From 12a5df593d3cb54d1119a0ae24439b1b18983e08 Mon Sep 17 00:00:00 2001 From: AustinRochford Date: Fri, 14 Oct 2016 09:23:14 -0400 Subject: [PATCH] Don't broadcast in alltrue --- pymc3/distributions/dist_math.py | 5 +---- pymc3/tests/test_dist_math.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 pymc3/tests/test_dist_math.py diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 9e3d3562cb..997d00eef2 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -29,10 +29,7 @@ def bound(logp, *conditions): def alltrue(vals): - ret = 1 - for c in vals: - ret = ret * (1 * c) - return ret + return tt.all([tt.all(1 * val) for val in vals]) def logpow(x, m): diff --git a/pymc3/tests/test_dist_math.py b/pymc3/tests/test_dist_math.py new file mode 100644 index 0000000000..4fff6464ef --- /dev/null +++ b/pymc3/tests/test_dist_math.py @@ -0,0 +1,36 @@ +import numpy as np +import theano.tensor as tt + +from ..distributions.dist_math import alltrue + + +def test_alltrue(): + assert alltrue([]).eval() + assert alltrue([True]).eval() + assert alltrue([tt.ones(10)]).eval() + assert alltrue([tt.ones(10), + 5 * tt.ones(101)]).eval() + assert alltrue([np.ones(10), + 5 * tt.ones(101)]).eval() + assert alltrue([np.ones(10), + True, + 5 * tt.ones(101)]).eval() + assert alltrue([np.array([1, 2, 3]), + True, + 5 * tt.ones(101)]).eval() + + assert not alltrue([False]).eval() + assert not alltrue([tt.zeros(10)]).eval() + assert not alltrue([True, + False]).eval() + assert not alltrue([np.array([0, -1]), + tt.ones(60)]).eval() + assert not alltrue([np.ones(10), + False, + 5 * tt.ones(101)]).eval() + + +def test_alltrue_shape(): + vals = [True, tt.ones(10), tt.zeros(5)] + + assert alltrue(vals).eval().shape == ()