Skip to content

Commit a968a9b

Browse files
brandonwillardtwiecki
authored andcommitted
Update Op checks to include both AdvancedIncSubtensor Ops
1 parent f8843c0 commit a968a9b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pymc3/backends/arviz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from aesara.graph.basic import Constant
2121
from aesara.tensor.sharedvar import SharedVariable
22-
from aesara.tensor.subtensor import AdvancedIncSubtensor
22+
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
2323
from arviz import InferenceData, concat, rcParams
2424
from arviz.data.base import CoordSpec, DimSpec
2525
from arviz.data.base import dict_to_dataset as _dict_to_dataset
@@ -283,7 +283,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):
283283
point = {i.name: point[i.name] for i in log_like_fun.f.maker.inputs if i.name in point}
284284
log_like_val = np.atleast_1d(log_like_fun(point))
285285

286-
if isinstance(var.owner.op, AdvancedIncSubtensor):
286+
if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
287287
try:
288288
obs_data = extract_obs_data(var.tag.observations)
289289
except TypeError:

pymc3/tests/test_idata_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
import pytest
77

8-
from aesara.tensor.subtensor import AdvancedIncSubtensor
8+
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
99
from arviz import InferenceData
1010
from arviz.tests.helpers import check_multiple_attrs
1111
from numpy import ma
@@ -327,7 +327,7 @@ def test_mv_missing_data_model(self):
327327
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
328328

329329
# make sure that data is really missing
330-
assert isinstance(y.owner.op, AdvancedIncSubtensor)
330+
assert isinstance(y.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))
331331

332332
test_dict = {
333333
"posterior": ["mu", "chol_cov"],

0 commit comments

Comments
 (0)