rand: move dist functions to top module and PRNG interface; minor cleanup (#14481)

master
Subhomoy Haldar 2022-05-22 15:51:52 +05:30 committed by GitHub
parent 64a686f41f
commit 3647fb4def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 290 additions and 134 deletions

View File

@ -0,0 +1,10 @@
// Copyright (c) 2019-2022 Alexander Medvednikov. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
module buffer
pub struct PRNGBuffer {
mut:
bytes_left int
buffer u64
}

View File

@ -1,3 +1,6 @@
// Copyright (c) 2019-2022 Alexander Medvednikov. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
module config module config
import rand.seed import rand.seed
@ -12,3 +15,35 @@ pub struct PRNGConfigStruct {
pub: pub:
seed_ []u32 = seed.time_seed_array(2) seed_ []u32 = seed.time_seed_array(2)
} }
// Configuration struct for generating normally distributed floats. The default value for
// `mu` is 0 and the default value for `sigma` is 1.
[params]
pub struct NormalConfigStruct {
pub:
mu f64 = 0.0
sigma f64 = 1.0
}
// Configuration struct for the shuffle functions.
// The start index is inclusive and the end index is exclusive.
// Set the end to 0 to shuffle until the end of the array.
[params]
pub struct ShuffleConfigStruct {
pub:
start int
end int
}
// validate_for is a helper function for validating the configuration struct for the given array.
pub fn (config ShuffleConfigStruct) validate_for<T>(a []T) ? {
if config.start < 0 || config.start >= a.len {
return error("argument 'config.start' must be in range [0, a.len)")
}
if config.end < 0 || config.end > a.len {
return error("argument 'config.end' must be in range [0, a.len]")
}
if config.end != 0 && config.end <= config.start {
return error("argument 'config.end' must be greater than 'config.start'")
}
}

View File

@ -1,10 +0,0 @@
# Non-Uniform Distribution Functions
This module contains functions for sampling from non-uniform distributions.
All implementations of the `rand.PRNG` interface generate numbers from uniform
distributions. This library exists to allow the generation of pseudorandom numbers
sampled from non-uniform distributions. Additionally, it allows the user to use any
PRNG of their choice. This is because the default RNG can be reassigned to a different
generator. It can either be one of the pre-existing one (which are well-tested and
recommended) or a custom user-defined one. See `rand.set_rng()`.

85
vlib/rand/dist/dist.v vendored
View File

@ -1,85 +0,0 @@
// Copyright (c) 2019-2022 Alexander Medvednikov. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
module dist
import math
import rand
fn check_probability_range(p f64) {
if p < 0 || p > 1 {
panic('$p is not a valid probability value.')
}
}
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
pub fn bernoulli(p f64) bool {
check_probability_range(p)
return rand.f64() <= p
}
// binomial returns the number of successful trials out of n when the
// probability of success for each trial is p.
pub fn binomial(n int, p f64) int {
check_probability_range(p)
mut count := 0
for _ in 0 .. n {
if bernoulli(p) {
count++
}
}
return count
}
// Configuration struct for the `normal_pair` function. The default value for
// `mu` is 0 and the default value for `sigma` is 1.
pub struct NormalConfigStruct {
mu f64 = 0.0
sigma f64 = 1.0
}
// normal_pair returns a pair of normally distributed random numbers with the mean mu
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair()`.
pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
if config.sigma <= 0 {
panic('The standard deviation has to be positive.')
}
// This is an implementation of the Marsaglia polar method
// See: https://doi.org/10.1137%2F1006063
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
for {
u := rand.f64_in_range(-1, 1) or { 0.0 }
v := rand.f64_in_range(-1, 1) or { 0.0 }
s := u * u + v * v
if s >= 1 || s == 0 {
continue
}
t := math.sqrt(-2 * math.log(s) / s)
x := config.mu + config.sigma * t * u
y := config.mu + config.sigma * t * v
return x, y
}
return config.mu, config.mu
}
// normal returns a normally distributed random number with the mean mu and standard deviation
// sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
// `x := normal(mu: mean, sigma: etdev)` or `x := normal()`.
// **NOTE:** If you are generating a lot of normal variates, use `the normal_pair` function
// instead. This function discards one of the two variates generated by the `normal_pair` function.
pub fn normal(config NormalConfigStruct) f64 {
x, _ := normal_pair(config)
return x
}
// exponential returns an exponentially distributed random number with the rate paremeter
// lambda. It is expected that lambda is positive.
pub fn exponential(lambda f64) f64 {
if lambda <= 0 {
panic('The rate (lambda) must be positive.')
}
// Use the inverse transform sampling method
return -math.log(rand.f64()) / lambda
}

View File

@ -1,6 +1,5 @@
import math import math
import rand import rand
import rand.dist
const ( const (
// The sample size to be used // The sample size to be used
@ -20,7 +19,7 @@ fn test_bernoulli() {
for p in ps { for p in ps {
mut successes := 0 mut successes := 0
for _ in 0 .. count { for _ in 0 .. count {
if dist.bernoulli(p) { if rand.bernoulli(p) or { false } {
successes++ successes++
} }
} }
@ -43,7 +42,7 @@ fn test_binomial() {
mut sum := 0 mut sum := 0
mut var := 0.0 mut var := 0.0
for _ in 0 .. count { for _ in 0 .. count {
x := dist.binomial(n, p) x := rand.binomial(n, p) or { 0 }
sum += x sum += x
dist := (x - np) dist := (x - np)
var += dist * dist var += dist * dist
@ -68,7 +67,7 @@ fn test_normal_pair() {
mut sum := 0.0 mut sum := 0.0
mut var := 0.0 mut var := 0.0
for _ in 0 .. count { for _ in 0 .. count {
x, y := dist.normal_pair(mu: mu, sigma: sigma) x, y := rand.normal_pair(mu: mu, sigma: sigma) or { 0.0, 0.0 }
sum += x + y sum += x + y
dist_x := x - mu dist_x := x - mu
dist_y := y - mu dist_y := y - mu
@ -95,7 +94,7 @@ fn test_normal() {
mut sum := 0.0 mut sum := 0.0
mut var := 0.0 mut var := 0.0
for _ in 0 .. count { for _ in 0 .. count {
x := dist.normal(mu: mu, sigma: sigma) x := rand.normal(mu: mu, sigma: sigma) or { 0.0 }
sum += x sum += x
dist := x - mu dist := x - mu
var += dist * dist var += dist * dist
@ -120,7 +119,7 @@ fn test_exponential() {
mut sum := 0.0 mut sum := 0.0
mut var := 0.0 mut var := 0.0
for _ in 0 .. count { for _ in 0 .. count {
x := dist.exponential(lambda) x := rand.exponential(lambda)
sum += x sum += x
dist := x - mu dist := x - mu
var += dist * dist var += dist * dist

View File

@ -0,0 +1,130 @@
// Copyright (c) 2019-2022 Alexander Medvednikov. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
module rand
// NOTE: mini_math.v exists, so that we can avoid `import math`,
// just for the math.log and math.sqrt functions needed for the
// non uniform random number redistribution functions.
// Importing math is relatively heavy, both in terms of compilation
// speed (more source to process), and in terms of increases in the
// generated executable sizes (if the rest of the program does not use
// math already; many programs do not need math, for example the
// compiler itself does not, while needing random number generation.
const sqrt2 = 1.41421356237309504880168872420969807856967187537694807317667974
[inline]
fn msqrt(a f64) f64 {
if a == 0 {
return a
}
mut x := a
z, ex := frexp(x)
w := x
// approximate square root of number between 0.5 and 1
// relative error of approximation = 7.47e-3
x = 4.173075996388649989089e-1 + 5.9016206709064458299663e-1 * z // adjust for odd powers of 2
if (ex & 1) != 0 {
x *= rand.sqrt2
}
x = scalbn(x, ex >> 1)
// newton iterations
x = 0.5 * (x + w / x)
x = 0.5 * (x + w / x)
x = 0.5 * (x + w / x)
return x
}
// a simplified approximation (without the edge cases), see math.log
fn mlog(a f64) f64 {
ln2_lo := 1.90821492927058770002e-10
ln2_hi := 0.693147180369123816490
l1 := 0.6666666666666735130
l2 := 0.3999999999940941908
l3 := 0.2857142874366239149
l4 := 0.2222219843214978396
l5 := 0.1818357216161805012
l6 := 0.1531383769920937332
l7 := 0.1479819860511658591
x := a
mut f1, mut ki := frexp(x)
if f1 < rand.sqrt2 / 2 {
f1 *= 2
ki--
}
f := f1 - 1
k := f64(ki)
s := f / (2 + f)
s2 := s * s
s4 := s2 * s2
t1 := s2 * (l1 + s4 * (l3 + s4 * (l5 + s4 * l7)))
t2 := s4 * (l2 + s4 * (l4 + s4 * l6))
r := t1 + t2
hfsq := 0.5 * f * f
return k * ln2_hi - ((hfsq - (s * (hfsq + r) + k * ln2_lo)) - f)
}
fn frexp(x f64) (f64, int) {
mut y := f64_bits(x)
ee := int((y >> 52) & 0x7ff)
if ee == 0 {
if x != 0.0 {
x1p64 := f64_from_bits(u64(0x43f0000000000000))
z, e_ := frexp(x * x1p64)
return z, e_ - 64
}
return x, 0
} else if ee == 0x7ff {
return x, 0
}
e_ := ee - 0x3fe
y &= u64(0x800fffffffffffff)
y |= u64(0x3fe0000000000000)
return f64_from_bits(y), e_
}
fn scalbn(x f64, n_ int) f64 {
mut n := n_
x1p1023 := f64_from_bits(u64(0x7fe0000000000000))
x1p53 := f64_from_bits(u64(0x4340000000000000))
x1p_1022 := f64_from_bits(u64(0x0010000000000000))
mut y := x
if n > 1023 {
y *= x1p1023
n -= 1023
if n > 1023 {
y *= x1p1023
n -= 1023
if n > 1023 {
n = 1023
}
}
} else if n < -1022 {
/*
make sure final n < -53 to avoid double
rounding in the subnormal range
*/
y *= x1p_1022 * x1p53
n += 1022 - 53
if n < -1022 {
y *= x1p_1022 * x1p53
n += 1022 - 53
if n < -1022 {
n = -1022
}
}
}
return y * f64_from_bits(u64((0x3ff + n)) << 52)
}
[inline]
fn f64_from_bits(b u64) f64 {
return *unsafe { &f64(&b) }
}
[inline]
fn f64_bits(f f64) u64 {
return *unsafe { &u64(&f) }
}

View File

@ -3,6 +3,7 @@
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
module mt19937 module mt19937
import rand.buffer
import rand.seed import rand.seed
/* /*
@ -60,11 +61,10 @@ const (
// MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937. // MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937.
// **NOTE**: The RNG is not seeded when instantiated so remember to seed it before use. // **NOTE**: The RNG is not seeded when instantiated so remember to seed it before use.
pub struct MT19937RNG { pub struct MT19937RNG {
buffer.PRNGBuffer
mut: mut:
state []u64 = get_first_state(seed.time_seed_array(2)) state []u64 = get_first_state(seed.time_seed_array(2))
mti int = mt19937.nn mti int = mt19937.nn
bytes_left int
buffer u64
} }
fn get_first_state(seed_data []u32) []u64 { fn get_first_state(seed_data []u32) []u64 {

View File

@ -4,15 +4,15 @@
module musl module musl
import rand.seed import rand.seed
import rand.buffer
pub const seed_len = 1 pub const seed_len = 1
// MuslRNG ported from https://git.musl-libc.org/cgit/musl/tree/src/prng/rand_r.c // MuslRNG ported from https://git.musl-libc.org/cgit/musl/tree/src/prng/rand_r.c
pub struct MuslRNG { pub struct MuslRNG {
buffer.PRNGBuffer
mut: mut:
state u32 = seed.time_seed_32() state u32 = seed.time_seed_32()
bytes_left int
buffer u32
} }
// seed sets the current random state based on `seed_data`. // seed sets the current random state based on `seed_data`.

View File

@ -4,6 +4,7 @@
module pcg32 module pcg32
import rand.seed import rand.seed
import rand.buffer
pub const seed_len = 4 pub const seed_len = 4
@ -11,11 +12,10 @@ pub const seed_len = 4
// https://github.com/imneme/pcg-c-basic/blob/master/pcg_basic.c, and // https://github.com/imneme/pcg-c-basic/blob/master/pcg_basic.c, and
// https://github.com/imneme/pcg-c-basic/blob/master/pcg_basic.h // https://github.com/imneme/pcg-c-basic/blob/master/pcg_basic.h
pub struct PCG32RNG { pub struct PCG32RNG {
buffer.PRNGBuffer
mut: mut:
state u64 = u64(0x853c49e6748fea9b) ^ seed.time_seed_64() state u64 = u64(0x853c49e6748fea9b) ^ seed.time_seed_64()
inc u64 = u64(0xda3e39cb94b95bdb) ^ seed.time_seed_64() inc u64 = u64(0xda3e39cb94b95bdb) ^ seed.time_seed_64()
bytes_left int
buffer u32
} }
// seed seeds the PCG32RNG with 4 `u32` values. // seed seeds the PCG32RNG with 4 `u32` values.

View File

@ -274,34 +274,79 @@ pub fn (mut rng PRNG) ascii(len int) string {
return internal_string_from_set(mut rng, rand.ascii_chars, len) return internal_string_from_set(mut rng, rand.ascii_chars, len)
} }
// Configuration struct for the shuffle functions. // bernoulli returns true with a probability p. Note that 0 <= p <= 1.
// The start index is inclusive and the end index is exclusive. pub fn (mut rng PRNG) bernoulli(p f64) ?bool {
// Set the end to 0 to shuffle until the end of the array. if p < 0 || p > 1 {
[params] return error('$p is not a valid probability value.')
pub struct ShuffleConfigStruct { }
pub: return rng.f64() <= p
start int
end int
} }
fn (config ShuffleConfigStruct) validate_for<T>(a []T) ? { // normal returns a normally distributed pseudorandom f64 in range `[0, 1)`.
if config.start < 0 || config.start >= a.len { // NOTE: Use normal_pair() instead if you're generating a lot of normal variates.
return error("argument 'config.start' must be in range [0, a.len)") pub fn (mut rng PRNG) normal(conf config.NormalConfigStruct) ?f64 {
x, _ := rng.normal_pair(conf)?
return x
}
// normal_pair returns a pair of normally distributed pseudorandom f64 in range `[0, 1)`.
pub fn (mut rng PRNG) normal_pair(conf config.NormalConfigStruct) ?(f64, f64) {
if conf.sigma <= 0 {
return error('Standard deviation must be positive')
} }
if config.end < 0 || config.end > a.len { // This is an implementation of the Marsaglia polar method
return error("argument 'config.end' must be in range [0, a.len]") // See: https://doi.org/10.1137%2F1006063
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
for {
u := rng.f64_in_range(-1, 1) or { 0.0 }
v := rng.f64_in_range(-1, 1) or { 0.0 }
s := u * u + v * v
if s >= 1 || s == 0 {
continue
} }
t := msqrt(-2 * mlog(s) / s)
x := conf.mu + conf.sigma * t * u
y := conf.mu + conf.sigma * t * v
return x, y
}
return error('Implementation error. Please file an issue.')
}
// binomial returns the number of successful trials out of n when the
// probability of success for each trial is p.
pub fn (mut rng PRNG) binomial(n int, p f64) ?int {
if p < 0 || p > 1 {
return error('$p is not a valid probability value.')
}
mut count := 0
for _ in 0 .. n {
if rng.bernoulli(p)! {
count++
}
}
return count
}
// exponential returns an exponentially distributed random number with the rate paremeter
// lambda. It is expected that lambda is positive.
pub fn (mut rng PRNG) exponential(lambda f64) f64 {
if lambda <= 0 {
panic('The rate (lambda) must be positive.')
}
// Use the inverse transform sampling method
return -mlog(rng.f64()) / lambda
} }
// shuffle randomly permutates the elements in `a`. The range for shuffling is // shuffle randomly permutates the elements in `a`. The range for shuffling is
// optional and the entire array is shuffled by default. Leave the end as 0 to // optional and the entire array is shuffled by default. Leave the end as 0 to
// shuffle all elements until the end. // shuffle all elements until the end.
[direct_array_access] [direct_array_access]
pub fn (mut rng PRNG) shuffle<T>(mut a []T, config ShuffleConfigStruct) ? { pub fn (mut rng PRNG) shuffle<T>(mut a []T, config config.ShuffleConfigStruct) ? {
config.validate_for(a)? config.validate_for(a)?
new_end := if config.end == 0 { a.len } else { config.end } new_end := if config.end == 0 { a.len } else { config.end }
for i in config.start .. new_end { for i in config.start .. new_end {
x := rng.int_in_range(i, new_end) or { config.start } x := rng.int_in_range(i, new_end) or { config.start + i }
// swap // swap
a_i := a[i] a_i := a[i]
a[i] = a[x] a[i] = a[x]
@ -311,7 +356,7 @@ pub fn (mut rng PRNG) shuffle<T>(mut a []T, config ShuffleConfigStruct) ? {
// shuffle_clone returns a random permutation of the elements in `a`. // shuffle_clone returns a random permutation of the elements in `a`.
// The permutation is done on a fresh clone of `a`, so `a` remains unchanged. // The permutation is done on a fresh clone of `a`, so `a` remains unchanged.
pub fn (mut rng PRNG) shuffle_clone<T>(a []T, config ShuffleConfigStruct) ?[]T { pub fn (mut rng PRNG) shuffle_clone<T>(a []T, config config.ShuffleConfigStruct) ?[]T {
mut res := a.clone() mut res := a.clone()
rng.shuffle(mut res, config)? rng.shuffle(mut res, config)?
return res return res
@ -541,13 +586,13 @@ pub fn ascii(len int) string {
// shuffle randomly permutates the elements in `a`. The range for shuffling is // shuffle randomly permutates the elements in `a`. The range for shuffling is
// optional and the entire array is shuffled by default. Leave the end as 0 to // optional and the entire array is shuffled by default. Leave the end as 0 to
// shuffle all elements until the end. // shuffle all elements until the end.
pub fn shuffle<T>(mut a []T, config ShuffleConfigStruct) ? { pub fn shuffle<T>(mut a []T, config config.ShuffleConfigStruct) ? {
default_rng.shuffle(mut a, config)? default_rng.shuffle(mut a, config)?
} }
// shuffle_clone returns a random permutation of the elements in `a`. // shuffle_clone returns a random permutation of the elements in `a`.
// The permutation is done on a fresh clone of `a`, so `a` remains unchanged. // The permutation is done on a fresh clone of `a`, so `a` remains unchanged.
pub fn shuffle_clone<T>(a []T, config ShuffleConfigStruct) ?[]T { pub fn shuffle_clone<T>(a []T, config config.ShuffleConfigStruct) ?[]T {
return default_rng.shuffle_clone(a, config) return default_rng.shuffle_clone(a, config)
} }
@ -563,3 +608,31 @@ pub fn choose<T>(array []T, k int) ?[]T {
pub fn sample<T>(array []T, k int) []T { pub fn sample<T>(array []T, k int) []T {
return default_rng.sample(array, k) return default_rng.sample(array, k)
} }
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
pub fn bernoulli(p f64) ?bool {
return default_rng.bernoulli(p)
}
// normal returns a normally distributed pseudorandom f64 in range `[0, 1)`.
// NOTE: Use normal_pair() instead if you're generating a lot of normal variates.
pub fn normal(conf config.NormalConfigStruct) ?f64 {
return default_rng.normal(conf)
}
// normal_pair returns a pair of normally distributed pseudorandom f64 in range `[0, 1)`.
pub fn normal_pair(conf config.NormalConfigStruct) ?(f64, f64) {
return default_rng.normal_pair(conf)
}
// binomial returns the number of successful trials out of n when the
// probability of success for each trial is p.
pub fn binomial(n int, p f64) ?int {
return default_rng.binomial(n, p)
}
// exponential returns an exponentially distributed random number with the rate paremeter
// lambda. It is expected that lambda is positive.
pub fn exponential(lambda f64) f64 {
return default_rng.exponential(lambda)
}

View File

@ -4,11 +4,13 @@
module splitmix64 module splitmix64
import rand.seed import rand.seed
import rand.buffer
pub const seed_len = 2 pub const seed_len = 2
// SplitMix64RNG ported from http://xoshiro.di.unimi.it/splitmix64.c // SplitMix64RNG ported from http://xoshiro.di.unimi.it/splitmix64.c
pub struct SplitMix64RNG { pub struct SplitMix64RNG {
buffer.PRNGBuffer
mut: mut:
state u64 = seed.time_seed_64() state u64 = seed.time_seed_64()
bytes_left int bytes_left int

View File

@ -4,6 +4,7 @@
module sys module sys
import math.bits import math.bits
import rand.buffer
import rand.seed import rand.seed
// Implementation note: // Implementation note:
@ -36,10 +37,9 @@ fn calculate_iterations_for(bits int) int {
// SysRNG is the PRNG provided by default in the libc implementiation that V uses. // SysRNG is the PRNG provided by default in the libc implementiation that V uses.
pub struct SysRNG { pub struct SysRNG {
buffer.PRNGBuffer
mut: mut:
seed u32 = seed.time_seed_32() seed u32 = seed.time_seed_32()
buffer int
bytes_left int
} }
// r.seed() sets the seed of the accepting SysRNG to the given data. // r.seed() sets the seed of the accepting SysRNG to the given data.
@ -71,7 +71,7 @@ pub fn (mut r SysRNG) u8() u8 {
r.buffer >>= 8 r.buffer >>= 8
return value return value
} }
r.buffer = r.default_rand() r.buffer = u64(r.default_rand())
r.bytes_left = sys.rand_bytesize - 1 r.bytes_left = sys.rand_bytesize - 1
value := u8(r.buffer) value := u8(r.buffer)
r.buffer >>= 8 r.buffer >>= 8

View File

@ -4,6 +4,7 @@
module wyrand module wyrand
import hash import hash
import rand.buffer
import rand.seed import rand.seed
// Redefinition of some constants that we will need for pseudorandom number generation. // Redefinition of some constants that we will need for pseudorandom number generation.
@ -16,6 +17,7 @@ pub const seed_len = 2
// WyRandRNG is a RNG based on the WyHash hashing algorithm. // WyRandRNG is a RNG based on the WyHash hashing algorithm.
pub struct WyRandRNG { pub struct WyRandRNG {
buffer.PRNGBuffer
mut: mut:
state u64 = seed.time_seed_64() state u64 = seed.time_seed_64()
bytes_left int bytes_left int