go源码解读-sync.WaitGroup

WaitGroup

  • WaitGroup结构体会等待一组goroutines结束
  • Add方法会设置等待的goroutine的数量
  • goroutine结束之后调用done即可
  • WaitGroup可以用于所有线程都结束之后才执行的逻辑
  • 使用之后不可以再复制
  • state1为64位结构的值
    • 高位的32位为计数器
    • 低位的32位为等待线程数
1
2
3
4
type WaitGroup struct {
noCopy noCopy // 第一次使用之后,不可以用copy函数进行复制
state1 [3]uint32 // 计数器,高4个字节记录要等待的线程总数,低位4个字节记录还需要等待完成的线程数量
}

state方法

  • state方法返回计数和信号量的地址
1
2
3
4
5
6
7
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}

Add方法

  • Add方法添加计数的goroutine数到wg中
  • 添加完计数器变为0,所有的阻塞的线程会被释放
  • 添加完计数器为负数,会panic
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
func (wg *WaitGroup) Add(delta int) {
// 获取当前计数和信号量的地址
statep, semap := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
if delta < 0 {
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
// 高32位来进行计数器的CAS加法
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 当前需要等待的线总数
v := int32(state >> 32)
// 获取需要等待结束的线程数
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
race.Read(unsafe.Pointer(semap))
}
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// v或者w为0直接返回即可
if v > 0 || w == 0 {
return
}
// Add和Wait同步发生,此时会报错
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 将等待完成的线程数设置为0,逐一释放之前等待的线程
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false)
}
}

Done方法

  • Done方法完成counter–
1
2
3
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

Wait方法

  • Wait方法完成线程阻塞,直到counter为0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
func (wg *WaitGroup) Wait() {
// 获取计数器和信号量的地址
statep, semap := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()
}
// 循环等待counter为0
for {
state := atomic.LoadUint64(statep)
// 获取counter和waiter的值
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
// 如果counter为0,则直接返回,进行后续操作
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// 否则waiter计数++
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
race.Write(unsafe.Pointer(semap))
}
// 排队休眠,等待信号量唤醒
runtime_Semacquire(semap)
// 休眠过程中,wg被重用会导致state不一致,从而panic
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
Donate comment here