rand: add non uniform distributions in the `rand.dist` module (#9274)
parent
0c33656a19
commit
08da33fa5a
|
@ -0,0 +1,10 @@
|
|||
# Non-Uniform Distribution Functions
|
||||
|
||||
This module contains functions for sampling from non-uniform distributions.
|
||||
|
||||
All implementations of the `rand.PRNG` interface generate numbers from uniform
|
||||
distributions. This library exists to allow the generation of pseudorandom numbers
|
||||
sampled from non-uniform distributions. Additionally, it allows the user to use any
|
||||
PRNG of their choice. This is because the default RNG can be reassigned to a different
|
||||
generator. It can either be one of the pre-existing one (which are well-tested and
|
||||
recommended) or a custom user-defined one. See `rand.set_rng()`.
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) 2019-2021 Alexander Medvednikov. All rights reserved.
|
||||
// Use of this source code is governed by an MIT license
|
||||
// that can be found in the LICENSE file.
|
||||
module dist
|
||||
|
||||
import math
|
||||
import rand
|
||||
|
||||
fn check_probability_range(p f64) {
|
||||
if p < 0 || p > 1 {
|
||||
panic('$p is not a valid probability value.')
|
||||
}
|
||||
}
|
||||
|
||||
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
|
||||
pub fn bernoulli(p f64) bool {
|
||||
check_probability_range(p)
|
||||
return rand.f64() <= p
|
||||
}
|
||||
|
||||
// binomial returns the number of successful trials out of n when the
|
||||
// probability of success for each trial is p.
|
||||
pub fn binomial(n int, p f64) int {
|
||||
check_probability_range(p)
|
||||
mut count := 0
|
||||
for _ in 0 .. n {
|
||||
if bernoulli(p) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Configuration struct for the `normal_pair` function. The default value for
|
||||
// `mu` is 0 and the default value for `sigma` is 1.
|
||||
pub struct NormalConfigStruct {
|
||||
mu f64 = 0.0
|
||||
sigma f64 = 1.0
|
||||
}
|
||||
|
||||
// normal_pair returns a pair of normally distributed random numbers with the mean mu
|
||||
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
||||
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`.
|
||||
pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
|
||||
// This is an implementation of the Marsaglia polar method
|
||||
// See: https://doi.org/10.1137%2F1006063
|
||||
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
|
||||
for {
|
||||
u := rand.f64_in_range(-1, 1)
|
||||
v := rand.f64_in_range(-1, 1)
|
||||
|
||||
s := u * u + v * v
|
||||
if s >= 1 || s == 0 {
|
||||
continue
|
||||
}
|
||||
t := math.sqrt(-2 * math.log(s) / s)
|
||||
x := config.mu + config.sigma * t * u
|
||||
y := config.mu + config.sigma * t * v
|
||||
return x, y
|
||||
}
|
||||
return config.mu, config.mu
|
||||
}
|
||||
|
||||
// normal returns a normally distributed random number with the mean mu and standard deviation
|
||||
// sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
||||
// `x := normal(mu: mean, sigma: etdev)` or `x := normal({})`.
|
||||
// **NOTE:** If you are generating a lot of normal variates, use `the normal_pair` function
|
||||
// instead. This function discards one of the two variates generated by the `normal_pair` function.
|
||||
pub fn normal(config NormalConfigStruct) f64 {
|
||||
x, _ := normal_pair(config)
|
||||
return x
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
import math
|
||||
import rand
|
||||
import rand.dist
|
||||
|
||||
const (
|
||||
// The sample size to be used
|
||||
count = 2000
|
||||
// Accepted error is within 5% of the actual values.
|
||||
error = 0.05
|
||||
// The seeds used (for reproducible testing)
|
||||
seeds = [[u32(0xffff24), 0xabcd], [u32(0x141024), 0x42851],
|
||||
[u32(0x1452), 0x90cd],
|
||||
]
|
||||
)
|
||||
|
||||
fn test_bernoulli() {
|
||||
ps := [0.0, 0.1, 1.0 / 3.0, 0.5, 0.8, 17.0 / 18.0, 1.0]
|
||||
|
||||
for seed in seeds {
|
||||
rand.seed(seed)
|
||||
for p in ps {
|
||||
mut successes := 0
|
||||
for _ in 0 .. count {
|
||||
if dist.bernoulli(p) {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
assert math.abs(f64(successes) / count - p) < error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_binomial() {
|
||||
ns := [100, 200, 1000]
|
||||
ps := [0.0, 0.5, 0.95, 1.0]
|
||||
|
||||
for seed in seeds {
|
||||
rand.seed(seed)
|
||||
for n in ns {
|
||||
for p in ps {
|
||||
np := n * p
|
||||
npq := np * (1 - p)
|
||||
|
||||
mut sum := 0
|
||||
mut var := 0.0
|
||||
for _ in 0 .. count {
|
||||
x := dist.binomial(n, p)
|
||||
sum += x
|
||||
dist := (x - np)
|
||||
var += dist * dist
|
||||
}
|
||||
|
||||
assert math.abs(f64(sum / count) - np) / n < error
|
||||
assert math.abs(f64(var / count) - npq) / n < error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_normal_pair() {
|
||||
mus := [0, 10, 100, -40]
|
||||
sigmas := [1, 2, 40, 5]
|
||||
total := 2 * count
|
||||
|
||||
for seed in seeds {
|
||||
rand.seed(seed)
|
||||
for mu in mus {
|
||||
for sigma in sigmas {
|
||||
mut sum := 0.0
|
||||
mut var := 0.0
|
||||
for _ in 0 .. count {
|
||||
x, y := dist.normal_pair(mu: mu, sigma: sigma)
|
||||
sum += x + y
|
||||
dist_x := x - mu
|
||||
dist_y := y - mu
|
||||
var += dist_x * dist_x
|
||||
var += dist_y * dist_y
|
||||
}
|
||||
|
||||
variance := sigma * sigma
|
||||
assert math.abs(f64(sum / total) - mu) / sigma < 1
|
||||
assert math.abs(f64(var / total) - variance) / variance < 2 * error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_normal() {
|
||||
mus := [0, 10, 100, -40, 20]
|
||||
sigmas := [1, 2, 5]
|
||||
|
||||
for seed in seeds {
|
||||
rand.seed(seed)
|
||||
for mu in mus {
|
||||
for sigma in sigmas {
|
||||
mut sum := 0.0
|
||||
mut var := 0.0
|
||||
for _ in 0 .. count {
|
||||
x := dist.normal(mu: mu, sigma: sigma)
|
||||
sum += x
|
||||
dist := x - mu
|
||||
var += dist * dist
|
||||
}
|
||||
|
||||
variance := sigma * sigma
|
||||
assert math.abs(f64(sum / count) - mu) / sigma < 1
|
||||
assert math.abs(f64(var / count) - variance) / variance < 2 * error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,7 +4,6 @@
|
|||
module mt19937
|
||||
|
||||
import math.bits
|
||||
import rand.seed
|
||||
|
||||
/*
|
||||
C++ functions for MT19937, with initialization improved 2002/2/10.
|
||||
|
@ -57,6 +56,7 @@ const (
|
|||
)
|
||||
|
||||
// MT19937RNG is generator that uses the Mersenne Twister algorithm with period 2^19937.
|
||||
// **NOTE**: The RNG is not seeded when instantiated so remember to seed it before use.
|
||||
pub struct MT19937RNG {
|
||||
mut:
|
||||
state []u64 = []u64{len: mt19937.nn}
|
||||
|
|
|
@ -71,7 +71,9 @@ pub fn set_rng(rng &PRNG) {
|
|||
default_rng = rng
|
||||
}
|
||||
|
||||
// seed sets the given array of `u32` values as the seed for the `default_rng`. It is recommended to use
|
||||
// seed sets the given array of `u32` values as the seed for the `default_rng`. The default_rng is
|
||||
// an instance of WyRandRNG which takes 2 u32 values. When using a custom RNG, make sure to use
|
||||
// the correct number of u32s.
|
||||
pub fn seed(seed []u32) {
|
||||
default_rng.seed(seed)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue