diff --git a/vlib/sync/channel_select_6_test.v b/vlib/sync/channel_select_6_test.v new file mode 100644 index 0000000000..2c3573c1ef --- /dev/null +++ b/vlib/sync/channel_select_6_test.v @@ -0,0 +1,75 @@ +// This test case runs concurrent 3 instances of `do_select` that +// communicate with 6 other threads doing send and receive operations. +// There are buffered and unbuffered channels - handled by one or two +// concurrend threads on the other side + +fn do_select(ch1 chan int, ch2 chan int, chf1 chan f64, chf2 chan f64, sumch1 chan i64, sumch2 chan i64) { + mut sum1 := i64(0) + mut sum2 := i64(0) + f1 := 17. + f2 := 7. + for _ in 0 .. 20000 + chf1.cap / 3 { + select { + chf1 <- f1 {} + i := <-ch1 { + sum1 += i + } + j := <-ch2 { + sum2 += j + } + chf2 <- f2 {} + } + } + sumch1 <- sum1 + sumch2 <- sum2 +} + +fn do_send_int(ch chan int, factor int) { + for i in 0 .. 10000 { + ch <- (i * factor) + } +} + +fn do_rec_f64(ch chan f64, sumch chan f64) { + mut sum:= 0. + for _ in 0 .. 10000 { + sum += <-ch + } + sumch <- sum +} + +fn test_select() { + ch1 := chan int{cap: 3} + ch2 := chan int{} + // buffer length of chf1 mus be mutiple of 3 (# select threads) + chf1 := chan f64{cap: 30} + chf2 := chan f64{} + chsum1 := chan i64{} + chsum2 := chan i64{} + chsumf1 := chan f64{} + chsumf2 := chan f64{} + go do_send_int(ch1, 3) + go do_select(ch1, ch2, chf1, chf2, chsum1, chsum2) + go do_rec_f64(chf1, chsumf1) + go do_rec_f64(chf2, chsumf2) + go do_rec_f64(chf2, chsumf2) + go do_select(ch1, ch2, chf1, chf2, chsum1, chsum2) + go do_send_int(ch2, 7) + go do_send_int(ch2, 17) + go do_select(ch1, ch2, chf1, chf2, chsum1, chsum2) + + sum1 := <-chsum1 + <-chsum1 + <-chsum1 + sum2 := <-chsum2 + <-chsum2 + <-chsum2 + mut sumf1 := <-chsumf1 + // empty channel buffer + for _ in 0 .. chf1.cap { + sumf1 += <-chf1 + } + sumf2 := <-chsumf2 + <-chsumf2 + // Use Gauß' formula + expected_sum := i64(10000) * (10000 - 1) / 2 + assert sum1 == 3 * expected_sum + assert sum2 == (7 + 17) * expected_sum + assert sumf1 == 17. * f64(10000 + chf1.cap) + assert sumf2 == 7. * 20000 +} diff --git a/vlib/sync/channels.v b/vlib/sync/channels.v index c7dbf90005..6d944181b6 100644 --- a/vlib/sync/channels.v +++ b/vlib/sync/channels.v @@ -511,20 +511,18 @@ pub fn channel_select(mut channels []&Channel, dir []Direction, mut objrefs []vo mut subscr := []Subscription{len: channels.len} sem := new_semaphore() for i, ch in channels { + subscr[i].sem = sem if dir[i] == .push { mut null16 := u16(0) for !C.atomic_compare_exchange_weak_u16(&ch.write_sub_mtx, &null16, u16(1)) { null16 = u16(0) } - subscr[i].sem = sem subscr[i].prev = &ch.write_subscriber unsafe { subscr[i].nxt = C.atomic_exchange_ptr(&ch.write_subscriber, &subscr[i]) } if voidptr(subscr[i].nxt) != voidptr(0) { - unsafe { - subscr[i].nxt.prev = &subscr[i] - } + subscr[i].nxt.prev = &subscr[i].nxt } C.atomic_store_u16(&ch.write_sub_mtx, u16(0)) } else { @@ -532,13 +530,12 @@ pub fn channel_select(mut channels []&Channel, dir []Direction, mut objrefs []vo for !C.atomic_compare_exchange_weak_u16(&ch.read_sub_mtx, &null16, u16(1)) { null16 = u16(0) } - subscr[i].sem = sem subscr[i].prev = &ch.read_subscriber unsafe { subscr[i].nxt = C.atomic_exchange_ptr(&ch.read_subscriber, &subscr[i]) } if voidptr(subscr[i].nxt) != voidptr(0) { - unsafe { subscr[i].nxt.prev = &subscr[i] } + subscr[i].nxt.prev = &subscr[i].nxt } C.atomic_store_u16(&ch.read_sub_mtx, u16(0)) } @@ -577,7 +574,8 @@ pub fn channel_select(mut channels []&Channel, dir []Direction, mut objrefs []vo } if timeout == 0 { goto restore - } else if timeout > 0 { + } + if timeout > 0 { remaining := timeout - stopwatch.elapsed() if !sem.timed_wait(remaining) { goto restore @@ -594,8 +592,11 @@ restore: for !C.atomic_compare_exchange_weak_u16(&ch.write_sub_mtx, &null16, u16(1)) { null16 = u16(0) } - subscr[i].prev = subscr[i].nxt + unsafe { + *subscr[i].prev = subscr[i].nxt + } if subscr[i].nxt != 0 { + subscr[i].nxt.prev = subscr[i].prev // just in case we have missed a semaphore during restore subscr[i].nxt.sem.post() } @@ -605,8 +606,11 @@ restore: for !C.atomic_compare_exchange_weak_u16(&ch.read_sub_mtx, &null16, u16(1)) { null16 = u16(0) } - subscr[i].prev = subscr[i].nxt + unsafe { + *subscr[i].prev = subscr[i].nxt + } if subscr[i].nxt != 0 { + subscr[i].nxt.prev = subscr[i].prev subscr[i].nxt.sem.post() } C.atomic_store_u16(&ch.read_sub_mtx, u16(0))