rand: add non uniform distributions in the `rand.dist` module (#9274)

pull/9271/head^2
Subhomoy Haldar 2021-03-13 00:54:43 +05:30 committed by GitHub
parent 0c33656a19
commit 08da33fa5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 197 additions and 2 deletions

10
vlib/rand/dist/README.md vendored 100644
View File

@ -0,0 +1,10 @@
# Non-Uniform Distribution Functions
This module contains functions for sampling from non-uniform distributions.
All implementations of the `rand.PRNG` interface generate numbers from uniform
distributions. This library exists to allow the generation of pseudorandom numbers
sampled from non-uniform distributions. Additionally, it allows the user to use any
PRNG of their choice. This is because the default RNG can be reassigned to a different
generator. It can either be one of the pre-existing one (which are well-tested and
recommended) or a custom user-defined one. See `rand.set_rng()`.

72
vlib/rand/dist/dist.v vendored 100644
View File

@ -0,0 +1,72 @@
// Copyright (c) 2019-2021 Alexander Medvednikov. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
module dist
import math
import rand
fn check_probability_range(p f64) {
if p < 0 || p > 1 {
panic('$p is not a valid probability value.')
}
}
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
pub fn bernoulli(p f64) bool {
check_probability_range(p)
return rand.f64() <= p
}
// binomial returns the number of successful trials out of n when the
// probability of success for each trial is p.
pub fn binomial(n int, p f64) int {
check_probability_range(p)
mut count := 0
for _ in 0 .. n {
if bernoulli(p) {
count++
}
}
return count
}
// Configuration struct for the `normal_pair` function. The default value for
// `mu` is 0 and the default value for `sigma` is 1.
pub struct NormalConfigStruct {
mu f64 = 0.0
sigma f64 = 1.0
}
// normal_pair returns a pair of normally distributed random numbers with the mean mu
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`.
pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
// This is an implementation of the Marsaglia polar method
// See: https://doi.org/10.1137%2F1006063
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
for {
u := rand.f64_in_range(-1, 1)
v := rand.f64_in_range(-1, 1)
s := u * u + v * v
if s >= 1 || s == 0 {
continue
}
t := math.sqrt(-2 * math.log(s) / s)
x := config.mu + config.sigma * t * u
y := config.mu + config.sigma * t * v
return x, y
}
return config.mu, config.mu
}
// normal returns a normally distributed random number with the mean mu and standard deviation
// sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
// `x := normal(mu: mean, sigma: etdev)` or `x := normal({})`.
// **NOTE:** If you are generating a lot of normal variates, use `the normal_pair` function
// instead. This function discards one of the two variates generated by the `normal_pair` function.
pub fn normal(config NormalConfigStruct) f64 {
x, _ := normal_pair(config)
return x
}

111
vlib/rand/dist/dist_test.v vendored 100644
View File

@ -0,0 +1,111 @@
import math
import rand
import rand.dist
const (
// The sample size to be used
count = 2000
// Accepted error is within 5% of the actual values.
error = 0.05
// The seeds used (for reproducible testing)
seeds = [[u32(0xffff24), 0xabcd], [u32(0x141024), 0x42851],
[u32(0x1452), 0x90cd],
]
)
fn test_bernoulli() {
ps := [0.0, 0.1, 1.0 / 3.0, 0.5, 0.8, 17.0 / 18.0, 1.0]
for seed in seeds {
rand.seed(seed)
for p in ps {
mut successes := 0
for _ in 0 .. count {
if dist.bernoulli(p) {
successes++
}
}
assert math.abs(f64(successes) / count - p) < error
}
}
}
fn test_binomial() {
ns := [100, 200, 1000]
ps := [0.0, 0.5, 0.95, 1.0]
for seed in seeds {
rand.seed(seed)
for n in ns {
for p in ps {
np := n * p
npq := np * (1 - p)
mut sum := 0
mut var := 0.0
for _ in 0 .. count {
x := dist.binomial(n, p)
sum += x
dist := (x - np)
var += dist * dist
}
assert math.abs(f64(sum / count) - np) / n < error
assert math.abs(f64(var / count) - npq) / n < error
}
}
}
}
fn test_normal_pair() {
mus := [0, 10, 100, -40]
sigmas := [1, 2, 40, 5]
total := 2 * count
for seed in seeds {
rand.seed(seed)
for mu in mus {
for sigma in sigmas {
mut sum := 0.0
mut var := 0.0
for _ in 0 .. count {
x, y := dist.normal_pair(mu: mu, sigma: sigma)
sum += x + y
dist_x := x - mu
dist_y := y - mu
var += dist_x * dist_x
var += dist_y * dist_y
}
variance := sigma * sigma
assert math.abs(f64(sum / total) - mu) / sigma < 1
assert math.abs(f64(var / total) - variance) / variance < 2 * error
}
}
}
}
fn test_normal() {
mus := [0, 10, 100, -40, 20]
sigmas := [1, 2, 5]
for seed in seeds {
rand.seed(seed)
for mu in mus {
for sigma in sigmas {
mut sum := 0.0
mut var := 0.0
for _ in 0 .. count {
x := dist.normal(mu: mu, sigma: sigma)
sum += x
dist := x - mu
var += dist * dist
}
variance := sigma * sigma
assert math.abs(f64(sum / count) - mu) / sigma < 1
assert math.abs(f64(var / count) - variance) / variance < 2 * error
}
}
}
}

View File

@ -4,7 +4,6 @@
module mt19937 module mt19937
import math.bits import math.bits
import rand.seed
/* /*
C++ functions for MT19937, with initialization improved 2002/2/10. C++ functions for MT19937, with initialization improved 2002/2/10.
@ -57,6 +56,7 @@ const (
) )
// MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937. // MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937.
// **NOTE**: The RNG is not seeded when instantiated so remember to seed it before use.
pub struct MT19937RNG { pub struct MT19937RNG {
mut: mut:
state []u64 = []u64{len: mt19937.nn} state []u64 = []u64{len: mt19937.nn}

View File

@ -71,7 +71,9 @@ pub fn set_rng(rng &PRNG) {
default_rng = rng default_rng = rng
} }
// seed sets the given array of `u32` values as the seed for the `default_rng`. It is recommended to use // seed sets the given array of `u32` values as the seed for the `default_rng`. The default_rng is
// an instance of WyRandRNG which takes 2 u32 values. When using a custom RNG, make sure to use
// the correct number of u32s.
pub fn seed(seed []u32) { pub fn seed(seed []u32) {
default_rng.seed(seed) default_rng.seed(seed)
} }