rand.dist: add exponential distribution function and unit tests (#9402)

pull/9408/head
Subhomoy Haldar 2021-03-21 16:34:43 +05:30 committed by GitHub
parent c4e6ef424e
commit 0e80e57aa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 0 deletions

13
vlib/rand/dist/dist.v vendored
View File

@ -42,6 +42,9 @@ pub struct NormalConfigStruct {
// and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is // and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is
// `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`. // `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`.
pub fn normal_pair(config NormalConfigStruct) (f64, f64) { pub fn normal_pair(config NormalConfigStruct) (f64, f64) {
if config.sigma <= 0 {
panic('The standard deviation has to be positive.')
}
// This is an implementation of the Marsaglia polar method // This is an implementation of the Marsaglia polar method
// See: https://doi.org/10.1137%2F1006063 // See: https://doi.org/10.1137%2F1006063
// Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method // Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method
@ -70,3 +73,13 @@ pub fn normal(config NormalConfigStruct) f64 {
x, _ := normal_pair(config) x, _ := normal_pair(config)
return x return x
} }
// exponential returns an exponentially distributed random number with the rate paremeter
// lambda. It is expected that lambda is positive.
pub fn exponential(lambda f64) f64 {
if lambda <= 0 {
panic('The rate (lambda) must be positive.')
}
// Use the inverse transform sampling method
return -math.log(rand.f64()) / lambda
}

View File

@ -109,3 +109,26 @@ fn test_normal() {
} }
} }
} }
fn test_exponential() {
lambdas := [1.0, 10, 1 / 20.0, 1 / 10000.0, 1 / 524.0, 200]
for seed in seeds {
rand.seed(seed)
for lambda in lambdas {
mu := 1 / lambda
variance := mu * mu
mut sum := 0.0
mut var := 0.0
for _ in 0 .. count {
x := dist.exponential(lambda)
sum += x
dist := x - mu
var += dist * dist
}
assert math.abs((f64(sum / count) - mu) / mu) < error
assert math.abs((f64(var / count) - variance) / variance) < 2 * error
}
}
}