From 0e80e57aa52da4246d75e0b62e5cd6561262f225 Mon Sep 17 00:00:00 2001 From: Subhomoy Haldar Date: Sun, 21 Mar 2021 16:34:43 +0530 Subject: [PATCH] rand.dist: add exponential distribution function and unit tests (#9402) --- vlib/rand/dist/dist.v | 13 +++++++++++++ vlib/rand/dist/dist_test.v | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/vlib/rand/dist/dist.v b/vlib/rand/dist/dist.v index 24f3617892..5ebf256476 100644 --- a/vlib/rand/dist/dist.v +++ b/vlib/rand/dist/dist.v @@ -42,6 +42,9 @@ pub struct NormalConfigStruct { // and standard deviation sigma. If not specified, mu is 0 and sigma is 1. Intended usage is // `x, y := normal_pair(mu: mean, sigma: stdev)`, or `x, y := normal_pair({})`. pub fn normal_pair(config NormalConfigStruct) (f64, f64) { + if config.sigma <= 0 { + panic('The standard deviation has to be positive.') + } // This is an implementation of the Marsaglia polar method // See: https://doi.org/10.1137%2F1006063 // Also: https://en.wikipedia.org/wiki/Marsaglia_polar_method @@ -70,3 +73,13 @@ pub fn normal(config NormalConfigStruct) f64 { x, _ := normal_pair(config) return x } + +// exponential returns an exponentially distributed random number with the rate paremeter +// lambda. It is expected that lambda is positive. +pub fn exponential(lambda f64) f64 { + if lambda <= 0 { + panic('The rate (lambda) must be positive.') + } + // Use the inverse transform sampling method + return -math.log(rand.f64()) / lambda +} diff --git a/vlib/rand/dist/dist_test.v b/vlib/rand/dist/dist_test.v index 08c58dc659..a7c565ec92 100644 --- a/vlib/rand/dist/dist_test.v +++ b/vlib/rand/dist/dist_test.v @@ -109,3 +109,26 @@ fn test_normal() { } } } + +fn test_exponential() { + lambdas := [1.0, 10, 1 / 20.0, 1 / 10000.0, 1 / 524.0, 200] + + for seed in seeds { + rand.seed(seed) + for lambda in lambdas { + mu := 1 / lambda + variance := mu * mu + mut sum := 0.0 + mut var := 0.0 + for _ in 0 .. count { + x := dist.exponential(lambda) + sum += x + dist := x - mu + var += dist * dist + } + + assert math.abs((f64(sum / count) - mu) / mu) < error + assert math.abs((f64(var / count) - variance) / variance) < 2 * error + } + } +}