Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 16 additions & 134 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@
import numpy as np
from numpy.random import default_rng

from cmdstanpy import _TMPDIR
from cmdstanpy.utils import (
cmdstan_path,
cmdstan_version_before,
create_named_text_file,
get_logger,
read_metric,
write_stan_json,
)
from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger

OptionalPath = Union[str, os.PathLike, None]

Expand Down Expand Up @@ -65,9 +57,8 @@ def __init__(
save_warmup: bool = False,
thin: Optional[int] = None,
max_treedepth: Optional[int] = None,
metric: Union[
str, dict[str, Any], list[str], list[dict[str, Any]], None
] = None,
metric_type: Optional[str] = None,
metric_file: Union[str, list[str], None] = None,
step_size: Union[float, list[float], None] = None,
adapt_engaged: bool = True,
adapt_delta: Optional[float] = None,
Expand All @@ -83,9 +74,8 @@ def __init__(
self.save_warmup = save_warmup
self.thin = thin
self.max_treedepth = max_treedepth
self.metric = metric
self.metric_type: Optional[str] = None
self.metric_file: Union[str, list[str], None] = None
self.metric_type: Optional[str] = metric_type
self.metric_file: Union[str, list[str], None] = metric_file
self.step_size = step_size
self.adapt_engaged = adapt_engaged
self.adapt_delta = adapt_delta
Expand Down Expand Up @@ -178,124 +168,15 @@ def validate(self, chains: Optional[int]) -> None:
'Argument "step_size" must be > 0, '
'chain {}, found {}.'.format(i + 1, step_size)
)
if self.metric is not None:
if isinstance(self.metric, str):
if self.metric in ['diag', 'diag_e']:
self.metric_type = 'diag_e'
elif self.metric in ['dense', 'dense_e']:
self.metric_type = 'dense_e'
elif self.metric in ['unit', 'unit_e']:
self.metric_type = 'unit_e'
else:
if not os.path.exists(self.metric):
raise ValueError('no such file {}'.format(self.metric))
dims = read_metric(self.metric)
if len(dims) == 1:
self.metric_type = 'diag_e'
else:
self.metric_type = 'dense_e'
self.metric_file = self.metric
elif isinstance(self.metric, dict):
if 'inv_metric' not in self.metric:
raise ValueError(
'Entry "inv_metric" not found in metric dict.'
)
dims = list(np.asarray(self.metric['inv_metric']).shape)
if len(dims) == 1:
self.metric_type = 'diag_e'
else:
self.metric_type = 'dense_e'
dict_file = create_named_text_file(
dir=_TMPDIR, prefix="metric", suffix=".json"
)
write_stan_json(dict_file, self.metric)
self.metric_file = dict_file
elif isinstance(self.metric, (list, tuple)):
if len(self.metric) != chains:
raise ValueError(
'Number of metric files must match number of chains,'
' found {} metric files for {} chains.'.format(
len(self.metric), chains
)
)
if all(isinstance(elem, dict) for elem in self.metric):
metric_files: list[str] = []
for i, metric in enumerate(self.metric):
metric_dict: dict[str, Any] = metric # type: ignore
if 'inv_metric' not in metric_dict:
raise ValueError(
'Entry "inv_metric" not found in metric dict '
'for chain {}.'.format(i + 1)
)
if i == 0:
dims = list(
np.asarray(metric_dict['inv_metric']).shape
)
else:
dims2 = list(
np.asarray(metric_dict['inv_metric']).shape
)
if dims != dims2:
raise ValueError(
'Found inconsistent "inv_metric" entry '
'for chain {}: entry has dims '
'{}, expected {}.'.format(
i + 1, dims, dims2
)
)
dict_file = create_named_text_file(
dir=_TMPDIR, prefix="metric", suffix=".json"
)
write_stan_json(dict_file, metric_dict)
metric_files.append(dict_file)
if len(dims) == 1:
self.metric_type = 'diag_e'
else:
self.metric_type = 'dense_e'
self.metric_file = metric_files
elif all(isinstance(elem, str) for elem in self.metric):
metric_files = []
for i, metric in enumerate(self.metric):
assert isinstance(metric, str) # typecheck
if not os.path.exists(metric):
raise ValueError('no such file {}'.format(metric))
if i == 0:
dims = read_metric(metric)
else:
dims2 = read_metric(metric)
if len(dims) != len(dims2):
raise ValueError(
'Metrics files {}, {},'
' inconsistent metrics'.format(
self.metric[0], metric
)
)
if dims != dims2:
raise ValueError(
'Metrics files {}, {},'
' inconsistent metrics'.format(
self.metric[0], metric
)
)
metric_files.append(metric)
if len(dims) == 1:
self.metric_type = 'diag_e'
else:
self.metric_type = 'dense_e'
self.metric_file = metric_files
else:
raise ValueError(
'Argument "metric" must be a list of pathnames or '
'Python dicts, found list of {}.'.format(
type(self.metric[0])
)
)
else:
if self.metric_type is not None:
if self.metric_type in ['diag', 'dense', 'unit']:
self.metric_type += '_e'
if self.metric_type not in ['diag_e', 'dense_e', 'unit_e']:
raise ValueError(
'Invalid metric specified, not a recognized metric type, '
'must be either a metric type name, a filepath, dict, '
'or list of per-chain filepaths or dicts. Found '
'an object of type {}.'.format(type(self.metric))
'Argument "metric" must be one of [diag, dense, unit,'
' diag_e, dense_e, unit_e], found {}.'.format(
self.metric_type
)
)

if self.adapt_delta is not None:
Expand Down Expand Up @@ -332,7 +213,8 @@ def validate(self, chains: Optional[int]) -> None:

if self.fixed_param and (
self.max_treedepth is not None
or self.metric is not None
or self.metric_type is not None
or self.metric_file is not None
or self.step_size is not None
or not (
self.adapt_delta is None
Expand Down Expand Up @@ -371,7 +253,7 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
cmd.append(f'stepsize={self.step_size}')
else:
cmd.append(f'stepsize={self.step_size[idx]}')
if self.metric is not None:
if self.metric_type is not None:
cmd.append(f'metric={self.metric_type}')
if self.metric_file is not None:
if not isinstance(self.metric_file, list):
Expand Down
133 changes: 93 additions & 40 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Union,
)

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

Expand Down Expand Up @@ -55,7 +56,12 @@
get_logger,
returncode_msg,
)
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json
from cmdstanpy.utils.filesystem import (
temp_inits,
temp_metrics,
temp_single_json,
)
from cmdstanpy.utils.stancsv import try_deduce_metric_type

from . import progress as progbar

Expand Down Expand Up @@ -697,6 +703,13 @@ def sample(
timeout: Optional[float] = None,
*,
force_one_process_per_chain: Optional[bool] = None,
inv_metric: Union[
str,
np.ndarray,
Mapping[str, Any],
list[Union[str, np.ndarray, Mapping[str, Any]]],
None,
] = None,
) -> CmdStanMCMC:
"""
Run or more chains of the NUTS-HMC sampler to produce a set of draws
Expand Down Expand Up @@ -785,29 +798,25 @@ def sample(
:param max_treedepth: Maximum depth of trees evaluated by NUTS sampler
per iteration.

:param metric: Specification of the mass matrix, either as a
vector consisting of the diagonal elements of the covariance
matrix ('diag' or 'diag_e') or the full covariance matrix
('dense' or 'dense_e').

If the value of the metric argument is a string other than
'diag', 'diag_e', 'dense', or 'dense_e', it must be
a valid filepath to a JSON or Rdump file which contains an entry
'inv_metric' whose value is either the diagonal vector or
the full covariance matrix.

If the value of the metric argument is a list of paths, its
length must match the number of chains and all paths must be
unique.

If the value of the metric argument is a Python dict object, it
must contain an entry 'inv_metric' which specifies either the
diagnoal or dense matrix.

If the value of the metric argument is a list of Python dicts,
its length must match the number of chains and all dicts must
containan entry 'inv_metric' and all 'inv_metric' entries must
have the same shape.
:param metric: Specify the type of the inverse mass matrix. Options are
'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e'
for a dense matrix, or 'unit_e' an identity mass matrix. To provide
an initial value for the inverse mass matrix, use the ``inv_metric``
argument.

:param inv_metric: Provide an initial value for the inverse
mass matrix.

Valid options include:
- a string, which must be a valid filepath to a JSON or
Rdump file which contains an entry 'inv_metric' whose value
is either a diagonal vector or dense matrix.
- a numpy array containing either the diagonal vector or dense
matrix.
- a dictionary containing an entry 'inv_metric' whose value
is either a diagonal vector or dense matrix.
- a list of any of the above, of length num_chains, with
the same shape of metric in each entry.

:param step_size: Initial step size for HMC sampler. The value is
either a single number or a list of numbers which will be used
Expand Down Expand Up @@ -1001,35 +1010,79 @@ def sample(
'Chain_id must be a non-negative integer value,'
' found {}.'.format(chain_id)
)
if metric is not None and metric not in (
'diag',
'dense',
'unit_e',
'diag_e',
'dense_e',
):
get_logger().warning(
"Providing anything other than metric type for"
" 'metric' is deprecated and will be removed"
" in the next major release."
" Please provide such information via"
" 'inv_metric' argument."
)
if inv_metric is not None:
raise ValueError(
"Cannot provide both (deprecated) non-metric-type 'metric'"
" argument and 'inv_metric' argument."
)
inv_metric = metric # type: ignore # for backwards compatibility
metric = None

sampler_args = SamplerArgs(
num_chains=1 if one_process_per_chain else chains,
iter_warmup=iter_warmup,
iter_sampling=iter_sampling,
save_warmup=save_warmup,
thin=thin,
max_treedepth=max_treedepth,
metric=metric,
step_size=step_size,
adapt_engaged=adapt_engaged,
adapt_delta=adapt_delta,
adapt_init_phase=adapt_init_phase,
adapt_metric_window=adapt_metric_window,
adapt_step_size=adapt_step_size,
fixed_param=fixed_param,
)
if metric is None and inv_metric is not None:
metric = try_deduce_metric_type(inv_metric)

if isinstance(inv_metric, list):
if not len(inv_metric) == chains:
raise ValueError(
'Number of metric files must match number of chains,'
' found {} metric files for {} chains.'.format(
len(inv_metric), chains
)
)

with (
temp_single_json(data) as _data,
temp_inits(inits, id=chain_ids[0]) as _inits,
temp_metrics(inv_metric, id=chain_ids[0]) as _inv_metric,
):
cmdstan_inits: Union[str, list[str], int, float, None]
cmdstan_metrics: Union[str, list[str], None]

if one_process_per_chain and isinstance(inits, list): # legacy
cmdstan_inits = [
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
]
else:
cmdstan_inits = _inits
if one_process_per_chain and isinstance(inv_metric, list): # legacy
cmdstan_metrics = [
f"{_inv_metric[:-5]}_{i}.json" # type: ignore
for i in chain_ids
]
else:
cmdstan_metrics = _inv_metric

sampler_args = SamplerArgs(
num_chains=1 if one_process_per_chain else chains,
iter_warmup=iter_warmup,
iter_sampling=iter_sampling,
save_warmup=save_warmup,
thin=thin,
max_treedepth=max_treedepth,
metric_type=metric, # type: ignore
metric_file=cmdstan_metrics,
step_size=step_size,
adapt_engaged=adapt_engaged,
adapt_delta=adapt_delta,
adapt_init_phase=adapt_init_phase,
adapt_metric_window=adapt_metric_window,
adapt_step_size=adapt_step_size,
fixed_param=fixed_param,
)

args = CmdStanArgs(
self._name,
Expand Down
Loading
Loading