diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 40d4c1a3ac..af42fd64a0 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -226,8 +226,9 @@ def get_tau_sigma(tau=None, sigma=None): if isinstance(sigma, Variable): sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0") else: - assert np.all(np.asarray(sigma) > 0) - sigma_ = sigma + sigma_ = np.asarray(sigma) + if np.any(sigma_ <= 0): + raise ValueError("sigma must be positive") tau = sigma_**-2.0 else: @@ -237,9 +238,9 @@ def get_tau_sigma(tau=None, sigma=None): if isinstance(tau, Variable): tau_ = check_parameters(tau, tau > 0, msg="tau > 0") else: - assert np.all(np.asarray(tau) > 0) - tau_ = tau - + tau_ = np.asarray(tau) + if np.any(tau_ <= 0): + raise ValueError("tau must be positive") sigma = tau_**-0.5 return floatX(tau), floatX(sigma) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index a8d8ea4dd4..53f2357e26 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2399,6 +2399,11 @@ def test_get_tau_sigma(self): with pytest.raises(ParameterValueError): sigma.eval() + sigma = [1, 2] + assert_almost_equal( + get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] + ) + @pytest.mark.parametrize( "value,mu,sigma,nu,logp", [