sync: simplify `WaitGroup` and `PoolProcessor` and use atomic counters (#8715)

pull/8678/head
Uwe Krüger 2021-02-13 13:52:27 +01:00 committed by GitHub
parent d03c1d615a
commit 835b3b2b81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 37 deletions

View File

@ -4,6 +4,9 @@ import sync
import runtime import runtime
[trusted]
fn C.atomic_fetch_add_u32(voidptr, u32) u32
pub const ( pub const (
no_result = voidptr(0) no_result = voidptr(0)
) )
@ -14,9 +17,8 @@ mut:
njobs int njobs int
items []voidptr items []voidptr
results []voidptr results []voidptr
ntask int // writing to this should be locked by ntask_mtx. ntask u32 // reading/writing to this should be atomic
ntask_mtx &sync.Mutex waitgroup sync.WaitGroup
waitgroup &sync.WaitGroup
shared_context voidptr shared_context voidptr
thread_contexts []voidptr thread_contexts []voidptr
} }
@ -44,17 +46,16 @@ pub fn new_pool_processor(context PoolProcessorConfig) &PoolProcessor {
if isnil(context.callback) { if isnil(context.callback) {
panic('You need to pass a valid callback to new_pool_processor.') panic('You need to pass a valid callback to new_pool_processor.')
} }
pool := &PoolProcessor { mut pool := &PoolProcessor {
items: [] items: []
results: [] results: []
shared_context: voidptr(0) shared_context: voidptr(0)
thread_contexts: [] thread_contexts: []
njobs: context.maxjobs njobs: context.maxjobs
ntask: 0 ntask: 0
ntask_mtx: sync.new_mutex()
waitgroup: sync.new_waitgroup()
thread_cb: voidptr(context.callback) thread_cb: voidptr(context.callback)
} }
pool.waitgroup.init()
return pool return pool
} }
@ -104,16 +105,9 @@ pub fn (mut pool PoolProcessor) work_on_pointers(items []voidptr) {
// method in a callback. // method in a callback.
fn process_in_thread(mut pool PoolProcessor, task_id int) { fn process_in_thread(mut pool PoolProcessor, task_id int) {
cb := ThreadCB(pool.thread_cb) cb := ThreadCB(pool.thread_cb)
mut idx := 0
ilen := pool.items.len ilen := pool.items.len
for { for {
if pool.ntask >= ilen { idx := int(C.atomic_fetch_add_u32(&pool.ntask, 1))
break
}
pool.ntask_mtx.@lock()
idx = pool.ntask
pool.ntask++
pool.ntask_mtx.unlock()
if idx >= ilen { if idx >= ilen {
break break
} }

View File

@ -3,47 +3,49 @@
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
module sync module sync
// WaitGroup implementation. wait() blocks until all tasks complete. [trusted]
fn C.atomic_fetch_add_u32(voidptr, u32) u32
// WaitGroup
// Do not copy an instance of WaitGroup, use a ref instead. // Do not copy an instance of WaitGroup, use a ref instead.
// //
// Two mutexes are required so that wait() doesn't unblock on done()/add() before // usage: in main thread:
// task_count becomes zero. // `wg := sync.new_waitgroup()
// `wg.add(nr_jobs)` before starting jobs with `go ...`
// `wg.wait()` to wait for all jobs to have finished
//
// in each parallel job:
// `wg.done()` when finished
// //
// [init_with=new_waitgroup] // TODO: implement support for init_with struct attribute, and disallow WaitGroup{} from outside the sync.new_waitgroup() function. // [init_with=new_waitgroup] // TODO: implement support for init_with struct attribute, and disallow WaitGroup{} from outside the sync.new_waitgroup() function.
[ref_only] [ref_only]
struct WaitGroup { struct WaitGroup {
mut: mut:
task_count int // current task count task_count u32 // current task count - reading/writing should be atomic
task_count_mutex Mutex // This mutex protects the task_count count in add() sem Semaphore // This blocks wait() until tast_countreleased by add()
wait_blocker Semaphore // This blocks the wait() until released by add()
} }
pub fn new_waitgroup() &WaitGroup { pub fn new_waitgroup() &WaitGroup {
mut wg := &WaitGroup{} mut wg := &WaitGroup{}
wg.task_count_mutex.init() wg.init()
wg.wait_blocker.init(1)
return wg return wg
} }
pub fn (mut wg WaitGroup) init() {
wg.sem.init(0)
}
// add increments (+ve delta) or decrements (-ve delta) task count by delta // add increments (+ve delta) or decrements (-ve delta) task count by delta
// and unblocks any wait() calls if task count becomes zero. // and unblocks any wait() calls if task count becomes zero.
// add panics if task count drops below zero. // add panics if task count drops below zero.
pub fn (mut wg WaitGroup) add(delta int) { pub fn (mut wg WaitGroup) add(delta int) {
// protect task_count old_nrjobs := int(C.atomic_fetch_add_u32(&wg.task_count, u32(delta)))
wg.task_count_mutex.@lock() new_nrjobs := old_nrjobs + delta
defer { if new_nrjobs < 0 {
wg.task_count_mutex.unlock()
}
// If task_count likely to leave zero, set wait() to block
if wg.task_count == 0 {
wg.wait_blocker.wait()
}
wg.task_count += delta
if wg.task_count < 0 {
panic('Negative number of jobs in waitgroup') panic('Negative number of jobs in waitgroup')
} }
if wg.task_count == 0 { // if no more task_count tasks if new_nrjobs == 0 {
wg.wait_blocker.post() // unblock wait() wg.sem.post()
} }
} }
@ -54,6 +56,5 @@ pub fn (mut wg WaitGroup) done() {
// wait blocks until all tasks are done (task count becomes zero) // wait blocks until all tasks are done (task count becomes zero)
pub fn (mut wg WaitGroup) wait() { pub fn (mut wg WaitGroup) wait() {
wg.wait_blocker.wait() // blocks until task_count becomes 0 wg.sem.wait() // blocks until task_count becomes 0
wg.wait_blocker.post() // allow other wait()s to unblock or reuse wait group
} }