Skip to content

Commit 54d288d

Browse files
basnijholtmichaelosthege
authored andcommitted
Add tests related to named models
A test was added to asserts that strace.name is a string (this was not the case before #4365). Non-empty model names are actually not supported (again, see #4365) so attempting to SMC-sample a named model will now raise a NotImplementedError.
1 parent 4f079d4 commit 54d288d

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

pymc3/smc/sample_smc.py

+5
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def sample_smc(
141141
_log.info("Initializing SMC sampler...")
142142

143143
model = modelcontext(model)
144+
if model.name:
145+
raise NotImplementedError(
146+
"The SMC implementation currently does not support named models. "
147+
"See https://github.com/pymc-devs/pymc3/pull/4365."
148+
)
144149
if cores is None:
145150
cores = _cpu_count()
146151

pymc3/smc/smc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def posterior_to_trace(self):
255255
varnames = [v.name for v in self.variables]
256256

257257
with self.model:
258-
strace = NDArray(self.model.name)
258+
strace = NDArray(name=self.model.name)
259259
strace.setup(lenght_pos, self.chain)
260260
for i in range(lenght_pos):
261261
value = []

pymc3/tests/test_smc.py

+21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import pytest
1617
import theano.tensor as tt
1718

1819
import pymc3 as pm
@@ -189,3 +190,23 @@ def test_repr_latex(self):
189190
assert expected == self.s._repr_latex_()
190191
assert self.s._repr_latex_() == self.s.__latex__()
191192
assert self.SMABC_test.model._repr_latex_() == self.SMABC_test.model.__latex__()
193+
194+
def test_name_is_string_type(self):
195+
with self.SMABC_potential:
196+
assert not self.SMABC_potential.name
197+
trace = pm.sample_smc(draws=10, kernel="ABC")
198+
assert isinstance(trace._straces[0].name, str)
199+
200+
def test_named_models_are_unsupported(self):
201+
def normal_sim(a, b):
202+
return np.random.normal(a, b, 1000)
203+
204+
with pm.Model(name="NamedModel"):
205+
a = pm.Normal("a", mu=0, sigma=1)
206+
b = pm.HalfNormal("b", sigma=1)
207+
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
208+
s = pm.Simulator(
209+
"s", normal_sim, params=(a, b), sum_stat="sort", epsilon=1, observed=self.data
210+
)
211+
with pytest.raises(NotImplementedError, match="named models"):
212+
pm.sample_smc(draws=10, kernel="ABC")

0 commit comments

Comments
 (0)