86 lines
2.6 KiB
V
86 lines
2.6 KiB
V
// Copyright (c) 2019-2021 Alexander Medvednikov. All rights reserved.
|
|
// Use of this source code is governed by an MIT license
|
|
// that can be found in the LICENSE file.
|
|
module dist
|
|
|
|
import math
|
|
import rand
|
|
|
|
fn check_probability_range(p f64) {
|
|
if p < 0 || p > 1 {
|
|
panic('$p is not a valid probability value.')
|
|
}
|
|
}
|
|
|
|
// bernoulli returns true with a probability p. Note that 0 <= p <= 1.
|
|
pub fn bernoulli(p f64) bool {
|
|
check_probability_range(p)
|
|
return rand.f64() <= p
|
|
}
|
|
|
|
// binomial returns the number of successful trials out of n when the
|
|
// probability of success for each trial is p.
|
|
pub fn binomial(n int, p f64) int {
|
|
check_probability_range(p)
|
|
mut count := 0
|
|
for _ in 0 .. n {
|
|
if bernoulli(p) {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
// Configuration struct for the `normal_pair` function. The default value for
|
|
// `mu` is 0 and the default value for `sigma` is 1.
|
|
pub struct NormalConfigStruct {
|
|
mu f64 = 0.0
|
|
sigma f64 = 1.0
|
|
}
|
|
|
|
// normal_pair returns a pair of normally distributed random numbers with the mean mu
|
|
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
|
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair()`.
|
|
pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
|
|
if config.sigma <= 0 {
|
|
panic('The standard deviation has to be positive.')
|
|
}
|
|
// This is an implementation of the Marsaglia polar method
|
|
// See: https://doi.org/10.1137%2F1006063
|
|
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
|
|
for {
|
|
u := rand.f64_in_range(-1, 1)
|
|
v := rand.f64_in_range(-1, 1)
|
|
|
|
s := u * u + v * v
|
|
if s >= 1 || s == 0 {
|
|
continue
|
|
}
|
|
t := math.sqrt(-2 * math.log(s) / s)
|
|
x := config.mu + config.sigma * t * u
|
|
y := config.mu + config.sigma * t * v
|
|
return x, y
|
|
}
|
|
return config.mu, config.mu
|
|
}
|
|
|
|
// normal returns a normally distributed random number with the mean mu and standard deviation
|
|
// sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
|
|
// `x := normal(mu: mean, sigma: etdev)` or `x := normal()`.
|
|
// **NOTE:** If you are generating a lot of normal variates, use `the normal_pair` function
|
|
// instead. This function discards one of the two variates generated by the `normal_pair` function.
|
|
pub fn normal(config NormalConfigStruct) f64 {
|
|
x, _ := normal_pair(config)
|
|
return x
|
|
}
|
|
|
|
// exponential returns an exponentially distributed random number with the rate paremeter
|
|
// lambda. It is expected that lambda is positive.
|
|
pub fn exponential(lambda f64) f64 {
|
|
if lambda <= 0 {
|
|
panic('The rate (lambda) must be positive.')
|
|
}
|
|
// Use the inverse transform sampling method
|
|
return -math.log(rand.f64()) / lambda
|
|
}
|