Skip to content

Commit 5802f12

Browse files
michaelosthegealoctavodia
authored andcommitted
Make trace backend related type hints more generic
* Specify covariant input types in `StatsBijection`. * Annotate `_choose_chains` to be independent of `BaseTrace` type.
1 parent ce66620 commit 5802f12

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

pymc/backends/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,18 @@
2121
import warnings
2222

2323
from abc import ABC
24-
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24+
from typing import (
25+
Dict,
26+
List,
27+
Optional,
28+
Sequence,
29+
Set,
30+
Sized,
31+
Tuple,
32+
TypeVar,
33+
Union,
34+
cast,
35+
)
2536

2637
import numpy as np
2738

@@ -510,7 +521,10 @@ def _squeeze_cat(results, combine, squeeze):
510521
return results
511522

512523

513-
def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]:
524+
S = TypeVar("S", bound=Sized)
525+
526+
527+
def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]:
514528
"""
515529
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
516530

pymc/step_methods/compound.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from abc import ABC, abstractmethod
2222
from enum import IntEnum, unique
23-
from typing import Dict, List, Sequence, Tuple, Union
23+
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union
2424

2525
import numpy as np
2626

@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
181181
class StatsBijection:
182182
"""Map between a `list` of stats to `dict` of stats."""
183183

184-
def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
184+
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
185185
# Keep a list of flat vs. original stat names
186186
self._stat_groups: List[List[Tuple[str, str]]] = [
187187
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
188188
for s, names_dtypes in enumerate(sampler_stats_dtypes)
189189
]
190190

191-
def map(self, stats_list: StatsType) -> StatsDict:
191+
def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
192192
"""Combine stats dicts of multiple samplers into one dict."""
193193
stats_dict = {}
194194
for s, sts in enumerate(stats_list):
@@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict:
197197
stats_dict[sname] = sval
198198
return stats_dict
199199

200-
def rmap(self, stats_dict: StatsDict) -> StatsType:
200+
def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
201201
"""Split a global stats dict into a list of sampler-wise stats dicts."""
202202
stats_list = []
203203
for namemap in self._stat_groups:

0 commit comments

Comments
 (0)