sync: simplify `WaitGroup` and `PoolProcessor` and use atomic counters (#8715)
							parent
							
								
									d03c1d615a
								
							
						
					
					
						commit
						835b3b2b81
					
				| 
						 | 
				
			
			@ -4,6 +4,9 @@ import sync
 | 
			
		|||
 | 
			
		||||
import runtime
 | 
			
		||||
 | 
			
		||||
[trusted]
 | 
			
		||||
fn C.atomic_fetch_add_u32(voidptr, u32) u32
 | 
			
		||||
 | 
			
		||||
pub const (
 | 
			
		||||
	no_result = voidptr(0)
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -14,9 +17,8 @@ mut:
 | 
			
		|||
	njobs           int
 | 
			
		||||
	items           []voidptr
 | 
			
		||||
	results         []voidptr
 | 
			
		||||
	ntask           int // writing to this should be locked by ntask_mtx.
 | 
			
		||||
	ntask_mtx       &sync.Mutex
 | 
			
		||||
	waitgroup       &sync.WaitGroup
 | 
			
		||||
	ntask           u32 // reading/writing to this should be atomic
 | 
			
		||||
	waitgroup       sync.WaitGroup
 | 
			
		||||
	shared_context  voidptr
 | 
			
		||||
	thread_contexts []voidptr
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -44,17 +46,16 @@ pub fn new_pool_processor(context PoolProcessorConfig) &PoolProcessor {
 | 
			
		|||
	if isnil(context.callback) {
 | 
			
		||||
		panic('You need to pass a valid callback to new_pool_processor.')
 | 
			
		||||
	}
 | 
			
		||||
	pool := &PoolProcessor {
 | 
			
		||||
	mut pool := &PoolProcessor {
 | 
			
		||||
		items: []
 | 
			
		||||
		results: []
 | 
			
		||||
		shared_context: voidptr(0)
 | 
			
		||||
		thread_contexts: []
 | 
			
		||||
		njobs: context.maxjobs
 | 
			
		||||
		ntask: 0
 | 
			
		||||
		ntask_mtx: sync.new_mutex()
 | 
			
		||||
		waitgroup: sync.new_waitgroup()
 | 
			
		||||
		thread_cb: voidptr(context.callback)
 | 
			
		||||
	}
 | 
			
		||||
	pool.waitgroup.init()
 | 
			
		||||
	return pool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -104,16 +105,9 @@ pub fn (mut pool PoolProcessor) work_on_pointers(items []voidptr) {
 | 
			
		|||
// method in a callback.
 | 
			
		||||
fn process_in_thread(mut pool PoolProcessor, task_id int) {
 | 
			
		||||
	cb := ThreadCB(pool.thread_cb)
 | 
			
		||||
	mut idx := 0
 | 
			
		||||
	ilen := pool.items.len
 | 
			
		||||
	for {
 | 
			
		||||
		if pool.ntask >= ilen {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		pool.ntask_mtx.@lock()
 | 
			
		||||
		idx = pool.ntask
 | 
			
		||||
		pool.ntask++
 | 
			
		||||
		pool.ntask_mtx.unlock()
 | 
			
		||||
		idx := int(C.atomic_fetch_add_u32(&pool.ntask, 1))
 | 
			
		||||
		if idx >= ilen {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,47 +3,49 @@
 | 
			
		|||
// that can be found in the LICENSE file.
 | 
			
		||||
module sync
 | 
			
		||||
 | 
			
		||||
// WaitGroup implementation. wait() blocks until all tasks complete.
 | 
			
		||||
[trusted]
 | 
			
		||||
fn C.atomic_fetch_add_u32(voidptr, u32) u32
 | 
			
		||||
 | 
			
		||||
// WaitGroup
 | 
			
		||||
// Do not copy an instance of WaitGroup, use a ref instead.
 | 
			
		||||
//
 | 
			
		||||
// Two mutexes are required so that wait() doesn't unblock on done()/add() before
 | 
			
		||||
// task_count becomes zero.
 | 
			
		||||
// usage: in main thread:
 | 
			
		||||
// `wg := sync.new_waitgroup()
 | 
			
		||||
// `wg.add(nr_jobs)` before starting jobs with `go ...`
 | 
			
		||||
// `wg.wait()` to wait for all jobs to have finished
 | 
			
		||||
//
 | 
			
		||||
// in each parallel job:
 | 
			
		||||
// `wg.done()` when finished 
 | 
			
		||||
//
 | 
			
		||||
// [init_with=new_waitgroup] // TODO: implement support for init_with struct attribute, and disallow WaitGroup{} from outside the sync.new_waitgroup() function.
 | 
			
		||||
[ref_only]
 | 
			
		||||
struct WaitGroup {
 | 
			
		||||
mut:
 | 
			
		||||
	task_count       int // current task count
 | 
			
		||||
	task_count_mutex Mutex // This mutex protects the task_count count in add()
 | 
			
		||||
	wait_blocker     Semaphore // This blocks the wait() until released by add()
 | 
			
		||||
	task_count u32 // current task count - reading/writing should be atomic
 | 
			
		||||
	sem        Semaphore // This blocks wait() until tast_countreleased by add()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn new_waitgroup() &WaitGroup {
 | 
			
		||||
	mut wg := &WaitGroup{}
 | 
			
		||||
	wg.task_count_mutex.init()
 | 
			
		||||
	wg.wait_blocker.init(1)
 | 
			
		||||
	wg.init()
 | 
			
		||||
	return wg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn (mut wg WaitGroup) init() {
 | 
			
		||||
	wg.sem.init(0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// add increments (+ve delta) or decrements (-ve delta) task count by delta
 | 
			
		||||
// and unblocks any wait() calls if task count becomes zero.
 | 
			
		||||
// add panics if task count drops below zero.
 | 
			
		||||
pub fn (mut wg WaitGroup) add(delta int) {
 | 
			
		||||
	// protect task_count
 | 
			
		||||
	wg.task_count_mutex.@lock()
 | 
			
		||||
	defer {
 | 
			
		||||
		wg.task_count_mutex.unlock()
 | 
			
		||||
	}
 | 
			
		||||
	// If task_count likely to leave zero, set wait() to block
 | 
			
		||||
	if wg.task_count == 0 {
 | 
			
		||||
		wg.wait_blocker.wait()
 | 
			
		||||
	}
 | 
			
		||||
	wg.task_count += delta
 | 
			
		||||
	if wg.task_count < 0 {
 | 
			
		||||
	old_nrjobs := int(C.atomic_fetch_add_u32(&wg.task_count, u32(delta)))
 | 
			
		||||
	new_nrjobs := old_nrjobs + delta
 | 
			
		||||
	if new_nrjobs < 0 {
 | 
			
		||||
		panic('Negative number of jobs in waitgroup')
 | 
			
		||||
	}
 | 
			
		||||
	if wg.task_count == 0 { // if no more task_count tasks
 | 
			
		||||
		wg.wait_blocker.post() // unblock wait()
 | 
			
		||||
	if new_nrjobs == 0 {
 | 
			
		||||
		wg.sem.post()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +56,5 @@ pub fn (mut wg WaitGroup) done() {
 | 
			
		|||
 | 
			
		||||
// wait blocks until all tasks are done (task count becomes zero)
 | 
			
		||||
pub fn (mut wg WaitGroup) wait() {
 | 
			
		||||
	wg.wait_blocker.wait() // blocks until task_count becomes 0
 | 
			
		||||
	wg.wait_blocker.post() // allow other wait()s to unblock or reuse wait group
 | 
			
		||||
	wg.sem.wait() // blocks until task_count becomes 0
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue