Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40)
- Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43)
- Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44)
- Avoid returning NaN from `Gamma::sample`; this is a Value-breaking change and also affects `ChiSquared` and `Dirichlet` (#46)

## [0.5.2]

Expand Down
72 changes: 50 additions & 22 deletions src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ use serde::{Deserialize, Serialize};
///
/// # Notes
///
/// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits
/// of the floating point type `F`, the implementation may overflow and produce
/// `inf`. On the other hand, when `k` is relatively close to zero (like 0.005)
/// and `θ` is huge (like 1e200), the implementation is likely be affected by
/// underflow and may fail to produce tiny floating point values (like 1e-200),
/// returning 0.0 for them instead. The exact thresholds for this to occur
/// depend on `F`.
///
/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
/// falling back to directly sampling from an Exponential for `shape
/// == 1`, and using the boosting technique described in that paper for
Expand Down Expand Up @@ -173,8 +181,10 @@ where
return Err(Error::ScaleTooSmall);
}

let repr = if shape == F::one() {
One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
let repr = if shape == F::infinity() || scale == F::infinity() {
One(Exp::new(F::zero()).unwrap())
} else if shape == F::one() {
One(Exp::new(F::one() / scale).unwrap())
} else if shape < F::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Expand Down Expand Up @@ -212,6 +222,28 @@ where
d,
}
}

fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// Marsaglia & Tsang method, 2000
loop {
let x: F = rng.sample(StandardNormal);
let v_cbrt = F::one() + self.c * x;
if v_cbrt <= F::zero() {
continue;
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: F = rng.sample(Open01);

let x_sqr = x * x;
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
{
// `x` is concentrated enough that `v` should always be finite
return v;
}
}
}
}

impl<F> Distribution<F> for Gamma<F>
Expand All @@ -238,35 +270,22 @@ where
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let u: F = rng.sample(Open01);

self.large_shape.sample(rng) * u.powf(self.inv_shape)
let a = self.large_shape.sample_unscaled(rng);
let b = u.powf(self.inv_shape);
// Multiplying numbers with `scale` can overflow, so do it last to avoid
// producing NaN = inf * 0.0. All the other terms are finite and small.
(a * b * self.large_shape.d) * self.large_shape.scale
}
}

impl<F> Distribution<F> for GammaLargeShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// Marsaglia & Tsang method, 2000
loop {
let x: F = rng.sample(StandardNormal);
let v_cbrt = F::one() + self.c * x;
if v_cbrt <= F::zero() {
// a^3 <= 0 iff a <= 0
continue;
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: F = rng.sample(Open01);

let x_sqr = x * x;
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
{
return self.d * v * self.scale;
}
}
self.sample_unscaled(rng) * (self.d * self.scale)
}
}

Expand All @@ -278,4 +297,13 @@ mod test {
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}

#[test]
fn gamma_extreme_values() {
let d = Gamma::new(f64::infinity(), 2.0).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());

let d = Gamma::new(2.0, f64::infinity()).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
}
}
Loading