rand: add non uniform distributions in the `rand.dist` module (#9274)
parent
0c33656a19
commit
08da33fa5a
|
@ -0,0 +1,10 @@
|
||||||
|
# Non-Uniform Distribution Functions
|
||||||
|
|
||||||
|
This module contains functions for sampling from non-uniform distributions.
|
||||||
|
|
||||||
|
All implementations of the `rand.PRNG` interface generate numbers from uniform
|
||||||
|
distributions. This library exists to allow the generation of pseudorandom numbers
|
||||||
|
sampled from non-uniform distributions. Additionally, it allows the user to use any
|
||||||
|
PRNG of their choice. This is because the default RNG can be reassigned to a different
|
||||||
|
generator. It can either be one of the pre-existing one (which are well-tested and
|
||||||
|
recommended) or a custom user-defined one. See `rand.set_rng()`.
|
|
@ -0,0 +1,72 @@
|
||||||
|
// Copyright (c) 2019-2021 Alexander Medvednikov. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT license
|
||||||
|
// that can be found in the LICENSE file.
|
||||||
|
module dist
|
||||||
|
|
||||||
|
import math
|
||||||
|
import rand
|
||||||
|
|
||||||
|
fn check_probability_range(p f64) {
|
||||||
|
if p < 0 || p > 1 {
|
||||||
|
panic('$p is not a valid probability value.')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
|
||||||
|
pub fn bernoulli(p f64) bool {
|
||||||
|
check_probability_range(p)
|
||||||
|
return rand.f64() <= p
|
||||||
|
}
|
||||||
|
|
||||||
|
// binomial returns the number of successful trials out of n when the
|
||||||
|
// probability of success for each trial is p.
|
||||||
|
pub fn binomial(n int, p f64) int {
|
||||||
|
check_probability_range(p)
|
||||||
|
mut count := 0
|
||||||
|
for _ in 0 .. n {
|
||||||
|
if bernoulli(p) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configuration struct for the `normal_pair` function. The default value for
|
||||||
|
// `mu` is 0 and the default value for `sigma` is 1.
|
||||||
|
pub struct NormalConfigStruct {
|
||||||
|
mu f64 = 0.0
|
||||||
|
sigma f64 = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// normal_pair returns a pair of normally distributed random numbers with the mean mu
|
||||||
|
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
||||||
|
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`.
|
||||||
|
pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
|
||||||
|
// This is an implementation of the Marsaglia polar method
|
||||||
|
// See: https://doi.org/10.1137%2F1006063
|
||||||
|
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
|
||||||
|
for {
|
||||||
|
u := rand.f64_in_range(-1, 1)
|
||||||
|
v := rand.f64_in_range(-1, 1)
|
||||||
|
|
||||||
|
s := u * u + v * v
|
||||||
|
if s >= 1 || s == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t := math.sqrt(-2 * math.log(s) / s)
|
||||||
|
x := config.mu + config.sigma * t * u
|
||||||
|
y := config.mu + config.sigma * t * v
|
||||||
|
return x, y
|
||||||
|
}
|
||||||
|
return config.mu, config.mu
|
||||||
|
}
|
||||||
|
|
||||||
|
// normal returns a normally distributed random number with the mean mu and standard deviation
|
||||||
|
// sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
||||||
|
// `x := normal(mu: mean, sigma: etdev)` or `x := normal({})`.
|
||||||
|
// **NOTE:** If you are generating a lot of normal variates, use `the normal_pair` function
|
||||||
|
// instead. This function discards one of the two variates generated by the `normal_pair` function.
|
||||||
|
pub fn normal(config NormalConfigStruct) f64 {
|
||||||
|
x, _ := normal_pair(config)
|
||||||
|
return x
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
import math
|
||||||
|
import rand
|
||||||
|
import rand.dist
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The sample size to be used
|
||||||
|
count = 2000
|
||||||
|
// Accepted error is within 5% of the actual values.
|
||||||
|
error = 0.05
|
||||||
|
// The seeds used (for reproducible testing)
|
||||||
|
seeds = [[u32(0xffff24), 0xabcd], [u32(0x141024), 0x42851],
|
||||||
|
[u32(0x1452), 0x90cd],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fn test_bernoulli() {
|
||||||
|
ps := [0.0, 0.1, 1.0 / 3.0, 0.5, 0.8, 17.0 / 18.0, 1.0]
|
||||||
|
|
||||||
|
for seed in seeds {
|
||||||
|
rand.seed(seed)
|
||||||
|
for p in ps {
|
||||||
|
mut successes := 0
|
||||||
|
for _ in 0 .. count {
|
||||||
|
if dist.bernoulli(p) {
|
||||||
|
successes++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert math.abs(f64(successes) / count - p) < error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_binomial() {
|
||||||
|
ns := [100, 200, 1000]
|
||||||
|
ps := [0.0, 0.5, 0.95, 1.0]
|
||||||
|
|
||||||
|
for seed in seeds {
|
||||||
|
rand.seed(seed)
|
||||||
|
for n in ns {
|
||||||
|
for p in ps {
|
||||||
|
np := n * p
|
||||||
|
npq := np * (1 - p)
|
||||||
|
|
||||||
|
mut sum := 0
|
||||||
|
mut var := 0.0
|
||||||
|
for _ in 0 .. count {
|
||||||
|
x := dist.binomial(n, p)
|
||||||
|
sum += x
|
||||||
|
dist := (x - np)
|
||||||
|
var += dist * dist
|
||||||
|
}
|
||||||
|
|
||||||
|
assert math.abs(f64(sum / count) - np) / n < error
|
||||||
|
assert math.abs(f64(var / count) - npq) / n < error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_normal_pair() {
|
||||||
|
mus := [0, 10, 100, -40]
|
||||||
|
sigmas := [1, 2, 40, 5]
|
||||||
|
total := 2 * count
|
||||||
|
|
||||||
|
for seed in seeds {
|
||||||
|
rand.seed(seed)
|
||||||
|
for mu in mus {
|
||||||
|
for sigma in sigmas {
|
||||||
|
mut sum := 0.0
|
||||||
|
mut var := 0.0
|
||||||
|
for _ in 0 .. count {
|
||||||
|
x, y := dist.normal_pair(mu: mu, sigma: sigma)
|
||||||
|
sum += x + y
|
||||||
|
dist_x := x - mu
|
||||||
|
dist_y := y - mu
|
||||||
|
var += dist_x * dist_x
|
||||||
|
var += dist_y * dist_y
|
||||||
|
}
|
||||||
|
|
||||||
|
variance := sigma * sigma
|
||||||
|
assert math.abs(f64(sum / total) - mu) / sigma < 1
|
||||||
|
assert math.abs(f64(var / total) - variance) / variance < 2 * error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_normal() {
|
||||||
|
mus := [0, 10, 100, -40, 20]
|
||||||
|
sigmas := [1, 2, 5]
|
||||||
|
|
||||||
|
for seed in seeds {
|
||||||
|
rand.seed(seed)
|
||||||
|
for mu in mus {
|
||||||
|
for sigma in sigmas {
|
||||||
|
mut sum := 0.0
|
||||||
|
mut var := 0.0
|
||||||
|
for _ in 0 .. count {
|
||||||
|
x := dist.normal(mu: mu, sigma: sigma)
|
||||||
|
sum += x
|
||||||
|
dist := x - mu
|
||||||
|
var += dist * dist
|
||||||
|
}
|
||||||
|
|
||||||
|
variance := sigma * sigma
|
||||||
|
assert math.abs(f64(sum / count) - mu) / sigma < 1
|
||||||
|
assert math.abs(f64(var / count) - variance) / variance < 2 * error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,6 @@
|
||||||
module mt19937
|
module mt19937
|
||||||
|
|
||||||
import math.bits
|
import math.bits
|
||||||
import rand.seed
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
C++ functions for MT19937, with initialization improved 2002/2/10.
|
C++ functions for MT19937, with initialization improved 2002/2/10.
|
||||||
|
@ -57,6 +56,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937.
|
// MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937.
|
||||||
|
// **NOTE**: The RNG is not seeded when instantiated so remember to seed it before use.
|
||||||
pub struct MT19937RNG {
|
pub struct MT19937RNG {
|
||||||
mut:
|
mut:
|
||||||
state []u64 = []u64{len: mt19937.nn}
|
state []u64 = []u64{len: mt19937.nn}
|
||||||
|
|
|
@ -71,7 +71,9 @@ pub fn set_rng(rng &PRNG) {
|
||||||
default_rng = rng
|
default_rng = rng
|
||||||
}
|
}
|
||||||
|
|
||||||
// seed sets the given array of `u32` values as the seed for the `default_rng`. It is recommended to use
|
// seed sets the given array of `u32` values as the seed for the `default_rng`. The default_rng is
|
||||||
|
// an instance of WyRandRNG which takes 2 u32 values. When using a custom RNG, make sure to use
|
||||||
|
// the correct number of u32s.
|
||||||
pub fn seed(seed []u32) {
|
pub fn seed(seed []u32) {
|
||||||
default_rng.seed(seed)
|
default_rng.seed(seed)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue