MIT 6.828 - Lab - Copy-on-write fork
课程主页:https://pdos.csail.mit.edu/6.828/2023/index.html
背景:
xv6 原本的 fork 会把父进程的内存完全拷贝一份,本实验的目的是给 xv6 的 fork 加入写时复制能力。大体思路比较清晰:在 fork 的时候不拷贝物理内存,仅仅拷贝页表,让父子进程共享物理页,并设置该物理也为只读。当父进程或子进程想要修改任何物理页的时候,因为页面是只读的,这个时候就会触发页错误并进入中断,此时内核可以对内存进行拷贝。
下面是本次实验中做的一些修改,我会细细描述为何这样修改,想要得到最终的改动,可以看后文。
共享页表
第一个改动是修改拷贝物理页的逻辑,将之前的直接拷贝修改为共享物理页。在 fork 的实现中调用了 uvmcopy 做内存的拷贝,这里需要修改此函数。
可以让子进程共享父进程的页,但是需要在页表项中做一些标记,说明这个页是一个共享页,下次写入的时候需要做写时复制。在页表项中有一些标记,对于共享的页,我们需要将其 W 标记删除掉,并额外增加一个 COW 标记。页表项的的后 10 个 bit 中尚且存在 2 个 bit 未使用,这里可以利用这里预留的 bit:
#define PTE_W (1L << 2)
#define PTE_X (1L << 3)
#define PTE_U (1L << 4) // user can access
+#define PTE_C (1L << 8) // copy on write
对 uvmcopy 做出的改动如下:
@@ -323,14 +322,15 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
if((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");
pa = PTE2PA(*pte);
+ if (PTE_FLAGS(*pte) & PTE_W) { // 如果该页可写,则去掉其可写权限,并新增写时复制标记
+ *pte &= ~PTE_W;
+ *pte |= PTE_C;
+ }
flags = PTE_FLAGS(*pte);
- if((mem = kalloc()) == 0)
- goto err;
- memmove(mem, (char*)pa, PGSIZE);
- if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
- kfree(mem);
+ if (mappages(new, i, PGSIZE, (uint64)pa, flags) != 0) { // 直接映射父进程的页到子进程中
goto err;
}
+ kref((void*)pa);
}
return 0;
这里第一步是修改页表的标记,对于可写的页,删除其写标记,并增加写时复制标记。对于只读的页,我们不做任何处理直接共享。
因为一个页映射到了两个进程中,当两个进程退出的时候,内存就会被释放两次,这就会出问题。为了解决这个问题,这里给页增加了引用计数。引用计数如何实现,下一节接着说。
引用计数
因为一个页现在要被多个进程共享,多个进程在退出时需要释放页面,我们需要有机制保证其只会被释放一次,这就需要用到引用计数了。
一般情况下,实现引用计数的方式是在需要共享的对象中添加一个字段来记录该对象的引用次数。每当对象被引用时,计数器就会加一,对象的引用被解除时,计数器就会减一,当计数器为 0 时,就会释放内存。
因为页中没有额外的空间可以存储引用数,这里可以创建一个数组来存放每个页的引用次数。该数组的下标是页的编号。
struct {
struct spinlock lock;
struct run *freelist;
+ uint8 refcount[PHYSTOP/PGSIZE];
} kmem;
在 kinit 中将引用数清零:
void
@@ -28,6 +29,30 @@ kinit()
{
initlock(&kmem.lock, "kmem");
freerange(end, (void*)PHYSTOP);
+ memset(kmem.refcount, 0, sizeof(kmem.refcount));
}
引用计数的通常使用模式是下面这样的:
o = alloc()
ref(o)
// ...
unref(o)
在 alloc 后会紧跟着 ref 操作来增加其引用计数,在需要释放的时候不再调用 free 而是调用 unref,在 unref 中会判断引用计数是否为 0,当计数为零的时候会调用 free 做释放。
但这样的实现改动会比较多,这里我实现的思路是:在 alloc 的时候自动增加引用计数,在 free 的时候减少引用计数,只有当计数为 0 的时候才做真正的释放。
这里先实现三个函数用于增加、减少、查询一个页的引用计数:
#define PAGE_INDEX(pa) ((uint64)pa / PGSIZE)
void kref(void *pa) {
acquire(&kmem.lock);
kmem.refcount[PAGE_INDEX(pa)]++;
release(&kmem.lock);
}
int kunref(void *pa) {
acquire(&kmem.lock);
if (kmem.refcount[PAGE_INDEX(pa)] == 0) {
panic("kunref: double free");
}
kmem.refcount[PAGE_INDEX(pa)]--;
int count = kmem.refcount[PAGE_INDEX(pa)];
release(&kmem.lock);
return count;
}
int krefcount(void *pa) {
acquire(&kmem.lock);
int count = kmem.refcount[PAGE_INDEX(pa)];
release(&kmem.lock);
return count;
}
在 kfree 的开头做 unref 然后判断是否需要做真正的释放:
void
kfree(void *pa)
{
struct run *r;
if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
panic("kfree");
+ if (kunref(pa) > 0) {
+ return;
+ }
这里有一个小细节,我让 kunref 返回了此时的引用计数,并没有先调用 kunref 然后调用 krefcount 获取计数,这是为了保证获取当前页面计数的原子性。如果在 kunref 和 krefcount 之间其他 CPU 上也在同一个页上调用 kfree,本来获取到的计数为 1,但因为其他 CPU 也调用了 kfree,导致两个 CPU 上获取到的计数都为 0,此时就会释放两次。
在 kalloc 的末尾增加计数:
@@ -78,5 +108,7 @@ kalloc(void)
if(r)
memset((char*)r, 5, PGSIZE); // fill with junk
+
+ kref((void*)r);
return (void*)r;
}
在 kinit 中调用 freerange 来构造空闲链表,freerange 通过调用 kfree 将页表加入空闲链表中,这些页并非通过 kalloc 申请来的,因此这里需要略做调整,在外面主动增加一下引用计数。
void
@@ -35,8 +60,10 @@ freerange(void *pa_start, void *pa_end)
{
char *p;
p = (char*)PGROUNDUP((uint64)pa_start);
- for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE)
+ for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) {
+ kref(p);
kfree(p);
+ }
}
处理异常
前面我们在共享物理页的时候,修改了页表项的标记,清除了可写标记,页都成了只读的。当用户代码在修改共享内存时,就会触发异常。这个时候就会进入到内核的异常处理函数 usertrap 中,此时可以从 scause 寄存器中拿到异常原因,并处理页错误。
下图中列出了异常编号和原因:

这里需要关注的是编号为 15 的异常,在 usertrap 中做如下修改:
@@ -67,6 +67,11 @@ usertrap(void)
syscall();
} else if((which_dev = devintr()) != 0){
// ok
+ } else if (r_scause() == 15) {
+ uint64 va = r_stval();
+ if (uvmcow(myproc()->pagetable, va) != 0) {
+ setkilled(p);
+ }
} else {
printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
printf(" sepc=%p stval=%p\n", r_sepc(), r_stval());
从 stval 中拿到出现错误的虚拟地址,然后调用 uvmcow 实现写时复制,将页拷贝一份,并增加可写权限。异常处理完毕后,程序会到出错的地方继续执行。
写时复制
写时复制由 uvmcow 函数实现,其内容如下:
int uvmcow(pagetable_t pagetable, uint64 va) {
if (va > MAXVA) {
return -1;
}
pte_t *pte = walk(pagetable, va, 0);
uint32 flags = PTE_FLAGS(*pte);
if (pte == 0 || (flags & (PTE_V | PTE_U)) != (PTE_V | PTE_U)) { // 如果页不存在或者没有 U 和 V 标记
return -1;
}
if (!(flags & PTE_C)) { // 如果该页面没有 COW 标记,则出错,说明这是一个真正的页错误
return -1;
}
uint64 pa = PTE2PA(*pte); // 得到物理地址
if (krefcount((void*)pa) == 1) { // 如果该地址只有一个引用,直接修改标记即可
*pte &= ~PTE_COW;
*pte |= PTE_W;
} else {
void *mem = kalloc(); // 复制页面
memmove(mem, (const void*)pa, PGSIZE);
flags |= PTE_W;
flags &= ~PTE_COW;
uvmunmap(pagetable, va, 1, 1); // 删除旧的映射
if (mappages(pagetable, va, PGSIZE, (uint64)mem, flags) != 0) { // 重新映射
kfree(mem);
return -1;
}
}
return 0;
}
uvmcow 判断当前的引用数,如果只有一处引用,则不需要拷贝只用修改标记即可,否则需要做拷贝。执行完写时复制后,此前会失败的指令再次执行就不会发生页错误了。
内核中有一个 copyout 函数,它将物理地址指定的数据拷贝到虚拟地址处,这个函数会写虚拟地址,因此也需要处理写时复制的情况。
void
@@ -367,8 +398,14 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
return -1;
pte = walk(pagetable, va0, 0);
if(pte == 0 || (*pte & PTE_V) == 0 || (*pte & PTE_U) == 0 ||
- (*pte & PTE_W) == 0)
+ (*pte & (PTE_W |PTE_C)) == 0) // 如果不能写也没有写时复制的标记
return -1;
+ uint32 flags = PTE_FLAGS(*pte);
+ if (flags & PTE_C) { // 执行写时复制
+ if (uvmcow(pagetable, va0) != 0) {
+ return -1;
+ }
+ }
pa0 = PTE2PA(*pte);
n = PGSIZE - (dstva - va0);
if(n > len)
完整变更
diff --git a/kernel/defs.h b/kernel/defs.h
index a3c962b..83b294f 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -63,6 +63,9 @@ void ramdiskrw(struct buf*);
void* kalloc(void);
void kfree(void *);
void kinit(void);
+void kref(void *pa);
+int krefcount(void *pa);
+
// log.c
void initlog(int, struct superblock*);
@@ -173,6 +176,7 @@ uint64 walkaddr(pagetable_t, uint64);
int copyout(pagetable_t, uint64, char *, uint64);
int copyin(pagetable_t, char *, uint64, uint64);
int copyinstr(pagetable_t, char *, uint64, uint64);
+int uvmcow(pagetable_t pagetable, uint64 va);
// plic.c
void plicinit(void);
diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index 0699e7e..fe51996 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -21,6 +21,7 @@ struct run {
struct {
struct spinlock lock;
struct run *freelist;
+ uint8 refcount[PHYSTOP/PGSIZE];
} kmem;
void
@@ -28,6 +29,33 @@ kinit()
{
initlock(&kmem.lock, "kmem");
freerange(end, (void*)PHYSTOP);
+ memset(kmem.refcount, 0, sizeof(kmem.refcount));
+}
+
+#define PAGE_INDEX(pa) ((uint64)pa / PGSIZE)
+
+void kref(void *pa) {
+ acquire(&kmem.lock);
+ kmem.refcount[PAGE_INDEX(pa)]++;
+ release(&kmem.lock);
+}
+
+int kunref(void *pa) {
+ acquire(&kmem.lock);
+ if (kmem.refcount[PAGE_INDEX(pa)] == 0) {
+ panic("kunref: double free");
+ }
+ kmem.refcount[PAGE_INDEX(pa)]--;
+ int count = kmem.refcount[PAGE_INDEX(pa)];
+ release(&kmem.lock);
+ return count;
+}
+
+int krefcount(void *pa) {
+ acquire(&kmem.lock);
+ int count = kmem.refcount[PAGE_INDEX(pa)];
+ release(&kmem.lock);
+ return count;
}
void
@@ -35,8 +63,10 @@ freerange(void *pa_start, void *pa_end)
{
char *p;
p = (char*)PGROUNDUP((uint64)pa_start);
- for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE)
+ for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) {
+ kref(p);
kfree(p);
+ }
}
// Free the page of physical memory pointed at by pa,
@@ -51,6 +81,10 @@ kfree(void *pa)
if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
panic("kfree");
+ if (kunref(pa) > 0) {
+ return;
+ }
+
// Fill with junk to catch dangling refs.
memset(pa, 1, PGSIZE);
@@ -78,5 +112,7 @@ kalloc(void)
if(r)
memset((char*)r, 5, PGSIZE); // fill with junk
+
+ kref((void*)r);
return (void*)r;
}
diff --git a/kernel/riscv.h b/kernel/riscv.h
index 20a01db..c498988 100644
--- a/kernel/riscv.h
+++ b/kernel/riscv.h
@@ -343,6 +343,7 @@ typedef uint64 *pagetable_t; // 512 PTEs
#define PTE_W (1L << 2)
#define PTE_X (1L << 3)
#define PTE_U (1L << 4) // user can access
+#define PTE_C (1L << 8) // copy on write
// shift a physical address to the right place for a PTE.
#define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
diff --git a/kernel/trap.c b/kernel/trap.c
index 512c850..7b6c78e 100644
--- a/kernel/trap.c
+++ b/kernel/trap.c
@@ -67,6 +67,11 @@ usertrap(void)
syscall();
} else if((which_dev = devintr()) != 0){
// ok
+ } else if (r_scause() == 15) {
+ uint64 va = r_stval();
+ if (uvmcow(myproc()->pagetable, va) != 0) {
+ setkilled(p);
+ }
} else {
printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
printf(" sepc=%p stval=%p\n", r_sepc(), r_stval());
diff --git a/kernel/vm.c b/kernel/vm.c
index 5c31e87..9de38f0 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -315,7 +315,6 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
pte_t *pte;
uint64 pa, i;
uint flags;
- char *mem;
for(i = 0; i < sz; i += PGSIZE){
if((pte = walk(old, i, 0)) == 0)
@@ -323,14 +322,15 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
if((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");
pa = PTE2PA(*pte);
+ if (PTE_FLAGS(*pte) & PTE_W) {
+ *pte &= ~PTE_W;
+ *pte |= PTE_C;
+ }
flags = PTE_FLAGS(*pte);
- if((mem = kalloc()) == 0)
- goto err;
- memmove(mem, (char*)pa, PGSIZE);
- if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
- kfree(mem);
+ if (mappages(new, i, PGSIZE, (uint64)pa, flags) != 0) {
goto err;
}
+ kref((void*)pa);
}
return 0;
@@ -339,6 +339,37 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
return -1;
}
+int uvmcow(pagetable_t pagetable, uint64 va) {
+ if (va > MAXVA) {
+ return -1;
+ }
+ va = PGROUNDDOWN(va);
+ pte_t *pte = walk(pagetable, va, 0);
+ uint32 flags = PTE_FLAGS(*pte);
+ if (pte == 0 || (flags & (PTE_V | PTE_U)) != (PTE_V | PTE_U)) {
+ return -1;
+ }
+ if (!(flags & PTE_C)) {
+ return -1;
+ }
+ uint64 pa = PTE2PA(*pte);
+ if (krefcount((void*)pa) == 1) {
+ *pte &= ~PTE_C;
+ *pte |= PTE_W;
+ } else {
+ void *mem = kalloc();
+ memmove(mem, (const void*)pa, PGSIZE);
+ flags |= PTE_W;
+ flags &= ~PTE_C;
+ uvmunmap(pagetable, va, 1, 1);
+ if (mappages(pagetable, va, PGSIZE, (uint64)mem, flags) != 0) {
+ kfree(mem);
+ return -1;
+ }
+ }
+ return 0;
+}
+
// mark a PTE invalid for user access.
// used by exec for the user stack guard page.
void
@@ -367,8 +398,14 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
return -1;
pte = walk(pagetable, va0, 0);
if(pte == 0 || (*pte & PTE_V) == 0 || (*pte & PTE_U) == 0 ||
- (*pte & PTE_W) == 0)
+ (*pte & (PTE_W |PTE_C)) == 0)
return -1;
+ uint32 flags = PTE_FLAGS(*pte);
+ if (flags & PTE_C) {
+ if (uvmcow(pagetable, va0) != 0) {
+ return -1;
+ }
+ }
pa0 = PTE2PA(*pte);
n = PGSIZE - (dstva - va0);
if(n > len)