要求:

  1. 提交任务不能被阻塞
  2. 可自定义最大协程的数量
  3. 实现stopWait,阻塞直到所有任务结束

基础的数据结构

定义一个WorkPool的结构体
包含taskQueue和workQueue两个channel,容量为0
包含一个双端队列Deque(开源库,支持泛型)

import (
    "github.com/gammazero/deque"
)
type Task func()
type WorkPool struct {
    taskQueue chan Task
    workQueue chan Task

    waitingQueue deque.Deque[Task]
}

submit的实现

提交任务主要逻辑就是往taskQueue中发送任务。但taskQueue不是容量为0,那么当任务队列满的时候,不会阻塞么?这里先记住这个问题,等全部实现之后将会解答。

func (p *WorkPool) Submit(task Task) {
	if task != nil {
		p.taskQueue <- task
	}
}

worker的实现

任务执行的逻辑,就是从workQueue中一直拿任务,如果任务非空,则执行,否则跳出循环,直接返回。

func (p *WorkPool) work(task Task) {
	for task != nil {
		task()
		task = <- p.workQueue
	}
}

dispatch的实现

第2、3步实现了提交任务和处理任务,那么必须有一个中间人负责将任务从taskQueue中传递到workQueue,主要有这几点工作
● 当worker数量不足时,创建worker
● 当worker处理不过来的时候,需要将任务放入waitingDeque
● 需要不断将waitingDeque中的任务投递到workQueue
先按照最简单的思路写出一个框架。

  1. 因为需要不断从taskQueue中拿任务,所以外层用一个for循环
  2. 通过recv操作从taskQueue获取task,如果taskQueue被关闭,则跳出循环
  3. 将task发送到workQueue
func (p *WorkPool) dispatch() {
	for {
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		p.workQueue <- task
	}
}

但因为没有地方创建worker,导致workQueue没有接收方,导致p.workQueue <- task被阻塞,可以在p.workQueue <- task包一层select,如果没有worker接收task,则走default逻辑,去创建worker。

func (p *WorkPool) dispatch() {
	for {
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		select {
		case p.workQueue <- task:
		default:
			go p.work(task)
		}
	}
}

但我们有一个maxWorkerNum的限制,创建worker的逻辑应该调整如下:

  1. 在结构体中新增两个变量,分别记录当前的worker数量和最大的worker数量
  2. 创建worker的时候先判断当前worker的数量是否小于最大worker数量
  3. 如果小于,则创建新的worker,否则将任务放入waitingDeque中
type WorkPool struct {
	taskQueue chan Task
	workQueue chan Task

	waitingQueue deque.Deque[Task]

	maxWorkerNum int // 新增内容
	curWorkerNum int // 新增内容
}

func (p *WorkPool) dispatch() {
	for {
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		select {
		case p.workQueue <- task:
		default:
			if p.curWorkerNum < p.maxWorkerNum {
				go p.work(task)
				p.curWorkerNum += 1
			} else {
				p.waitingQueue.PushBack(task)
			}
		}
	}

这里可以有一个优化点,每次循环开始时,如果waitingDeque的长度已经大于1,说明已经有任务在排队了,那么如果接收到新来的任务,则直接放入waitingDeque中,同时再将waitingDeque中的任务放入workQueue中,因此我们设计一个新的函数processWaitingQueue,完成上面这两件事情。
我们将这两个逻辑放在一个select中,这里需要确认一下提交任务submit是否会被阻塞。processWaitingQueue中的select包含了两个channel的recv和send操作,若两个操作都处于就绪状态,那么go将会随机选取一个执行,抛开一直选择发送操作这种极端操作,submit操作就不会被一直阻塞。又因为processWaitingQueue在for循环中,所以一直会重复以上两件事情,直到taskQuque被关闭跳出for循环或者waitingQueue长度为0,则执行之前的逻辑。

func (p *WorkPool) processWaitingQueue() bool {
	select {
	case task, ok := <-p.taskQueue:
		if !ok {
			return false
		}
		p.waitingQueue.PushBack(task)
	case p.workQueue <- p.waitingQueue.Front():
		p.waitingQueue.PopFront()
	}
	return true
}
func (p *WorkPool) dispatch() {
	for {
		if p.waitingQueue.Len() > 0 {
			if !p.processWaitingQueue() {
				break
			}
			continue
		}
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		select {
		case p.workQueue <- task:
		default:
			if p.curWorkerNum < p.maxWorkerNum {
				go p.work(task)
				p.curWorkerNum += 1
			} else {
				p.waitingQueue.PushBack(task)
			}
		}
	}

stopWait逻辑的实现
● stopWait需要阻塞,直到所有任务完成,那么第一时间想到使用sync.waitGroup,等待所有的worker结束;
● stopWait的阻塞,虽然可以通过sync.waitGroup来实现,但在go中使用较多的是通过channel来实现同步操作(也就是stopWait需要同步到“所有worker结束”这个事件)
● 需要关闭taskQueue这个channel,不让新的任务提交进来。
● 关闭taskQueue这个操作不能重复执行,通过sync.Once来实现只执行一次

type WorkPool struct {
	taskQueue chan Task
	workQueue chan Task

	waitingQueue deque.Deque[Task]

	maxWorkerNum int
	curWorkerNum int

	stopOnce sync.Once  # 新增
	stopChan chan struct{} # 新增
}

func (p *WorkPool) StopWait() {
	p.stopOnce.Do(func() {
		close(p.taskQueue)
		<-p.stopChan
	})
}

func (p *WorkPool) dispatch() {
	waitGroup := sync.WaitGroup{}  # 新增
	for {
		if p.waitingQueue.Len() > 0 {
			if !p.processWaitingQueue() {
				break
			}
			continue
		}
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		select {
		case p.workQueue <- task:
		default:
			if p.curWorkerNum < p.maxWorkerNum {
				waitGroup.Add(1) # 新增
				go p.work(task, &waitGroup) # 修改
				p.curWorkerNum += 1
			} else {
				p.waitingQueue.PushBack(task)
			}
		}
	}

	waitGroup.Wait()

	p.stopChan <- struct{}{}
}

func (p *WorkPool) work(task Task, waitGroup *sync.WaitGroup) {
	for task != nil {
		task()
		task = <-p.workQueue
	}
	waitGroup.Done() # 新增
}

但上面的代码有两个问题:
● 当调用stopWait,执行close(p.taskQueue)之后,dispatch中的for循环将会跳出,那么waitingQueue中的任务怎么办呢?我们设计一个runQueuedTasks函数,负责不停地将waitingQueue中的任务放到waitingQueue中。
● WorkPool.work函数中不可能接收到nil,将一直处于循环之中,无法触发waitGroup.Done(),流程无法继续。我们可以按照当前worker的数量,发送对应数量的nil到workQueue中

func (p *WorkPool) runQueuedTasks() { # 新增
	for p.waitingQueue.Len() > 0 {
		p.workQueue <- p.waitingQueue.PopFront()
	}
}

func (p *WorkPool) dispatch() {
	waitGroup := sync.WaitGroup{}
	for {
		if p.waitingQueue.Len() > 0 {
			if !p.processWaitingQueue() {
				break
			}
			continue
		}
		task, ok := <-p.taskQueue
		if !ok {
			break
		}
		select {
		case p.workQueue <- task:
		default:
			if p.curWorkerNum < p.maxWorkerNum {
				waitGroup.Add(1)
				go p.work(task, &waitGroup)
				p.curWorkerNum += 1
			} else {
				p.waitingQueue.PushBack(task)
			}
		}
	}

	p.runQueuedTasks() # 新增

    for p.curWorkerNum > 0 { # 新增
		p.workQueue <- nil
		p.curWorkerNum--
	}

	waitGroup.Wait()

    close(p.workQueue)
	p.stopChan <- struct{}{}
}

以上思路取自于开源项目:https://github.com/gammazero/workerpool,该项目还支持超时暂停等功能,但基本在上面这个框架中增加,重点还是需要理解整个协程的框架。