7
7
Any ,
8
8
Dict ,
9
9
Iterable ,
10
- List ,
11
- Mapping ,
12
10
Optional ,
13
11
Tuple ,
14
12
Union ,
21
19
from aesara .tensor .sharedvar import SharedVariable
22
20
from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
23
21
from arviz import InferenceData , concat , rcParams
24
- from arviz .data .base import CoordSpec , DimSpec
25
- from arviz .data .base import dict_to_dataset as _dict_to_dataset
26
- from arviz .data .base import generate_dims_coords , make_attrs , requires
22
+ from arviz .data .base import CoordSpec , DimSpec , dict_to_dataset , requires
27
23
28
24
import pymc
29
25
@@ -101,42 +97,6 @@ def insert(self, k: str, v, idx: int):
101
97
self .trace_dict [k ][idx , :] = v
102
98
103
99
104
- def dict_to_dataset (
105
- data ,
106
- library = None ,
107
- coords = None ,
108
- dims = None ,
109
- attrs = None ,
110
- default_dims = None ,
111
- skip_event_dims = None ,
112
- index_origin = None ,
113
- ):
114
- """Temporal workaround for dict_to_dataset.
115
-
116
- Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
117
- 1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
118
- also remove unnecessary imports
119
- """
120
- if default_dims is None :
121
- return _dict_to_dataset (
122
- data ,
123
- attrs = attrs ,
124
- library = library ,
125
- coords = coords ,
126
- dims = dims ,
127
- skip_event_dims = skip_event_dims ,
128
- )
129
- else :
130
- out_data = {}
131
- for name , vals in data .items ():
132
- vals = np .atleast_1d (vals )
133
- val_dims = dims .get (name )
134
- val_dims , crds = generate_dims_coords (vals .shape , name , dims = val_dims , coords = coords )
135
- crds = {key : xr .IndexVariable ((key ,), data = crds [key ]) for key in val_dims }
136
- out_data [name ] = xr .DataArray (vals , dims = val_dims , coords = crds )
137
- return xr .Dataset (data_vars = out_data , attrs = make_attrs (attrs = attrs , library = library ))
138
-
139
-
140
100
class InferenceDataConverter : # pylint: disable=too-many-instance-attributes
141
101
"""Encapsulate InferenceData specific logic."""
142
102
@@ -160,7 +120,6 @@ def __init__(
160
120
model = None ,
161
121
save_warmup : Optional [bool ] = None ,
162
122
density_dist_obs : bool = True ,
163
- index_origin : Optional [int ] = None ,
164
123
):
165
124
166
125
self .save_warmup = rcParams ["data.save_warmup" ] if save_warmup is None else save_warmup
@@ -196,7 +155,6 @@ def __init__(
196
155
self .posterior_predictive = posterior_predictive
197
156
self .log_likelihood = log_likelihood
198
157
self .predictions = predictions
199
- self .index_origin = rcParams ["data.index_origin" ] if index_origin is None else index_origin
200
158
201
159
def arbitrary_element (dct : Dict [Any , np .ndarray ]) -> np .ndarray :
202
160
return next (iter (dct .values ()))
@@ -344,15 +302,13 @@ def posterior_to_xarray(self):
344
302
coords = self .coords ,
345
303
dims = self .dims ,
346
304
attrs = self .attrs ,
347
- index_origin = self .index_origin ,
348
305
),
349
306
dict_to_dataset (
350
307
data_warmup ,
351
308
library = pymc ,
352
309
coords = self .coords ,
353
310
dims = self .dims ,
354
311
attrs = self .attrs ,
355
- index_origin = self .index_origin ,
356
312
),
357
313
)
358
314
@@ -386,15 +342,13 @@ def sample_stats_to_xarray(self):
386
342
dims = None ,
387
343
coords = self .coords ,
388
344
attrs = self .attrs ,
389
- index_origin = self .index_origin ,
390
345
),
391
346
dict_to_dataset (
392
347
data_warmup ,
393
348
library = pymc ,
394
349
dims = None ,
395
350
coords = self .coords ,
396
351
attrs = self .attrs ,
397
- index_origin = self .index_origin ,
398
352
),
399
353
)
400
354
@@ -427,15 +381,13 @@ def log_likelihood_to_xarray(self):
427
381
dims = self .dims ,
428
382
coords = self .coords ,
429
383
skip_event_dims = True ,
430
- index_origin = self .index_origin ,
431
384
),
432
385
dict_to_dataset (
433
386
data_warmup ,
434
387
library = pymc ,
435
388
dims = self .dims ,
436
389
coords = self .coords ,
437
390
skip_event_dims = True ,
438
- index_origin = self .index_origin ,
439
391
),
440
392
)
441
393
@@ -456,9 +408,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
456
408
"This can mean that some draws or even whole chains are not represented." ,
457
409
k ,
458
410
)
459
- return dict_to_dataset (
460
- data , library = pymc , coords = self .coords , dims = self .dims , index_origin = self .index_origin
461
- )
411
+ return dict_to_dataset (data , library = pymc , coords = self .coords , dims = self .dims )
462
412
463
413
@requires (["posterior_predictive" ])
464
414
def posterior_predictive_to_xarray (self ):
@@ -493,7 +443,6 @@ def priors_to_xarray(self):
493
443
library = pymc ,
494
444
coords = self .coords ,
495
445
dims = self .dims ,
496
- index_origin = self .index_origin ,
497
446
)
498
447
)
499
448
return priors_dict
@@ -510,7 +459,6 @@ def observed_data_to_xarray(self):
510
459
coords = self .coords ,
511
460
dims = self .dims ,
512
461
default_dims = [],
513
- index_origin = self .index_origin ,
514
462
)
515
463
516
464
@requires (["trace" , "predictions" ])
@@ -557,7 +505,6 @@ def is_data(name, var) -> bool:
557
505
coords = self .coords ,
558
506
dims = self .dims ,
559
507
default_dims = [],
560
- index_origin = self .index_origin ,
561
508
)
562
509
563
510
def to_inference_data (self ):
0 commit comments