MIT 6.828 - Lab - Multithreading
课程主页:https://pdos.csail.mit.edu/6.828/2023/index.html
实现多线程
本实验中要求实现用户态的多线程,这里的线程和我们在 Linux 用到的线程是不一样的,在 Linux 上,线程由内核调度,而本次实验中线程需要程序自己调度。一个线程通过 thread_yield 交出控制权,将 CPU 让给其他可执行的线程。在 thread_yield 中调用 thread_schedule 来找到下一个可运行的线程。
这种用户态多线程也可以被称之为协程,一个协程由一组状态和指令构成,这里状态就是栈、寄存器,指令就是一个函数。需要运行某个协程的时候,就将栈指针指向协程的栈,然后指令寄存器指向该函数。
当协程要切换的时候,就保存当前寄存器到协程的上下文中,把下一个要执行的协程的上下文载入。这里切换的时机就是用户代码中主动调用的 yield 函数。
上下文
上下文就是当前各个寄存器的状态,需要定义一个结构体来保存它们,这里可以复用 proc.h 中定义的 context 结构体。
struct context {
uint64 ra;
uint64 sp;
// callee-saved
uint64 s0;
uint64 s1;
uint64 s2;
uint64 s3;
uint64 s4;
uint64 s5;
uint64 s6;
uint64 s7;
uint64 s8;
uint64 s9;
uint64 s10;
uint64 s11;
};
struct thread {
char stack[STACK_SIZE]; /* the thread's stack */
int state; /* FREE, RUNNING, RUNNABLE */
struct context context; // 保存当前线程的上下文
};
这里 ra 表示函数的返回地址,当一个程序执行结束后,就会跳转到 ra 所指向的地址处继续执行。 sp 是栈指针,指向当前使用的栈。
创建线程
void
thread_create(void (*func)())
{
struct thread *t;
for (t = all_thread; t < all_thread + MAX_THREAD; t++) {
if (t->state == FREE) break;
}
t->state = RUNNABLE;
// YOUR CODE HERE
memset(&t->context, 0, sizeof(t->context));
t->context.ra = (uint64)func;
t->context.sp = (uint64)(t->stack + STACK_SIZE);
}
这里我设置 ra 为 func 的地址,设置 sp 指向当前进程的栈,因为栈是向下增长的,因此这里指向 t->stack 的末尾。
线程切换
线程切换使用 thread_switch 实现,它做的工作和进程切换调用的 swtch 是一样的。这里 thread_switch 函数的视线可以复用 swtch.S 中 swtch 的实现。
.text
/*
* save the old thread's registers,
* restore the new thread's registers.
*/
.globl thread_switch
thread_switch:
/* YOUR CODE HERE */
sd ra, 0(a0)
sd sp, 8(a0)
sd s0, 16(a0)
sd s1, 24(a0)
sd s2, 32(a0)
sd s3, 40(a0)
sd s4, 48(a0)
sd s5, 56(a0)
sd s6, 64(a0)
sd s7, 72(a0)
sd s8, 80(a0)
sd s9, 88(a0)
sd s10, 96(a0)
sd s11, 104(a0)
ld ra, 0(a1)
ld sp, 8(a1)
ld s0, 16(a1)
ld s1, 24(a1)
ld s2, 32(a1)
ld s3, 40(a1)
ld s4, 48(a1)
ld s5, 56(a1)
ld s6, 64(a1)
ld s7, 72(a1)
ld s8, 80(a1)
ld s9, 88(a1)
ld s10, 96(a1)
ld s11, 104(a1)
ret /* return to ra */
执行线程切换就只需要加入如下一行代码,这里将当前执行正在执行的线程的上下文保存在 t->context 中,并恢复待执行线程的上下文。
@@ -56,10 +75,7 @@ thread_schedule(void)
next_thread->state = RUNNING;
t = current_thread;
current_thread = next_thread;
/* YOUR CODE HERE
* Invoke thread_switch to switch from t to next_thread:
* thread_switch(??, ??);
*/
+ thread_switch((uint64)&t->context, (uint64)¤t_thread->context);
} else
next_thread = 0;
}
当恢复了 current_thread->context 的上下文后,指令会接着向下执行,并退出 thread_schedule 函数,然后跳转到 current_thread->context->ra 所执行的地方,如此就开始运行一个线程了。
当一个线程使用 thread_yield 释放执行权后,在 thread_yield 中会调用 thread_schedule,这里 thread_switch 就会将当前上下文保存到 context 中,此时 context 中的 ra 就不再是创建线程时候设置的 func 了,而是调用 thread_schedule 的 thread_yield 中。
void
thread_yield(void)
{
current_thread->state = RUNNABLE;
thread_schedule();
}
在 thread_yield 退出时,会从栈上拿出 ra,并进一步返回到调用 thread_yield 的地方,如此一个线程就又恢复了执行。
void
thread_yield(void)
{
11a: 1141 add sp,sp,-16
11c: e406 sd ra,8(sp) # 保存 ra 到栈上
11e: e022 sd s0,0(sp)
120: 0800 add s0,sp,16
current_thread->state = RUNNABLE;
122: 00001797 auipc a5,0x1
126: cb67b783 ld a5,-842(a5) # dd8 <current_thread>
12a: 6709 lui a4,0x2
12c: 97ba add a5,a5,a4
12e: 4709 li a4,2
130: c398 sw a4,0(a5)
thread_schedule();
132: 00000097 auipc ra,0x0
136: ef4080e7 jalr -268(ra) # 26 <thread_schedule>
}
13a: 60a2 ld ra,8(sp) # 从栈上恢复 ra
13c: 6402 ld s0,0(sp)
13e: 0141 add sp,sp,16
140: 8082 ret # 返回到 ra 处继续运行
理解这个实验的细节,你就可以很容易地理解有栈协程的原理了。
多线程安全哈希表
第二个子实验是要实现一个多线程安全的哈希表。实验给的代码中,哈希表在插入新元素时没有加锁,当哈希冲突的时候,在处理外接链表的时候,多线程同时修改链表会导致数据丢失。
使用一个线程执行写入,没有任何问题:
$ ./ph 1
100000 puts, 1.510 seconds, 66246 puts/second
0: 0 keys missing
100000 gets, 1.521 seconds, 65730 gets/second
当使用两个线程的时候,会发现部分 key 在哈希表中找不到:
$ ./ph 2
100000 puts, 0.712 seconds, 140477 puts/second
0: 16689 keys missing
1: 16689 keys missing
200000 gets, 1.438 seconds, 139093 gets/second
解法很简单,可以在 put_thread 中对 put 函数加锁:
pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
static void *
put_thread(void *xa)
{
int n = (int) (long) xa; // thread number
int b = NKEYS/nthread;
for (int i = 0; i < b; i++) {
pthread_mutex_lock(&mutex);
put(keys[b*n + i], n);
pthread_mutex_unlock(&mutex);
}
return NULL;
}
再次执行,就不会有问题了:
$ ./ph 2
100000 puts, 1.628 seconds, 61431 puts/second
0: 0 keys missing
1: 0 keys missing
200000 gets, 1.557 seconds, 128480 gets/second
上面的方案是对整个哈希表加锁,导致两个多个线程只能排队执行插入操作。因为这里哈希表使用拉链法,一个线程实际上只会操作一个 bucket,我们可以将加锁粒度细分到 bucket 上。
pthread_mutex_t mutexs[NBUCKET];
static
void put(int key, int value)
{
int i = key % NBUCKET;
pthread_mutex_lock(&mutexs[i]);
//...
pthread_mutex_unlock(&mutexs[i]);
}
int
main(int argc, char *argv[])
{
// 初始化 mutex
for (int i = 0; i < NBUCKET; i++) {
pthread_mutex_init(&mutexs[i], NULL);
}
//...
}
```
可以在 `put` 函数中对 `bucket` 加锁,这样可以得到更好的并发度。
```sh
# 全局加锁
$ ./ph 4
100000 puts, 1.635 seconds, 61160 puts/second
0: 0 keys missing
3: 0 keys missing
1: 0 keys missing
2: 0 keys missing
# 对 bucket 加锁
$ ./ph 4
100000 puts, 0.691 seconds, 144680 puts/second
3: 0 keys missing
2: 0 keys missing
1: 0 keys missing
0: 0 keys missing
```
对比两种方案,可以看到对 bucket 加锁,可以获得两倍以上的性能提升。
## 实现 `barrier`
`barrier` 翻译为屏障,多个线程分别执行不同任务,每个线程完全的时间不同,但在进行下一步操作的时候,希望所有线程都执行完各自的任务,此时就需要使用 `barrier`。
```
thread 1 barrier
=============> || =====>
||
thread 2 ||
=========> || =====>
||
thread 3 ||
==========> || ===>
```
当一个线程到达屏障后,如果其他线程尚未到达,则需要等待。当最后一个线程到达屏障后,就需要通知其他线程,这样多个线程又可以继续往下执行。
这里实现的 `barrier` 如下:
```c
static void
barrier()
{
pthread_mutex_lock(&bstate.barrier_mutex);
bstate.nthread++; // 累加到达屏障的线程数
if (bstate.nthread == nthread) { // 如果所有线程都到达屏障
bstate.round++; // 累加 round 计数器,表示可以执行下一轮循环
pthread_cond_broadcast(&bstate.barrier_cond); // 通知所有线程,此时会唤醒在等待在 `pthread_cond_wait` 上的线程
bstate.nthread = 0; // 清空 nthread 准备进入下一轮的循环
} else {
// 如果还有线程没有到达屏障,则在条件变量上等待
// 在使用 `pthread_cond_wait` 的时候,需要对 `mutex` 加锁,在 `pthread_cond_wait` 内部会自动对 `mutex` 解锁。
pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
}
pthread_mutex_unlock(&bstate.barrier_mutex);
}
```
## 完整变更
```diff
diff --git a/notxv6/barrier.c b/notxv6/barrier.c
index 12793e8..f546350 100644
--- a/notxv6/barrier.c
+++ b/notxv6/barrier.c
@@ -26,11 +26,16 @@ static void
barrier()
{
// YOUR CODE HERE
- //
- // Block until all threads have called barrier() and
- // then increment bstate.round.
- //
-
+ pthread_mutex_lock(&bstate.barrier_mutex);
+ bstate.nthread++;
+ if (bstate.nthread == nthread) {
+ bstate.round++;
+ pthread_cond_broadcast(&bstate.barrier_cond);
+ bstate.nthread = 0;
+ } else {
+ pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
+ }
+ pthread_mutex_unlock(&bstate.barrier_mutex);
}
static void *
diff --git a/notxv6/ph.c b/notxv6/ph.c
index 82afe76..2e6785e 100644
--- a/notxv6/ph.c
+++ b/notxv6/ph.c
@@ -16,6 +16,7 @@ struct entry {
struct entry *table[NBUCKET];
int keys[NKEYS];
int nthread = 1;
+pthread_mutex_t mutexs[NBUCKET];
double
@@ -40,6 +41,7 @@ static
void put(int key, int value)
{
int i = key % NBUCKET;
+ pthread_mutex_lock(&mutexs[i]);
// is the key already present?
struct entry *e = 0;
@@ -54,7 +56,7 @@ void put(int key, int value)
// the new is new.
insert(key, value, &table[i], table[i]);
}
-
+ pthread_mutex_unlock(&mutexs[i]);
}
static struct entry*
@@ -118,6 +120,10 @@ main(int argc, char *argv[])
keys[i] = random();
}
+ for (int i = 0; i < NBUCKET; i++) {
+ pthread_mutex_init(&mutexs[i], NULL);
+ }
+
//
// first the puts
//
diff --git a/user/uthread.c b/user/uthread.c
index 18b773d..b103e19 100644
--- a/user/uthread.c
+++ b/user/uthread.c
@@ -10,10 +10,29 @@
#define STACK_SIZE 8192
#define MAX_THREAD 4
+struct thread_context {
+ uint64 ra;
+ uint64 sp;
+
+ // callee-saved
+ uint64 s0;
+ uint64 s1;
+ uint64 s2;
+ uint64 s3;
+ uint64 s4;
+ uint64 s5;
+ uint64 s6;
+ uint64 s7;
+ uint64 s8;
+ uint64 s9;
+ uint64 s10;
+ uint64 s11;
+};
struct thread {
char stack[STACK_SIZE]; /* the thread's stack */
int state; /* FREE, RUNNING, RUNNABLE */
+ struct thread_context context;
};
struct thread all_thread[MAX_THREAD];
struct thread *current_thread;
@@ -56,10 +75,7 @@ thread_schedule(void)
next_thread->state = RUNNING;
t = current_thread;
current_thread = next_thread;
- /* YOUR CODE HERE
- * Invoke thread_switch to switch from t to next_thread:
- * thread_switch(??, ??);
- */
+ thread_switch((uint64)&t->context, (uint64)¤t_thread->context);
} else
next_thread = 0;
}
@@ -74,6 +90,9 @@ thread_create(void (*func)())
}
t->state = RUNNABLE;
// YOUR CODE HERE
+ memset(&t->context, 0, sizeof(t->context));
+ t->context.ra = (uint64)func;
+ t->context.sp = (uint64)(t->stack + STACK_SIZE);
}
void
diff --git a/user/uthread_switch.S b/user/uthread_switch.S
index 5defb12..6afc3de 100644
--- a/user/uthread_switch.S
+++ b/user/uthread_switch.S
@@ -8,4 +8,35 @@
.globl thread_switch
thread_switch:
/* YOUR CODE HERE */
+
+ sd ra, 0(a0)
+ sd sp, 8(a0)
+ sd s0, 16(a0)
+ sd s1, 24(a0)
+ sd s2, 32(a0)
+ sd s3, 40(a0)
+ sd s4, 48(a0)
+ sd s5, 56(a0)
+ sd s6, 64(a0)
+ sd s7, 72(a0)
+ sd s8, 80(a0)
+ sd s9, 88(a0)
+ sd s10, 96(a0)
+ sd s11, 104(a0)
+
+ ld ra, 0(a1)
+ ld sp, 8(a1)
+ ld s0, 16(a1)
+ ld s1, 24(a1)
+ ld s2, 32(a1)
+ ld s3, 40(a1)
+ ld s4, 48(a1)
+ ld s5, 56(a1)
+ ld s6, 64(a1)
+ ld s7, 72(a1)
+ ld s8, 80(a1)
+ ld s9, 88(a1)
+ ld s10, 96(a1)
+ ld s11, 104(a1)
+
ret /* return to ra */
```