@@ -138,7 +138,6 @@ def __init__(
138
138
dims : Optional [DimSpec ] = None ,
139
139
model = None ,
140
140
save_warmup : Optional [bool ] = None ,
141
- density_dist_obs : bool = True ,
142
141
):
143
142
144
143
self .save_warmup = rcParams ["data.save_warmup" ] if save_warmup is None else save_warmup
@@ -175,29 +174,11 @@ def __init__(
175
174
self .log_likelihood = log_likelihood
176
175
self .predictions = predictions
177
176
178
- def arbitrary_element (dct : Dict [Any , np .ndarray ]) -> np .ndarray :
179
- return next (iter (dct .values ()))
180
-
181
- if trace is None :
182
- # if you have a posterior_predictive built with keep_dims,
183
- # you'll lose here, but there's nothing I can do about that.
184
- self .nchains = 1
185
- get_from = None
186
- if predictions is not None :
187
- get_from = predictions
188
- elif posterior_predictive is not None :
189
- get_from = posterior_predictive
190
- elif prior is not None :
191
- get_from = prior
192
- if get_from is None :
193
- # pylint: disable=line-too-long
194
- raise ValueError (
195
- "When constructing InferenceData must have at least"
196
- " one of trace, prior, posterior_predictive or predictions."
197
- )
198
-
199
- aelem = arbitrary_element (get_from )
200
- self .ndraws = aelem .shape [0 ]
177
+ if all (elem is None for elem in (trace , predictions , posterior_predictive , prior )):
178
+ raise ValueError (
179
+ "When constructing InferenceData you must pass at least"
180
+ " one of trace, prior, posterior_predictive or predictions."
181
+ )
201
182
202
183
self .coords = {** self .model .coords , ** (coords or {})}
203
184
self .coords = {
@@ -214,7 +195,6 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
214
195
}
215
196
self .dims = {** model_dims , ** self .dims }
216
197
217
- self .density_dist_obs = density_dist_obs
218
198
self .observations = find_observations (self .model )
219
199
220
200
def split_trace (self ) -> Tuple [Union [None , "MultiTrace" ], Union [None , "MultiTrace" ]]:
@@ -396,34 +376,35 @@ def log_likelihood_to_xarray(self):
396
376
),
397
377
)
398
378
399
- def translate_posterior_predictive_dict_to_xarray (self , dct ) -> xr .Dataset :
379
+ def translate_posterior_predictive_dict_to_xarray (self , dct , kind ) -> xr .Dataset :
400
380
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
401
381
data = {}
382
+ warning_vars = []
402
383
for k , ary in dct .items ():
403
- shape = ary .shape
404
- if shape [0 ] == self .nchains and shape [1 ] == self .ndraws :
384
+ if (ary .shape [0 ] == self .nchains ) and (ary .shape [1 ] == self .ndraws ):
405
385
data [k ] = ary
406
- elif shape [0 ] == self .nchains * self .ndraws :
407
- data [k ] = ary .reshape ((self .nchains , self .ndraws , * shape [1 :]))
408
386
else :
409
387
data [k ] = np .expand_dims (ary , 0 )
410
- # pylint: disable=line-too-long
411
- _log .warning (
412
- "posterior predictive variable %s's shape not compatible with number of chains and draws. "
413
- "This can mean that some draws or even whole chains are not represented." ,
414
- k ,
415
- )
388
+ warning_vars .append (k )
389
+ warnings .warn (
390
+ f"The shape of variables { ', ' .join (warning_vars )} in { kind } group is not compatible "
391
+ "with number of chains and draws. The automatic dimension naming might not have worked. "
392
+ "This can also mean that some draws or even whole chains are not represented." ,
393
+ UserWarning ,
394
+ )
416
395
return dict_to_dataset (data , library = pymc , coords = self .coords , dims = self .dims )
417
396
418
397
@requires (["posterior_predictive" ])
419
398
def posterior_predictive_to_xarray (self ):
420
399
"""Convert posterior_predictive samples to xarray."""
421
- return self .translate_posterior_predictive_dict_to_xarray (self .posterior_predictive )
400
+ return self .translate_posterior_predictive_dict_to_xarray (
401
+ self .posterior_predictive , "posterior_predictive"
402
+ )
422
403
423
404
@requires (["predictions" ])
424
405
def predictions_to_xarray (self ):
425
406
"""Convert predictions (out of sample predictions) to xarray."""
426
- return self .translate_posterior_predictive_dict_to_xarray (self .predictions )
407
+ return self .translate_posterior_predictive_dict_to_xarray (self .predictions , "predictions" )
427
408
428
409
def priors_to_xarray (self ):
429
410
"""Convert prior samples (and if possible prior predictive too) to xarray."""
@@ -545,7 +526,6 @@ def to_inference_data(
545
526
dims : Optional [DimSpec ] = None ,
546
527
model : Optional ["Model" ] = None ,
547
528
save_warmup : Optional [bool ] = None ,
548
- density_dist_obs : bool = True ,
549
529
) -> InferenceData :
550
530
"""Convert pymc data into an InferenceData object.
551
531
@@ -578,9 +558,6 @@ def to_inference_data(
578
558
save_warmup : bool, optional
579
559
Save warmup iterations InferenceData object. If not defined, use default
580
560
defined by the rcParams.
581
- density_dist_obs : bool, default True
582
- Store variables passed with ``observed`` arg to
583
- :class:`~pymc.distributions.DensityDist` in the generated InferenceData.
584
561
585
562
Returns
586
563
-------
@@ -598,7 +575,6 @@ def to_inference_data(
598
575
dims = dims ,
599
576
model = model ,
600
577
save_warmup = save_warmup ,
601
- density_dist_obs = density_dist_obs ,
602
578
).to_inference_data ()
603
579
604
580
@@ -620,6 +596,7 @@ def predictions_to_inference_data(
620
596
predictions: Dict[str, np.ndarray]
621
597
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
622
598
a dictionary of strings (variable names) to numpy ndarrays (draws).
599
+ Requires the arrays to follow the convention ``chain, draw, *shape``.
623
600
posterior_trace: MultiTrace
624
601
This should be a trace that has been thinned appropriately for
625
602
``pymc.sample_posterior_predictive``. Specifically, any variable whose shape is
@@ -648,14 +625,21 @@ def predictions_to_inference_data(
648
625
raise ValueError (
649
626
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
650
627
)
651
- new_idata = InferenceDataConverter (
628
+ converter = InferenceDataConverter (
652
629
trace = posterior_trace ,
653
630
predictions = predictions ,
654
631
model = model ,
655
632
coords = coords ,
656
633
dims = dims ,
657
634
log_likelihood = False ,
658
- ).to_inference_data ()
635
+ )
636
+ if hasattr (idata_orig , "posterior" ):
637
+ converter .nchains = idata_orig .posterior .dims ["chain" ]
638
+ converter .ndraws = idata_orig .posterior .dims ["draw" ]
639
+ else :
640
+ aelem = next (iter (predictions .values ()))
641
+ converter .nchains , converter .ndraws = aelem .shape [:2 ]
642
+ new_idata = converter .to_inference_data ()
659
643
if idata_orig is None :
660
644
return new_idata
661
645
elif inplace :
0 commit comments