WangYu::Space

Study, think, create, and grow. Teach yourself and teach others.

MIT 6.828 - Lab - Multithreading

分类:操作系统标签: 6.828创建时间:2024-02-04 00:00:00

课程主页:https://pdos.csail.mit.edu/6.828/2023/index.html

实现多线程

本实验中要求实现用户态的多线程,这里的线程和我们在 Linux 用到的线程是不一样的,在 Linux 上,线程由内核调度,而本次实验中线程需要程序自己调度。一个线程通过 thread_yield 交出控制权,将 CPU 让给其他可执行的线程。在 thread_yield 中调用 thread_schedule 来找到下一个可运行的线程。

这种用户态多线程也可以被称之为协程,一个协程由一组状态和指令构成,这里状态就是栈、寄存器,指令就是一个函数。需要运行某个协程的时候,就将栈指针指向协程的栈,然后指令寄存器指向该函数。 当协程要切换的时候,就保存当前寄存器到协程的上下文中,把下一个要执行的协程的上下文载入。这里切换的时机就是用户代码中主动调用的 yield 函数。

上下文

上下文就是当前各个寄存器的状态,需要定义一个结构体来保存它们,这里可以复用 proc.h 中定义的 context 结构体。

struct context {
  uint64 ra;
  uint64 sp;

  // callee-saved
  uint64 s0;
  uint64 s1;
  uint64 s2;
  uint64 s3;
  uint64 s4;
  uint64 s5;
  uint64 s6;
  uint64 s7;
  uint64 s8;
  uint64 s9;
  uint64 s10;
  uint64 s11;
};
struct thread {
  char       stack[STACK_SIZE]; /* the thread's stack */
  int        state;             /* FREE, RUNNING, RUNNABLE */
  struct context context;       // 保存当前线程的上下文
};

这里 ra 表示函数的返回地址,当一个程序执行结束后,就会跳转到 ra 所指向的地址处继续执行。 sp 是栈指针,指向当前使用的栈。

创建线程

void 
thread_create(void (*func)())
{
  struct thread *t;

  for (t = all_thread; t < all_thread + MAX_THREAD; t++) {
    if (t->state == FREE) break;
  }
  t->state = RUNNABLE;
  // YOUR CODE HERE
  memset(&t->context, 0, sizeof(t->context));
  t->context.ra = (uint64)func;
  t->context.sp = (uint64)(t->stack + STACK_SIZE);
}

这里我设置 rafunc 的地址,设置 sp 指向当前进程的栈,因为栈是向下增长的,因此这里指向 t->stack 的末尾。

线程切换

线程切换使用 thread_switch 实现,它做的工作和进程切换调用的 swtch 是一样的。这里 thread_switch 函数的视线可以复用 swtch.Sswtch 的实现。

	.text

	/*
         * save the old thread's registers,
         * restore the new thread's registers.
         */

	.globl thread_switch
thread_switch:
	/* YOUR CODE HERE */

	sd ra, 0(a0)
	sd sp, 8(a0)
	sd s0, 16(a0)
	sd s1, 24(a0)
	sd s2, 32(a0)
	sd s3, 40(a0)
	sd s4, 48(a0)
	sd s5, 56(a0)
	sd s6, 64(a0)
	sd s7, 72(a0)
	sd s8, 80(a0)
	sd s9, 88(a0)
	sd s10, 96(a0)
	sd s11, 104(a0)

	ld ra, 0(a1)
	ld sp, 8(a1)
	ld s0, 16(a1)
	ld s1, 24(a1)
	ld s2, 32(a1)
	ld s3, 40(a1)
	ld s4, 48(a1)
	ld s5, 56(a1)
	ld s6, 64(a1)
	ld s7, 72(a1)
	ld s8, 80(a1)
	ld s9, 88(a1)
	ld s10, 96(a1)
	ld s11, 104(a1)

	ret    /* return to ra */

执行线程切换就只需要加入如下一行代码,这里将当前执行正在执行的线程的上下文保存在 t->context 中,并恢复待执行线程的上下文。

@@ -56,10 +75,7 @@ thread_schedule(void)
     next_thread->state = RUNNING;
     t = current_thread;
     current_thread = next_thread;
     /* YOUR CODE HERE
      * Invoke thread_switch to switch from t to next_thread:
      * thread_switch(??, ??);
     */
+    thread_switch((uint64)&t->context, (uint64)&current_thread->context);
   } else
     next_thread = 0;
 }

当恢复了 current_thread->context 的上下文后,指令会接着向下执行,并退出 thread_schedule 函数,然后跳转到 current_thread->context->ra 所执行的地方,如此就开始运行一个线程了。

当一个线程使用 thread_yield 释放执行权后,在 thread_yield 中会调用 thread_schedule,这里 thread_switch 就会将当前上下文保存到 context 中,此时 context 中的 ra 就不再是创建线程时候设置的 func 了,而是调用 thread_schedulethread_yield 中。

void 
thread_yield(void)
{
  current_thread->state = RUNNABLE;
  thread_schedule();
}

thread_yield 退出时,会从栈上拿出 ra,并进一步返回到调用 thread_yield 的地方,如此一个线程就又恢复了执行。

void 
thread_yield(void)
{
 11a:	1141                	add	sp,sp,-16
 11c:	e406                	sd	ra,8(sp)   # 保存 ra 到栈上
 11e:	e022                	sd	s0,0(sp)
 120:	0800                	add	s0,sp,16
  current_thread->state = RUNNABLE;
 122:	00001797          	auipc	a5,0x1
 126:	cb67b783          	ld	a5,-842(a5) # dd8 <current_thread>
 12a:	6709                	lui	a4,0x2
 12c:	97ba                	add	a5,a5,a4
 12e:	4709                	li	a4,2
 130:	c398                	sw	a4,0(a5)
  thread_schedule();
 132:	00000097          	auipc	ra,0x0
 136:	ef4080e7          	jalr	-268(ra) # 26 <thread_schedule>
}
 13a:	60a2                	ld	ra,8(sp)   # 从栈上恢复 ra
 13c:	6402                	ld	s0,0(sp)
 13e:	0141                	add	sp,sp,16
 140:	8082                	ret            # 返回到 ra 处继续运行

理解这个实验的细节,你就可以很容易地理解有栈协程的原理了。

多线程安全哈希表

第二个子实验是要实现一个多线程安全的哈希表。实验给的代码中,哈希表在插入新元素时没有加锁,当哈希冲突的时候,在处理外接链表的时候,多线程同时修改链表会导致数据丢失。

使用一个线程执行写入,没有任何问题:

$ ./ph 1
100000 puts, 1.510 seconds, 66246 puts/second
0: 0 keys missing
100000 gets, 1.521 seconds, 65730 gets/second

当使用两个线程的时候,会发现部分 key 在哈希表中找不到:

$ ./ph 2
100000 puts, 0.712 seconds, 140477 puts/second
0: 16689 keys missing
1: 16689 keys missing
200000 gets, 1.438 seconds, 139093 gets/second

解法很简单,可以在 put_thread 中对 put 函数加锁:

pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;

static void *
put_thread(void *xa)
{
  int n = (int) (long) xa; // thread number
  int b = NKEYS/nthread;

  for (int i = 0; i < b; i++) {
    pthread_mutex_lock(&mutex);
    put(keys[b*n + i], n);
    pthread_mutex_unlock(&mutex);
  }

  return NULL;
}

再次执行,就不会有问题了:

$ ./ph 2
100000 puts, 1.628 seconds, 61431 puts/second
0: 0 keys missing
1: 0 keys missing
200000 gets, 1.557 seconds, 128480 gets/second

上面的方案是对整个哈希表加锁,导致两个多个线程只能排队执行插入操作。因为这里哈希表使用拉链法,一个线程实际上只会操作一个 bucket,我们可以将加锁粒度细分到 bucket 上。

pthread_mutex_t mutexs[NBUCKET];

static 
void put(int key, int value)
{
  int i = key % NBUCKET;
  pthread_mutex_lock(&mutexs[i]);

  //...
  
  pthread_mutex_unlock(&mutexs[i]);
}

int
main(int argc, char *argv[])
{
  // 初始化 mutex
  for (int i = 0; i < NBUCKET; i++) {
    pthread_mutex_init(&mutexs[i], NULL);
  }

  //...
}
```

可以在 `put` 函数中对 `bucket` 加锁,这样可以得到更好的并发度。

```sh
# 全局加锁
$ ./ph 4                                                          
100000 puts, 1.635 seconds, 61160 puts/second
0: 0 keys missing
3: 0 keys missing
1: 0 keys missing
2: 0 keys missing


# 对 bucket 加锁
$ ./ph 4  
100000 puts, 0.691 seconds, 144680 puts/second
3: 0 keys missing
2: 0 keys missing
1: 0 keys missing
0: 0 keys missing
```

对比两种方案,可以看到对 bucket 加锁,可以获得两倍以上的性能提升。

## 实现 `barrier`

`barrier` 翻译为屏障,多个线程分别执行不同任务,每个线程完全的时间不同,但在进行下一步操作的时候,希望所有线程都执行完各自的任务,此时就需要使用 `barrier`。


```
thread 1       barrier 
=============>   ||  =====>
                 || 
thread 2         ||
=========>       ||  =====>
                 ||
thread 3         ||
==========>      ||  ===>
```

当一个线程到达屏障后,如果其他线程尚未到达,则需要等待。当最后一个线程到达屏障后,就需要通知其他线程,这样多个线程又可以继续往下执行。

这里实现的 `barrier` 如下:

```c
static void 
barrier()
{
  pthread_mutex_lock(&bstate.barrier_mutex);
  bstate.nthread++;  // 累加到达屏障的线程数
  if (bstate.nthread == nthread) {  // 如果所有线程都到达屏障
    bstate.round++;  // 累加 round 计数器,表示可以执行下一轮循环
    pthread_cond_broadcast(&bstate.barrier_cond);  // 通知所有线程,此时会唤醒在等待在 `pthread_cond_wait` 上的线程
    bstate.nthread = 0; // 清空 nthread 准备进入下一轮的循环
  } else {
    // 如果还有线程没有到达屏障,则在条件变量上等待
    // 在使用 `pthread_cond_wait` 的时候,需要对 `mutex` 加锁,在 `pthread_cond_wait` 内部会自动对 `mutex` 解锁。
    pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
  }
  pthread_mutex_unlock(&bstate.barrier_mutex);
}
```

## 完整变更

```diff
diff --git a/notxv6/barrier.c b/notxv6/barrier.c
index 12793e8..f546350 100644
--- a/notxv6/barrier.c
+++ b/notxv6/barrier.c
@@ -26,11 +26,16 @@ static void
 barrier()
 {
   // YOUR CODE HERE
-  //
-  // Block until all threads have called barrier() and
-  // then increment bstate.round.
-  //
-  
+  pthread_mutex_lock(&bstate.barrier_mutex);
+  bstate.nthread++;
+  if (bstate.nthread == nthread) {
+    bstate.round++;
+    pthread_cond_broadcast(&bstate.barrier_cond);
+    bstate.nthread = 0;
+  } else {
+    pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
+  }
+  pthread_mutex_unlock(&bstate.barrier_mutex);
 }
 
 static void *
diff --git a/notxv6/ph.c b/notxv6/ph.c
index 82afe76..2e6785e 100644
--- a/notxv6/ph.c
+++ b/notxv6/ph.c
@@ -16,6 +16,7 @@ struct entry {
 struct entry *table[NBUCKET];
 int keys[NKEYS];
 int nthread = 1;
+pthread_mutex_t mutexs[NBUCKET];
 
 
 double
@@ -40,6 +41,7 @@ static
 void put(int key, int value)
 {
   int i = key % NBUCKET;
+  pthread_mutex_lock(&mutexs[i]);
 
   // is the key already present?
   struct entry *e = 0;
@@ -54,7 +56,7 @@ void put(int key, int value)
     // the new is new.
     insert(key, value, &table[i], table[i]);
   }
-
+  pthread_mutex_unlock(&mutexs[i]);
 }
 
 static struct entry*
@@ -118,6 +120,10 @@ main(int argc, char *argv[])
     keys[i] = random();
   }
 
+  for (int i = 0; i < NBUCKET; i++) {
+    pthread_mutex_init(&mutexs[i], NULL);
+  }
+
   //
   // first the puts
   //
diff --git a/user/uthread.c b/user/uthread.c
index 18b773d..b103e19 100644
--- a/user/uthread.c
+++ b/user/uthread.c
@@ -10,10 +10,29 @@
 #define STACK_SIZE  8192
 #define MAX_THREAD  4
 
+struct thread_context {
+  uint64 ra;
+  uint64 sp;
+
+  // callee-saved
+  uint64 s0;
+  uint64 s1;
+  uint64 s2;
+  uint64 s3;
+  uint64 s4;
+  uint64 s5;
+  uint64 s6;
+  uint64 s7;
+  uint64 s8;
+  uint64 s9;
+  uint64 s10;
+  uint64 s11;
+};
 
 struct thread {
   char       stack[STACK_SIZE]; /* the thread's stack */
   int        state;             /* FREE, RUNNING, RUNNABLE */
+  struct thread_context context;
 };
 struct thread all_thread[MAX_THREAD];
 struct thread *current_thread;
@@ -56,10 +75,7 @@ thread_schedule(void)
     next_thread->state = RUNNING;
     t = current_thread;
     current_thread = next_thread;
-    /* YOUR CODE HERE
-     * Invoke thread_switch to switch from t to next_thread:
-     * thread_switch(??, ??);
-     */
+    thread_switch((uint64)&t->context, (uint64)&current_thread->context);
   } else
     next_thread = 0;
 }
@@ -74,6 +90,9 @@ thread_create(void (*func)())
   }
   t->state = RUNNABLE;
   // YOUR CODE HERE
+  memset(&t->context, 0, sizeof(t->context));
+  t->context.ra = (uint64)func;
+  t->context.sp = (uint64)(t->stack + STACK_SIZE);
 }
 
 void 
diff --git a/user/uthread_switch.S b/user/uthread_switch.S
index 5defb12..6afc3de 100644
--- a/user/uthread_switch.S
+++ b/user/uthread_switch.S
@@ -8,4 +8,35 @@
 	.globl thread_switch
 thread_switch:
 	/* YOUR CODE HERE */
+
+	sd ra, 0(a0)
+	sd sp, 8(a0)
+	sd s0, 16(a0)
+	sd s1, 24(a0)
+	sd s2, 32(a0)
+	sd s3, 40(a0)
+	sd s4, 48(a0)
+	sd s5, 56(a0)
+	sd s6, 64(a0)
+	sd s7, 72(a0)
+	sd s8, 80(a0)
+	sd s9, 88(a0)
+	sd s10, 96(a0)
+	sd s11, 104(a0)
+
+	ld ra, 0(a1)
+	ld sp, 8(a1)
+	ld s0, 16(a1)
+	ld s1, 24(a1)
+	ld s2, 32(a1)
+	ld s3, 40(a1)
+	ld s4, 48(a1)
+	ld s5, 56(a1)
+	ld s6, 64(a1)
+	ld s7, 72(a1)
+	ld s8, 80(a1)
+	ld s9, 88(a1)
+	ld s10, 96(a1)
+	ld s11, 104(a1)
+
 	ret    /* return to ra */

```

评论 (评论内容仅博主可见,不会公开显示)