Skip to content

Commit a5f3e45

Browse files
committed
Cast rounded input replacement to original discrete dtype in sample_smc
1 parent fb3e65f commit a5f3e45

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## PyMC 4.0.1 (vNext)
44
+ Fixed an incorrect entry in `pm.Metropolis.stats_dtypes` (see #5582).
55
+ Added a check in `Empirical` approximation which does not yet support `InferenceData` inputs (see #5874, #5884).
6-
+ ...
6+
+ Fixed bug when sampling discrete variables with SMC (see #5887).
77

88
## PyMC 4.0.0 (2022-06-03)
99

pymc/smc/smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def _logp_forw(point, out_vars, in_vars, shared):
567567
if in_var.dtype in discrete_types:
568568
float_var = at.TensorType("floatX", in_var.broadcastable)(in_var.name)
569569
new_in_vars.append(float_var)
570-
replace_int_input[in_var] = at.round(float_var)
570+
replace_int_input[in_var] = at.round(float_var).astype(in_var.dtype)
571571
else:
572572
new_in_vars.append(in_var)
573573

pymc/tests/test_smc.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_discrete_rounding_proposal(self):
110110
assert np.isclose(smc.prior_logp_func(floatX(np.array([0.51]))), np.log(0.7))
111111
assert smc.prior_logp_func(floatX(np.array([1.51]))) == -np.inf
112112

113-
def test_unobserved_discrete(self):
113+
def test_unobserved_bernoulli(self):
114114
n = 10
115115
rng = self.get_random_state()
116116
z_true = np.zeros(n, dtype=int)
@@ -126,6 +126,15 @@ def test_unobserved_discrete(self):
126126

127127
assert np.all(np.median(trace["z"], axis=0) == z_true)
128128

129+
def test_unobserved_categorical(self):
130+
with pm.Model() as m:
131+
mu = pm.Categorical("mu", p=[0.1, 0.3, 0.6], size=2)
132+
pm.Normal("like", mu=mu, sigma=0.1, observed=[1, 2])
133+
134+
trace = pm.sample_smc(chains=1, return_inferencedata=False)
135+
136+
assert np.all(np.median(trace["mu"], axis=0) == [1, 2])
137+
129138
def test_marginal_likelihood(self):
130139
"""
131140
Verifies that the log marginal likelihood function

0 commit comments

Comments
 (0)