Skip to content

Commit db65421

Browse files
michaelosthegealoctavodia
authored andcommitted
Add type hints to SMC code
1 parent 22b4446 commit db65421

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

pymc/smc/kernels.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
import warnings
1717

1818
from abc import ABC
19-
from typing import Dict, cast
19+
from typing import Dict, Union, cast
2020

2121
import numpy as np
2222
import pytensor.tensor as at
2323

2424
from pytensor.graph.replace import clone_replace
2525
from scipy.special import logsumexp
2626
from scipy.stats import multivariate_normal
27+
from typing_extensions import TypeAlias
2728

2829
from pymc.backends.ndarray import NDArray
2930
from pymc.blocking import DictToArrayBijection
@@ -39,6 +40,9 @@
3940
from pymc.step_methods.metropolis import MultivariateNormalProposal
4041
from pymc.vartypes import discrete_types
4142

43+
SMCStats: TypeAlias = Dict[str, Union[int, float]]
44+
SMCSettings: TypeAlias = Dict[str, Union[int, float]]
45+
4246

4347
class SMC_KERNEL(ABC):
4448
"""Base class for the Sequential Monte Carlo kernels.
@@ -304,7 +308,7 @@ def mutate(self):
304308
"""Apply kernel-specific perturbation to the particles once per stage"""
305309
pass
306310

307-
def sample_stats(self) -> Dict:
311+
def sample_stats(self) -> SMCStats:
308312
"""Stats to be saved at the end of each stage
309313
310314
These stats will be saved under `sample_stats` in the final InferenceData object.
@@ -314,7 +318,7 @@ def sample_stats(self) -> Dict:
314318
"beta": self.beta,
315319
}
316320

317-
def sample_settings(self) -> Dict:
321+
def sample_settings(self) -> SMCSettings:
318322
"""SMC_kernel settings to be saved once at the end of sampling.
319323
320324
These stats will be saved under `sample_stats` in the final InferenceData object.
@@ -425,7 +429,7 @@ def mutate(self):
425429

426430
self.acc_rate = np.mean(ac_)
427431

428-
def sample_stats(self):
432+
def sample_stats(self) -> SMCStats:
429433
stats = super().sample_stats()
430434
stats.update(
431435
{
@@ -434,7 +438,7 @@ def sample_stats(self):
434438
)
435439
return stats
436440

437-
def sample_settings(self):
441+
def sample_settings(self) -> SMCSettings:
438442
stats = super().sample_settings()
439443
stats.update(
440444
{
@@ -543,7 +547,7 @@ def mutate(self):
543547

544548
self.chain_acc_rate = np.mean(ac_, axis=0)
545549

546-
def sample_stats(self):
550+
def sample_stats(self) -> SMCStats:
547551
stats = super().sample_stats()
548552
stats.update(
549553
{
@@ -553,7 +557,7 @@ def sample_stats(self):
553557
)
554558
return stats
555559

556-
def sample_settings(self):
560+
def sample_settings(self) -> SMCSettings:
557561
stats = super().sample_settings()
558562
stats.update(
559563
{

pymc/smc/sampling.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from collections import defaultdict
2121
from itertools import repeat
22+
from typing import Any, Dict, Optional, Tuple, Union
2223

2324
import cloudpickle
2425
import numpy as np
@@ -30,7 +31,7 @@
3031

3132
from pymc.backends.arviz import dict_to_dataset, to_inference_data
3233
from pymc.backends.base import MultiTrace
33-
from pymc.model import modelcontext
34+
from pymc.model import Model, modelcontext
3435
from pymc.sampling.parallel import _cpu_count
3536
from pymc.smc.kernels import IMH
3637
from pymc.util import RandomState, _get_seeds_per_chain
@@ -50,7 +51,7 @@ def sample_smc(
5051
idata_kwargs=None,
5152
progressbar=True,
5253
**kernel_kwargs,
53-
):
54+
) -> Union[InferenceData, MultiTrace]:
5455
r"""
5556
Sequential Monte Carlo based sampling.
5657
@@ -237,19 +238,23 @@ def sample_smc(
237238

238239
if compute_convergence_checks:
239240
_compute_convergence_checks(idata, draws, model, trace)
240-
return idata if return_inferencedata else trace
241+
242+
if return_inferencedata:
243+
assert idata is not None
244+
return idata
245+
return trace
241246

242247

243248
def _save_sample_stats(
244249
sample_settings,
245250
sample_stats,
246251
chains,
247-
trace,
248-
return_inferencedata,
252+
trace: MultiTrace,
253+
return_inferencedata: bool,
249254
_t_sampling,
250255
idata_kwargs,
251-
model,
252-
):
256+
model: Model,
257+
) -> Tuple[Optional[Any], Optional[InferenceData]]:
253258
sample_settings_dict = sample_settings[0]
254259
sample_settings_dict["_t_sampling"] = _t_sampling
255260
sample_stats_dict = sample_stats[0]
@@ -262,12 +267,12 @@ def _save_sample_stats(
262267
value_list.append(chain_sample_stats[stat])
263268
sample_stats_dict[stat] = value_list
264269

270+
idata: Optional[InferenceData] = None
265271
if not return_inferencedata:
266272
for stat, value in sample_stats_dict.items():
267273
setattr(trace.report, stat, value)
268274
for stat, value in sample_settings_dict.items():
269275
setattr(trace.report, stat, value)
270-
idata = None
271276
else:
272277
for stat, value in sample_stats_dict.items():
273278
if chains > 1:
@@ -284,7 +289,7 @@ def _save_sample_stats(
284289
library=pymc,
285290
)
286291

287-
ikwargs = dict(model=model)
292+
ikwargs: Dict[str, Any] = dict(model=model)
288293
if idata_kwargs is not None:
289294
ikwargs.update(idata_kwargs)
290295
idata = to_inference_data(trace, **ikwargs)
@@ -293,7 +298,9 @@ def _save_sample_stats(
293298
return sample_stats, idata
294299

295300

296-
def _compute_convergence_checks(idata, draws, model, trace):
301+
def _compute_convergence_checks(
302+
idata: Optional[InferenceData], draws: int, model: Model, trace: MultiTrace
303+
):
297304
if draws < 100:
298305
warnings.warn(
299306
"The number of samples is too small to check convergence reliably.",

0 commit comments

Comments
 (0)