Skip to content

Commit b18a09e

Browse files
Remove dict_to_dataset workaround
Closes #5042
1 parent ed9a60a commit b18a09e

File tree

1 file changed

+2
-55
lines changed

1 file changed

+2
-55
lines changed

pymc/backends/arviz.py

+2-55
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
Any,
88
Dict,
99
Iterable,
10-
List,
11-
Mapping,
1210
Optional,
1311
Tuple,
1412
Union,
@@ -21,9 +19,7 @@
2119
from aesara.tensor.sharedvar import SharedVariable
2220
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
2321
from arviz import InferenceData, concat, rcParams
24-
from arviz.data.base import CoordSpec, DimSpec
25-
from arviz.data.base import dict_to_dataset as _dict_to_dataset
26-
from arviz.data.base import generate_dims_coords, make_attrs, requires
22+
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
2723

2824
import pymc
2925

@@ -101,42 +97,6 @@ def insert(self, k: str, v, idx: int):
10197
self.trace_dict[k][idx, :] = v
10298

10399

104-
def dict_to_dataset(
105-
data,
106-
library=None,
107-
coords=None,
108-
dims=None,
109-
attrs=None,
110-
default_dims=None,
111-
skip_event_dims=None,
112-
index_origin=None,
113-
):
114-
"""Temporal workaround for dict_to_dataset.
115-
116-
Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
117-
1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
118-
also remove unnecessary imports
119-
"""
120-
if default_dims is None:
121-
return _dict_to_dataset(
122-
data,
123-
attrs=attrs,
124-
library=library,
125-
coords=coords,
126-
dims=dims,
127-
skip_event_dims=skip_event_dims,
128-
)
129-
else:
130-
out_data = {}
131-
for name, vals in data.items():
132-
vals = np.atleast_1d(vals)
133-
val_dims = dims.get(name)
134-
val_dims, crds = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
135-
crds = {key: xr.IndexVariable((key,), data=crds[key]) for key in val_dims}
136-
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=crds)
137-
return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library))
138-
139-
140100
class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
141101
"""Encapsulate InferenceData specific logic."""
142102

@@ -160,7 +120,6 @@ def __init__(
160120
model=None,
161121
save_warmup: Optional[bool] = None,
162122
density_dist_obs: bool = True,
163-
index_origin: Optional[int] = None,
164123
):
165124

166125
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
@@ -196,7 +155,6 @@ def __init__(
196155
self.posterior_predictive = posterior_predictive
197156
self.log_likelihood = log_likelihood
198157
self.predictions = predictions
199-
self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
200158

201159
def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
202160
return next(iter(dct.values()))
@@ -344,15 +302,13 @@ def posterior_to_xarray(self):
344302
coords=self.coords,
345303
dims=self.dims,
346304
attrs=self.attrs,
347-
index_origin=self.index_origin,
348305
),
349306
dict_to_dataset(
350307
data_warmup,
351308
library=pymc,
352309
coords=self.coords,
353310
dims=self.dims,
354311
attrs=self.attrs,
355-
index_origin=self.index_origin,
356312
),
357313
)
358314

@@ -386,15 +342,13 @@ def sample_stats_to_xarray(self):
386342
dims=None,
387343
coords=self.coords,
388344
attrs=self.attrs,
389-
index_origin=self.index_origin,
390345
),
391346
dict_to_dataset(
392347
data_warmup,
393348
library=pymc,
394349
dims=None,
395350
coords=self.coords,
396351
attrs=self.attrs,
397-
index_origin=self.index_origin,
398352
),
399353
)
400354

@@ -427,15 +381,13 @@ def log_likelihood_to_xarray(self):
427381
dims=self.dims,
428382
coords=self.coords,
429383
skip_event_dims=True,
430-
index_origin=self.index_origin,
431384
),
432385
dict_to_dataset(
433386
data_warmup,
434387
library=pymc,
435388
dims=self.dims,
436389
coords=self.coords,
437390
skip_event_dims=True,
438-
index_origin=self.index_origin,
439391
),
440392
)
441393

@@ -456,9 +408,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
456408
"This can mean that some draws or even whole chains are not represented.",
457409
k,
458410
)
459-
return dict_to_dataset(
460-
data, library=pymc, coords=self.coords, dims=self.dims, index_origin=self.index_origin
461-
)
411+
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
462412

463413
@requires(["posterior_predictive"])
464414
def posterior_predictive_to_xarray(self):
@@ -493,7 +443,6 @@ def priors_to_xarray(self):
493443
library=pymc,
494444
coords=self.coords,
495445
dims=self.dims,
496-
index_origin=self.index_origin,
497446
)
498447
)
499448
return priors_dict
@@ -510,7 +459,6 @@ def observed_data_to_xarray(self):
510459
coords=self.coords,
511460
dims=self.dims,
512461
default_dims=[],
513-
index_origin=self.index_origin,
514462
)
515463

516464
@requires(["trace", "predictions"])
@@ -557,7 +505,6 @@ def is_data(name, var) -> bool:
557505
coords=self.coords,
558506
dims=self.dims,
559507
default_dims=[],
560-
index_origin=self.index_origin,
561508
)
562509

563510
def to_inference_data(self):

0 commit comments

Comments
 (0)