Skip to content

Refactoring towards IBaseTrace interfaces #6475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@

"""
from copy import copy
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence, Union

import numpy as np

from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace
from pymc.backends.ndarray import NDArray, point_list_to_multitrace
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

__all__ = ["to_inference_data", "predictions_to_inference_data"]

Expand All @@ -76,7 +80,7 @@ def _init_trace(
chain_number: int,
stats_dtypes: List[Dict[str, type]],
trace: Optional[BaseTrace],
model,
model: Model,
) -> BaseTrace:
"""Initializes a trace backend for a chain."""
strace: BaseTrace
Expand All @@ -91,3 +95,26 @@ def _init_trace(

strace.setup(expected_length, chain_number, stats_dtypes)
return strace


def init_traces(
*,
backend: Optional[BaseTrace],
chains: int,
expected_length: int,
step: Union[BlockedStep, CompoundStep],
var_dtypes: Dict[str, np.dtype],
var_shapes: Dict[str, Sequence[int]],
model: Model,
) -> Sequence[IBaseTrace]:
"""Initializes a trace recorder for each chain."""
return [
_init_trace(
expected_length=expected_length,
stats_dtypes=step.stats_dtypes,
chain_number=chain_number,
trace=backend,
model=model,
)
for chain_number in range(chains)
]
139 changes: 87 additions & 52 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

from abc import ABC
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Expand All @@ -47,7 +49,87 @@ class BackendError(Exception):
pass


class BaseTrace(ABC):
class IBaseTrace(ABC, Sized):
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""

chain: int
"""Chain number."""

varnames: List[str]
"""Names of tracked variables."""

sampler_vars: List[Dict[str, type]]
"""Sampler stats for each sampler."""

def __len__(self):
raise NotImplementedError()

def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""Get values from trace.

Parameters
----------
varname: str
burn: int
thin: int

Returns
-------
A NumPy array
"""
raise NotImplementedError()

def get_sampler_stats(self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1):
"""Get sampler statistics from the trace.

Parameters
----------
stat_name: str
sampler_idx: int or None
burn: int
thin: int

Returns
-------
If the `sampler_idx` is specified, return the statistic with
the given name in a numpy array. If it is not specified and there
is more than one sampler that provides this statistic, return
a numpy array of shape (m, n), where `m` is the number of
such samplers, and `n` is the number of samples.
"""
raise NotImplementedError()

def _slice(self, idx: slice) -> "IBaseTrace":
"""Slice trace object."""
raise NotImplementedError()

def point(self, idx: int) -> Dict[str, np.ndarray]:
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError()

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
"""Record results of a sampling iteration.

Parameters
----------
draw: dict
Values mapped to variable names
stats: list of dicts
The diagnostic values for each sampler
"""
raise NotImplementedError()

def close(self):
"""Close the backend.

This is called after sampling has finished.
"""
pass


class BaseTrace(IBaseTrace):
"""Base trace object

Parameters
Expand Down Expand Up @@ -127,25 +209,6 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
self._set_sampler_vars(sampler_vars)
self._is_base_setup = True

def record(self, point, sampler_states=None):
"""Record results of a sampling iteration.

Parameters
----------
point: dict
Values mapped to variable names
sampler_states: list of dicts
The diagnostic values for each sampler
"""
raise NotImplementedError

def close(self):
"""Close the database backend.

This is called after sampling has finished.
"""
pass

# Selection methods

def __getitem__(self, idx):
Expand All @@ -157,24 +220,6 @@ def __getitem__(self, idx):
except (ValueError, TypeError): # Passed variable or variable name.
raise ValueError("Can only index with slice or integer")

def __len__(self):
raise NotImplementedError

def get_values(self, varname, burn=0, thin=1):
"""Get values from trace.

Parameters
----------
varname: str
burn: int
thin: int

Returns
-------
A NumPy array
"""
raise NotImplementedError

def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
"""Get sampler statistics from the trace.

Expand Down Expand Up @@ -220,19 +265,9 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
"""Get sampler statistics."""
raise NotImplementedError()

def _slice(self, idx: Union[int, slice]):
"""Slice trace object."""
raise NotImplementedError()

def point(self, idx: int) -> Dict[str, np.ndarray]:
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError()

@property
def stat_names(self) -> Set[str]:
names = set()
names: Set[str] = set()
for vars in self.sampler_vars or []:
names.update(vars.keys())

Expand Down Expand Up @@ -290,7 +325,7 @@ class MultiTrace:
List of variable names in the trace(s)
"""

def __init__(self, straces: Sequence[BaseTrace]):
def __init__(self, straces: Sequence[IBaseTrace]):
if len({t.chain for t in straces}) != len(straces):
raise ValueError("Chains are not unique.")
self._straces = {t.chain: t for t in straces}
Expand Down Expand Up @@ -386,7 +421,7 @@ def stat_names(self) -> Set[str]:
sampler_vars = [s.sampler_vars for s in self._straces.values()]
if not all(svars == sampler_vars[0] for svars in sampler_vars):
raise ValueError("Inividual chains contain different sampler stats")
names = set()
names: Set[str] = set()
for trace in self._straces.values():
if trace.sampler_vars is None:
continue
Expand Down Expand Up @@ -472,7 +507,7 @@ def get_sampler_stats(
]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, slice):
def _slice(self, slice: slice):
"""Return a new MultiTrace object sliced according to `slice`."""
new_traces = [trace._slice(slice) for trace in self._straces.values()]
trace = MultiTrace(new_traces)
Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""
return self.samples[varname][burn::thin]

def _slice(self, idx):
def _slice(self, idx: slice):
# Slicing directly instead of using _slice_as_ndarray to
# support stop value in slice (which is needed by
# iter_sample).
Expand All @@ -174,7 +174,7 @@ def _slice(self, idx):
return sliced
sliced._stats = []
for vars in self._stats:
var_sliced = {}
var_sliced: Dict[str, np.ndarray] = {}
sliced._stats.append(var_sliced)
for key, vals in vars.items():
var_sliced[key] = vals[idx]
Expand Down
47 changes: 23 additions & 24 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

import pymc as pm

from pymc.backends import _init_trace
from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains
from pymc.backends import init_traces
from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
Expand Down Expand Up @@ -71,7 +71,7 @@
class SamplingIteratorCallback(Protocol):
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""

def __call__(self, trace: BaseTrace, draw: Draw):
def __call__(self, trace: IBaseTrace, draw: Draw):
pass


Expand Down Expand Up @@ -486,21 +486,21 @@ def sample(
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]

# One final check that shapes and logps at the starting points are okay.
ip: Dict[str, np.ndarray]
for ip in initial_points:
model.check_start_vals(ip)
_check_start_shape(model, ip)

# Create trace backends for each chain
traces = [
_init_trace(
expected_length=draws + tune,
stats_dtypes=step.stats_dtypes,
chain_number=chain_number,
trace=trace,
model=model,
)
for chain_number in range(chains)
]
traces = init_traces(
backend=trace,
chains=chains,
expected_length=draws + tune,
step=step,
var_dtypes={vn: v.dtype for vn, v in ip.items()},
var_shapes={vn: v.shape for vn, v in ip.items()},
model=model,
)

sample_args = {
"draws": draws,
Expand Down Expand Up @@ -657,7 +657,7 @@ def _sample_many(
*,
draws: int,
chains: int,
traces: Sequence[BaseTrace],
traces: Sequence[IBaseTrace],
start: Sequence[PointType],
random_seed: Optional[Sequence[RandomSeed]],
step: Step,
Expand Down Expand Up @@ -701,7 +701,7 @@ def _sample(
start: PointType,
draws: int,
step: Step,
trace: BaseTrace,
trace: IBaseTrace,
tune: int,
model: Optional[Model] = None,
callback=None,
Expand All @@ -726,8 +726,8 @@ def _sample(
The number of samples to draw
step : function
Step function
trace : backend, optional
A backend instance.
trace
A chain backend to record draws and stats.
tune : int
Number of iterations to tune.
model : Model (optional if in ``with`` context)
Expand Down Expand Up @@ -767,7 +767,7 @@ def _iter_sample(
draws: int,
step: Step,
start: PointType,
trace: BaseTrace,
trace: IBaseTrace,
chain: int = 0,
tune: int = 0,
model: Optional[Model] = None,
Expand All @@ -785,8 +785,8 @@ def _iter_sample(
start : dict
Starting point in parameter space (or partial point).
Must contain numeric (transformed) initial values for all (transformed) free variables.
trace : backend
A backend instance.
trace
A chain backend to record draws and stats.
chain : int, optional
Chain number used to store sample in backend.
tune : int, optional
Expand Down Expand Up @@ -852,7 +852,7 @@ def _mp_sample(
random_seed: Sequence[RandomSeed],
start: Sequence[PointType],
progressbar: bool = True,
traces: Sequence[BaseTrace],
traces: Sequence[IBaseTrace],
model: Optional[Model] = None,
callback: Optional[SamplingIteratorCallback] = None,
mp_ctx=None,
Expand All @@ -879,9 +879,8 @@ def _mp_sample(
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
progressbar : bool
Whether or not to display a progress bar in the command line.
trace : BaseTrace, optional
A backend instance, or None.
If None, the NDArray backend is used.
traces
Recording backends for each chain.
model : Model (optional if in ``with`` context)
callback
A function which gets called for every sample from the trace of a chain. The function is
Expand Down
Loading