20
20
21
21
from abc import ABC , abstractmethod
22
22
from enum import IntEnum , unique
23
- from typing import Dict , List , Sequence , Tuple , Union
23
+ from typing import Any , Dict , List , Mapping , Sequence , Tuple , Union
24
24
25
25
import numpy as np
26
26
@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
181
181
class StatsBijection :
182
182
"""Map between a `list` of stats to `dict` of stats."""
183
183
184
- def __init__ (self , sampler_stats_dtypes : Sequence [Dict [str , type ]]) -> None :
184
+ def __init__ (self , sampler_stats_dtypes : Sequence [Mapping [str , type ]]) -> None :
185
185
# Keep a list of flat vs. original stat names
186
186
self ._stat_groups : List [List [Tuple [str , str ]]] = [
187
187
[(f"sampler_{ s } __{ statname } " , statname ) for statname , _ in names_dtypes .items ()]
188
188
for s , names_dtypes in enumerate (sampler_stats_dtypes )
189
189
]
190
190
191
- def map (self , stats_list : StatsType ) -> StatsDict :
191
+ def map (self , stats_list : Sequence [ Mapping [ str , Any ]] ) -> StatsDict :
192
192
"""Combine stats dicts of multiple samplers into one dict."""
193
193
stats_dict = {}
194
194
for s , sts in enumerate (stats_list ):
@@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict:
197
197
stats_dict [sname ] = sval
198
198
return stats_dict
199
199
200
- def rmap (self , stats_dict : StatsDict ) -> StatsType :
200
+ def rmap (self , stats_dict : Mapping [ str , Any ] ) -> StatsType :
201
201
"""Split a global stats dict into a list of sampler-wise stats dicts."""
202
202
stats_list = []
203
203
for namemap in self ._stat_groups :
0 commit comments