diff --git a/drivers/virt/geniezone/gzvm_mmu.c b/drivers/virt/geniezone/gzvm_mmu.c index 77104df4afe5..c4421ca294c8 100644 --- a/drivers/virt/geniezone/gzvm_mmu.c +++ b/drivers/virt/geniezone/gzvm_mmu.c @@ -107,8 +107,59 @@ int gzvm_gfn_to_pfn_memslot(struct gzvm_memslot *memslot, u64 gfn, return 0; } +static int cmp_ppages(struct rb_node *node, const struct rb_node *parent) +{ + struct gzvm_pinned_page *a = container_of(node, + struct gzvm_pinned_page, + node); + struct gzvm_pinned_page *b = container_of(parent, + struct gzvm_pinned_page, + node); + + if (a->ipa < b->ipa) + return -1; + if (a->ipa > b->ipa) + return 1; + return 0; +} + +static int gzvm_insert_ppage(struct gzvm *vm, struct gzvm_pinned_page *ppage) +{ + if (rb_find_add(&ppage->node, &vm->pinned_pages, cmp_ppages)) + return -EEXIST; + return 0; +} + +static int pin_one_page(struct gzvm *vm, unsigned long hva, u64 gpa) +{ + unsigned int flags = FOLL_HWPOISON | FOLL_LONGTERM | FOLL_WRITE; + struct gzvm_pinned_page *ppage = NULL; + struct mm_struct *mm = current->mm; + struct page *page = NULL; + + ppage = kmalloc(sizeof(*ppage), GFP_KERNEL_ACCOUNT); + if (!ppage) + return -ENOMEM; + + mmap_read_lock(mm); + pin_user_pages(hva, 1, flags, &page, NULL); + mmap_read_unlock(mm); + + if (!page) { + kfree(ppage); + return -EFAULT; + } + + ppage->page = page; + ppage->ipa = gpa; + gzvm_insert_ppage(vm, ppage); + + return 0; +} + static int handle_block_demand_page(struct gzvm *vm, int memslot_id, u64 gfn) { + unsigned long hva; u64 pfn, __gfn; int ret, i; @@ -131,6 +182,11 @@ static int handle_block_demand_page(struct gzvm *vm, int memslot_id, u64 gfn) goto err_unlock; } vm->demand_page_buffer[i] = pfn; + + hva = gzvm_gfn_to_hva_memslot(&vm->memslot[memslot_id], __gfn); + ret = pin_one_page(vm, hva, PFN_PHYS(__gfn)); + if (ret) + goto err_unlock; } ret = gzvm_arch_map_guest_block(vm->vm_id, memslot_id, start_gfn, @@ -148,6 +204,7 @@ err_unlock: static int handle_single_demand_page(struct gzvm *vm, int memslot_id, u64 gfn) { + unsigned long hva; int ret; u64 pfn; @@ -159,7 +216,8 @@ static int handle_single_demand_page(struct gzvm *vm, int memslot_id, u64 gfn) if (unlikely(ret)) return -EFAULT; - return 0; + hva = gzvm_gfn_to_hva_memslot(&vm->memslot[memslot_id], gfn); + return pin_one_page(vm, hva, PFN_PHYS(gfn)); } /** diff --git a/drivers/virt/geniezone/gzvm_vm.c b/drivers/virt/geniezone/gzvm_vm.c index 485d1e2097aa..a7d43bedfad0 100644 --- a/drivers/virt/geniezone/gzvm_vm.c +++ b/drivers/virt/geniezone/gzvm_vm.c @@ -292,6 +292,21 @@ out: return ret; } +static void gzvm_destroy_ppage(struct gzvm *gzvm) +{ + struct gzvm_pinned_page *ppage; + struct rb_node *node; + + node = rb_first(&gzvm->pinned_pages); + while (node) { + ppage = rb_entry(node, struct gzvm_pinned_page, node); + unpin_user_pages_dirty_lock(&ppage->page, 1, true); + node = rb_next(node); + rb_erase(&ppage->node, &gzvm->pinned_pages); + kfree(ppage); + } +} + static void gzvm_destroy_vm(struct gzvm *gzvm) { size_t allocated_size; @@ -315,6 +330,8 @@ static void gzvm_destroy_vm(struct gzvm *gzvm) mutex_unlock(&gzvm->lock); + gzvm_destroy_ppage(gzvm); + kfree(gzvm); } @@ -390,6 +407,7 @@ static struct gzvm *gzvm_create_vm(unsigned long vm_type) gzvm->vm_id = ret; gzvm->mm = current->mm; mutex_init(&gzvm->lock); + gzvm->pinned_pages = RB_ROOT; ret = gzvm_vm_irqfd_init(gzvm); if (ret) { diff --git a/include/linux/gzvm_drv.h b/include/linux/gzvm_drv.h index 7587a6388c32..0de4642bc01f 100644 --- a/include/linux/gzvm_drv.h +++ b/include/linux/gzvm_drv.h @@ -12,6 +12,7 @@ #include #include #include +#include /* * For the normal physical address, the highest 12 bits should be zero, so we @@ -82,6 +83,12 @@ struct gzvm_vcpu { struct gzvm_vcpu_hwstate *hwstate; }; +struct gzvm_pinned_page { + struct rb_node node; + struct page *page; + u64 ipa; +}; + struct gzvm { struct gzvm_vcpu *vcpus[GZVM_MAX_VCPUS]; /* userspace tied to this vm */ @@ -121,6 +128,9 @@ struct gzvm { * at the same time */ struct mutex demand_paging_lock; + + /* Use rb-tree to record pin/unpin page */ + struct rb_root pinned_pages; }; long gzvm_dev_ioctl_check_extension(struct gzvm *gzvm, unsigned long args);