From ddb203a1373f299419e7406b3f443c339efd4646 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 14 Aug 2025 14:41:38 -0400 Subject: [PATCH 1/5] Deprecate CmdStanMCMC.metric, provide .inv_metric instead --- cmdstanpy/stanfit/mcmc.py | 24 +++++++++----- cmdstanpy_tutorial.ipynb | 2 +- cmdstanpy_tutorial.py | 2 +- .../users-guide/examples/MCMC Sampling.ipynb | 4 +-- docsrc/users-guide/hello_world.rst | 2 +- test/test_sample.py | 32 ++++++------------- 6 files changed, 31 insertions(+), 35 deletions(-) diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 91f74d36..323f1d1e 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -233,19 +233,27 @@ def metric_type(self) -> Optional[str]: else None ) + # TODO(2.0): remove @property def metric(self) -> Optional[np.ndarray]: + """Deprecated. Use ``.inv_metric`` instead.""" + get_logger().warning( + 'The "metric" property is deprecated, use "inv_metric" instead. ' + 'This will be the same quantity, but with a more accurate name.' + ) + return self.inv_metric + + @property + def inv_metric(self) -> Optional[np.ndarray]: """ - Metric used by sampler for each chain. - When sampler algorithm 'fixed_param' is specified, metric is None. + Inverse mass matrix used by sampler for each chain. + Returns a ``nchains x nparams`` array when metric_type is 'diag_e', + a ``nchains x nparams x nparams`` array when metric_type is 'dense_e', + or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'. """ - if self._is_fixed_param: - return None - if self._metadata.cmdstan_config['metric'] == 'unit_e': - get_logger().info( - 'Unit diagnonal metric, inverse mass matrix size unknown.' - ) + if self._is_fixed_param or self.metric_type == 'unit_e': return None + self._assemble_draws() return self._metric diff --git a/cmdstanpy_tutorial.ipynb b/cmdstanpy_tutorial.ipynb index b88bb633..38401605 100644 --- a/cmdstanpy_tutorial.ipynb +++ b/cmdstanpy_tutorial.ipynb @@ -352,7 +352,7 @@ "metadata": {}, "outputs": [], "source": [ - "fit.metric_type, fit.metric" + "fit.metric_type, fit.inv_metric" ] }, { diff --git a/cmdstanpy_tutorial.py b/cmdstanpy_tutorial.py index e7f052f6..83a7b7a7 100644 --- a/cmdstanpy_tutorial.py +++ b/cmdstanpy_tutorial.py @@ -37,7 +37,7 @@ print(fit.step_size) print(fit.metric_type) -print(fit.metric) +print(fit.inv_metric) # #### Summarize the results diff --git a/docsrc/users-guide/examples/MCMC Sampling.ipynb b/docsrc/users-guide/examples/MCMC Sampling.ipynb index 3891fe46..053faa94 100644 --- a/docsrc/users-guide/examples/MCMC Sampling.ipynb +++ b/docsrc/users-guide/examples/MCMC Sampling.ipynb @@ -1483,7 +1483,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1502,7 +1502,7 @@ } ], "source": [ - "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\nmetric:\\n{fit.metric}')" + "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\ninverse metric:\\n{fit.inv_metric}')" ] }, { diff --git a/docsrc/users-guide/hello_world.rst b/docsrc/users-guide/hello_world.rst index 0fd88cc7..bc6d4f7b 100644 --- a/docsrc/users-guide/hello_world.rst +++ b/docsrc/users-guide/hello_world.rst @@ -171,7 +171,7 @@ access to the the per-chain HMC tuning parameters from the NUTS-HMC adaptive sam .. ipython:: python print(fit.metric_type) - print(fit.metric) + print(fit.inv_metric) print(fit.step_size) diff --git a/test/test_sample.py b/test/test_sample.py index 7f4f531f..e7fb90d8 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -89,7 +89,7 @@ def test_bernoulli_good(stanfile: str): assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS)) assert bern_fit.metric_type == 'diag_e' assert bern_fit.step_size.shape == (2,) - assert bern_fit.metric.shape == (2, 1) + assert bern_fit.inv_metric.shape == (2, 1) assert bern_fit.draws(concat_chains=True).shape == ( 200, @@ -125,7 +125,7 @@ def test_bernoulli_good(stanfile: str): assert bern_sample.shape == (100, 2, len(BERNOULLI_COLS)) assert bern_fit.metric_type == 'dense_e' assert bern_fit.step_size.shape == (2,) - assert bern_fit.metric.shape == (2, 1, 1) + assert bern_fit.inv_metric.shape == (2, 1, 1) bern_fit = bern_model.sample( data=jdata, @@ -186,9 +186,7 @@ def test_bernoulli_good(stanfile: str): @pytest.mark.parametrize("stanfile", ["bernoulli.stan"]) -def test_bernoulli_unit_e( - stanfile: str, caplog: pytest.LogCaptureFixture -) -> None: +def test_bernoulli_unit_e(stanfile: str) -> None: stan = os.path.join(DATAFILES_PATH, stanfile) bern_model = CmdStanModel(stan_file=stan) @@ -204,19 +202,9 @@ def test_bernoulli_unit_e( show_progress=False, ) assert bern_fit.metric_type == 'unit_e' - assert bern_fit.metric is None + assert bern_fit.inv_metric is None assert bern_fit.step_size.shape == (2,) - with caplog.at_level(logging.INFO): - logging.getLogger() - assert bern_fit.metric is None - check_present( - caplog, - ( - 'cmdstanpy', - 'INFO', - 'Unit diagnonal metric, inverse mass matrix size unknown.', - ), - ) + assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS)) @@ -535,7 +523,7 @@ def test_fixed_param_good() -> None: ) assert datagen_fit.runset._args.method == Method.SAMPLE assert datagen_fit.metric_type is None - assert datagen_fit.metric is None + assert datagen_fit.inv_metric is None assert datagen_fit.step_size is None assert datagen_fit.divergences is None assert datagen_fit.max_treedepths is None @@ -638,7 +626,7 @@ def test_fixed_param_good() -> None: assert datagen_fit.column_names == tuple(column_names) assert datagen_fit.num_draws_sampling == 100 assert datagen_fit.draws().shape == (100, 1, len(column_names)) - assert datagen_fit.metric is None + assert datagen_fit.inv_metric is None assert datagen_fit.metric_type is None assert datagen_fit.step_size is None @@ -860,7 +848,7 @@ def test_validate_big_run() -> None: assert fit.column_names == tuple(column_names) assert fit.metric_type == 'diag_e' assert fit.step_size.shape == (2,) - assert fit.metric.shape == (2, 2095) + assert fit.inv_metric.shape == (2, 2095) assert fit.draws().shape == (1000, 2, 2102) assert fit.draws_pd(vars=['phi']).shape == (2000, 2095) with raises_nested(ValueError, r'Unknown variable: gamma'): @@ -2136,8 +2124,8 @@ def test_sample_dense_mass_matrix(): linear_model = CmdStanModel(stan_file=stan) fit = linear_model.sample(data=jdata, metric="dense_e", chains=2) - assert fit.metric is not None - assert fit.metric.shape == (2, 3, 3) + assert fit.inv_metric is not None + assert fit.inv_metric.shape == (2, 3, 3) def test_no_output_draws(): From 4db7ecb73f35247898921c0e9492a18d417ae73c Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 14 Aug 2025 15:45:17 -0400 Subject: [PATCH 2/5] Rework how initial inverse mass matrix can be supplied, deprecate former overloading of metric argument --- cmdstanpy/cmdstan_args.py | 153 ++++------------------------------ cmdstanpy/model.py | 122 ++++++++++++++++++--------- cmdstanpy/utils/filesystem.py | 35 +++++++- cmdstanpy/utils/stancsv.py | 38 +++++++++ test/test_cmdstan_args.py | 23 ++--- test/test_sample.py | 26 +++--- test/test_utils.py | 15 ++++ 7 files changed, 211 insertions(+), 201 deletions(-) diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 07040d6d..60fc8b6e 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -1,23 +1,16 @@ """ CmdStan arguments """ + import os from enum import Enum, auto from time import time -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, List, Mapping, Optional, Union import numpy as np from numpy.random import default_rng -from cmdstanpy import _TMPDIR -from cmdstanpy.utils import ( - cmdstan_path, - cmdstan_version_before, - create_named_text_file, - get_logger, - read_metric, - write_stan_json, -) +from cmdstanpy.utils import cmdstan_path, cmdstan_version_before, get_logger OptionalPath = Union[str, os.PathLike, None] @@ -64,9 +57,8 @@ def __init__( save_warmup: bool = False, thin: Optional[int] = None, max_treedepth: Optional[int] = None, - metric: Union[ - str, Dict[str, Any], List[str], List[Dict[str, Any]], None - ] = None, + metric_type: Optional[str] = None, + metric_file: Union[str, List[str], None] = None, step_size: Union[float, List[float], None] = None, adapt_engaged: bool = True, adapt_delta: Optional[float] = None, @@ -82,9 +74,8 @@ def __init__( self.save_warmup = save_warmup self.thin = thin self.max_treedepth = max_treedepth - self.metric = metric - self.metric_type: Optional[str] = None - self.metric_file: Union[str, List[str], None] = None + self.metric_type: Optional[str] = metric_type + self.metric_file: Union[str, List[str], None] = metric_file self.step_size = step_size self.adapt_engaged = adapt_engaged self.adapt_delta = adapt_delta @@ -176,124 +167,15 @@ def validate(self, chains: Optional[int]) -> None: 'Argument "step_size" must be > 0, ' 'chain {}, found {}.'.format(i + 1, step_size) ) - if self.metric is not None: - if isinstance(self.metric, str): - if self.metric in ['diag', 'diag_e']: - self.metric_type = 'diag_e' - elif self.metric in ['dense', 'dense_e']: - self.metric_type = 'dense_e' - elif self.metric in ['unit', 'unit_e']: - self.metric_type = 'unit_e' - else: - if not os.path.exists(self.metric): - raise ValueError('no such file {}'.format(self.metric)) - dims = read_metric(self.metric) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = self.metric - elif isinstance(self.metric, dict): - if 'inv_metric' not in self.metric: - raise ValueError( - 'Entry "inv_metric" not found in metric dict.' - ) - dims = list(np.asarray(self.metric['inv_metric']).shape) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - dict_file = create_named_text_file( - dir=_TMPDIR, prefix="metric", suffix=".json" - ) - write_stan_json(dict_file, self.metric) - self.metric_file = dict_file - elif isinstance(self.metric, (list, tuple)): - if len(self.metric) != chains: - raise ValueError( - 'Number of metric files must match number of chains,' - ' found {} metric files for {} chains.'.format( - len(self.metric), chains - ) - ) - if all(isinstance(elem, dict) for elem in self.metric): - metric_files: List[str] = [] - for i, metric in enumerate(self.metric): - metric_dict: Dict[str, Any] = metric # type: ignore - if 'inv_metric' not in metric_dict: - raise ValueError( - 'Entry "inv_metric" not found in metric dict ' - 'for chain {}.'.format(i + 1) - ) - if i == 0: - dims = list( - np.asarray(metric_dict['inv_metric']).shape - ) - else: - dims2 = list( - np.asarray(metric_dict['inv_metric']).shape - ) - if dims != dims2: - raise ValueError( - 'Found inconsistent "inv_metric" entry ' - 'for chain {}: entry has dims ' - '{}, expected {}.'.format( - i + 1, dims, dims2 - ) - ) - dict_file = create_named_text_file( - dir=_TMPDIR, prefix="metric", suffix=".json" - ) - write_stan_json(dict_file, metric_dict) - metric_files.append(dict_file) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = metric_files - elif all(isinstance(elem, str) for elem in self.metric): - metric_files = [] - for i, metric in enumerate(self.metric): - assert isinstance(metric, str) # typecheck - if not os.path.exists(metric): - raise ValueError('no such file {}'.format(metric)) - if i == 0: - dims = read_metric(metric) - else: - dims2 = read_metric(metric) - if len(dims) != len(dims2): - raise ValueError( - 'Metrics files {}, {},' - ' inconsistent metrics'.format( - self.metric[0], metric - ) - ) - if dims != dims2: - raise ValueError( - 'Metrics files {}, {},' - ' inconsistent metrics'.format( - self.metric[0], metric - ) - ) - metric_files.append(metric) - if len(dims) == 1: - self.metric_type = 'diag_e' - else: - self.metric_type = 'dense_e' - self.metric_file = metric_files - else: - raise ValueError( - 'Argument "metric" must be a list of pathnames or ' - 'Python dicts, found list of {}.'.format( - type(self.metric[0]) - ) - ) - else: + if self.metric_type is not None: + if self.metric_type in ['diag', 'dense', 'unit']: + self.metric_type += '_e' + if self.metric_type not in ['diag_e', 'dense_e', 'unit_e']: raise ValueError( - 'Invalid metric specified, not a recognized metric type, ' - 'must be either a metric type name, a filepath, dict, ' - 'or list of per-chain filepaths or dicts. Found ' - 'an object of type {}.'.format(type(self.metric)) + 'Argument "metric" must be one of [diag, dense, unit,' + ' diag_e, dense_e, unit_e], found {}.'.format( + self.metric_type + ) ) if self.adapt_delta is not None: @@ -330,7 +212,8 @@ def validate(self, chains: Optional[int]) -> None: if self.fixed_param and ( self.max_treedepth is not None - or self.metric is not None + or self.metric_type is not None + or self.metric_file is not None or self.step_size is not None or not ( self.adapt_delta is None @@ -369,7 +252,7 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]: cmd.append(f'stepsize={self.step_size}') else: cmd.append(f'stepsize={self.step_size[idx]}') - if self.metric is not None: + if self.metric_type is not None: cmd.append(f'metric={self.metric_type}') if self.metric_file is not None: if not isinstance(self.metric_file, list): diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 116e11b4..33b08bda 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -26,6 +26,7 @@ Union, ) +import numpy as np import pandas as pd from tqdm.auto import tqdm @@ -57,7 +58,12 @@ get_logger, returncode_msg, ) -from cmdstanpy.utils.filesystem import temp_inits, temp_single_json +from cmdstanpy.utils.filesystem import ( + temp_inits, + temp_metrics, + temp_single_json, +) +from cmdstanpy.utils.stancsv import try_deduce_metric_type from . import progress as progbar @@ -698,6 +704,13 @@ def sample( timeout: Optional[float] = None, *, force_one_process_per_chain: Optional[bool] = None, + inv_metric: Union[ + str, + np.ndarray, + Mapping[str, Any], + List[Union[str, np.ndarray, Mapping[str, Any]]], + None, + ] = None, ) -> CmdStanMCMC: """ Run or more chains of the NUTS-HMC sampler to produce a set of draws @@ -786,29 +799,23 @@ def sample( :param max_treedepth: Maximum depth of trees evaluated by NUTS sampler per iteration. - :param metric: Specification of the mass matrix, either as a - vector consisting of the diagonal elements of the covariance - matrix ('diag' or 'diag_e') or the full covariance matrix - ('dense' or 'dense_e'). - - If the value of the metric argument is a string other than - 'diag', 'diag_e', 'dense', or 'dense_e', it must be - a valid filepath to a JSON or Rdump file which contains an entry - 'inv_metric' whose value is either the diagonal vector or - the full covariance matrix. + :param metric: Specify the type of the mass matrix. Options are + 'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e' + for a dense matrix, or 'unit_e' an identity mass matrix. - If the value of the metric argument is a list of paths, its - length must match the number of chains and all paths must be - unique. + :param inv_metric: Provide an initial value for the inverse + mass matrix. - If the value of the metric argument is a Python dict object, it - must contain an entry 'inv_metric' which specifies either the - diagnoal or dense matrix. - - If the value of the metric argument is a list of Python dicts, - its length must match the number of chains and all dicts must - containan entry 'inv_metric' and all 'inv_metric' entries must - have the same shape. + Valid options include: + - a string, which must be a valid filepath to a JSON or + Rdump file which contains an entry 'inv_metric' whose value + is either a diagonal vector or dense matrix. + - a numpy array containing either the diagonal vector or dense + matrix. + - a dictionary containing an entry 'inv_metric' whose value + is either a diagonal vector or dense matrix. + - a list of any of the above, of length num_chains, with + the same shape of metric in each entry. :param step_size: Initial step size for HMC sampler. The value is either a single number or a list of numbers which will be used @@ -1002,34 +1009,71 @@ def sample( 'Chain_id must be a non-negative integer value,' ' found {}.'.format(chain_id) ) + if metric not in [None, 'diag', 'dense', 'unit_e', 'diag_e', 'dense_e']: + get_logger().warning( + "Providing anything other than metric type for" + " 'metric' is deprecated and will be removed" + " in the next major release." + " Please provide such information via" + " 'inv_metric' argument." + ) + if inv_metric is not None: + raise ValueError( + "Cannot provide both (deprecated) non-metric-type 'metric'" + " argument and 'inv_metric' argument." + ) + inv_metric = metric # type: ignore # for backwards compatibility + metric = None - sampler_args = SamplerArgs( - num_chains=1 if one_process_per_chain else chains, - iter_warmup=iter_warmup, - iter_sampling=iter_sampling, - save_warmup=save_warmup, - thin=thin, - max_treedepth=max_treedepth, - metric=metric, - step_size=step_size, - adapt_engaged=adapt_engaged, - adapt_delta=adapt_delta, - adapt_init_phase=adapt_init_phase, - adapt_metric_window=adapt_metric_window, - adapt_step_size=adapt_step_size, - fixed_param=fixed_param, - ) + if metric is None and inv_metric is not None: + metric = try_deduce_metric_type(inv_metric) + + if isinstance(inv_metric, list): + if not len(inv_metric) == chains: + raise ValueError( + 'Number of metric files must match number of chains,' + ' found {} metric files for {} chains.'.format( + len(inv_metric), chains + ) + ) with temp_single_json(data) as _data, temp_inits( inits, id=chain_ids[0] - ) as _inits: + ) as _inits, temp_metrics(inv_metric, id=chain_ids[0]) as _inv_metric: cmdstan_inits: Union[str, List[str], int, float, None] + cmdstan_metrics: Union[str, List[str], None] + if one_process_per_chain and isinstance(inits, list): # legacy cmdstan_inits = [ f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore ] else: cmdstan_inits = _inits + if one_process_per_chain and isinstance(inv_metric, list): # legacy + cmdstan_metrics = [ + f"{_inv_metric[:-5]}_{i}.json" # type: ignore + for i in chain_ids + ] + else: + cmdstan_metrics = _inv_metric + + sampler_args = SamplerArgs( + num_chains=1 if one_process_per_chain else chains, + iter_warmup=iter_warmup, + iter_sampling=iter_sampling, + save_warmup=save_warmup, + thin=thin, + max_treedepth=max_treedepth, + metric_type=metric, # type: ignore + metric_file=cmdstan_metrics, + step_size=step_size, + adapt_engaged=adapt_engaged, + adapt_delta=adapt_delta, + adapt_init_phase=adapt_init_phase, + adapt_metric_window=adapt_metric_window, + adapt_step_size=adapt_step_size, + fixed_param=fixed_param, + ) args = CmdStanArgs( self._name, diff --git a/cmdstanpy/utils/filesystem.py b/cmdstanpy/utils/filesystem.py index 233898e1..4cb5784a 100644 --- a/cmdstanpy/utils/filesystem.py +++ b/cmdstanpy/utils/filesystem.py @@ -1,6 +1,7 @@ """ Utilities for interacting with the filesystem on multiple platforms """ + import contextlib import os import platform @@ -9,6 +10,8 @@ import tempfile from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union +import numpy as np + from cmdstanpy import _TMPDIR from .json import write_stan_json @@ -104,7 +107,7 @@ def pushd(new_dir: str) -> Iterator[None]: def _temp_single_json( - data: Union[str, os.PathLike, Mapping[str, Any], None] + data: Union[str, os.PathLike, Mapping[str, Any], None], ) -> Iterator[Optional[str]]: """Context manager for json files.""" if data is None: @@ -164,6 +167,36 @@ def _temp_multiinput( yield from _temp_single_json(input) +@contextlib.contextmanager +def temp_metrics( + metrics: Union[ + str, os.PathLike, Mapping[str, Any], np.ndarray, List[Any], None + ], + *, + id: int = 1, +) -> Iterator[Union[str, None]]: + if isinstance(metrics, dict): + if 'inv_metric' not in metrics: + raise ValueError('Entry "inv_metric" not found in metric dict.') + if isinstance(metrics, np.ndarray): + metrics = {"inv_metric": metrics} + + if isinstance(metrics, list): + metrics_processed = [] + for init in metrics: + if isinstance(init, np.ndarray): + metrics_processed.append({"inv_metric": init}) + else: + metrics_processed.append(init) + if isinstance(metrics_processed, dict): + if 'inv_metric' not in metrics_processed: + raise ValueError( + 'Entry "inv_metric" not found in metric dict.' + ) + metrics = metrics_processed + yield from _temp_multiinput(metrics, base=id) + + @contextlib.contextmanager def temp_inits( inits: Union[ diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index 74830994..3e6d64d7 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -13,6 +13,7 @@ Dict, Iterator, List, + Mapping, MutableMapping, Optional, TextIO, @@ -673,3 +674,40 @@ def parse_rdump_value(rhs: str) -> Union[int, float, np.ndarray]: except TypeError as e: raise ValueError('bad value in Rdump file: {}'.format(rhs)) from e return val + + +def try_deduce_metric_type( + inv_metric: Union[ + str, + np.ndarray, + Mapping[str, Any], + List[Union[str, np.ndarray, Mapping[str, Any]]], + ], +) -> Optional[str]: + """Given a user-supplied metric, try to infer the correct metric type.""" + if isinstance(inv_metric, list): + if inv_metric: + inv_metric = inv_metric[0] + + if isinstance(inv_metric, Mapping): + if (metric_type := inv_metric.get("metric_type")) in ( + 'diag_e', + 'dense_e', + ): + return metric_type # type: ignore + inv_metric = inv_metric.get('inv_metric', None) + + if isinstance(inv_metric, np.ndarray): + if len(inv_metric.shape) == 1: + return 'diag_e' + else: + return 'dense_e' + + if isinstance(inv_metric, str): + dims = read_metric(inv_metric) + if len(dims) == 1: + return 'diag_e' + else: + return 'dense_e' + + return None diff --git a/test/test_cmdstan_args.py b/test/test_cmdstan_args.py index 7587564d..14832374 100644 --- a/test/test_cmdstan_args.py +++ b/test/test_cmdstan_args.py @@ -151,7 +151,7 @@ def test_bad() -> None: with pytest.raises(ValueError): args.validate(chains=2) - args = SamplerArgs(metric='dense', fixed_param=True) + args = SamplerArgs(metric_type='dense', fixed_param=True) with pytest.raises(ValueError): args.validate(chains=2) @@ -221,22 +221,22 @@ def test_adapt() -> None: def test_metric() -> None: - args = SamplerArgs(metric='dense_e') + args = SamplerArgs(metric_type='dense_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=dense_e' in ' '.join(cmd) - args = SamplerArgs(metric='dense') + args = SamplerArgs(metric_type='dense') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=dense_e' in ' '.join(cmd) - args = SamplerArgs(metric='diag_e') + args = SamplerArgs(metric_type='diag_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=diag_e' in ' '.join(cmd) - args = SamplerArgs(metric='diag') + args = SamplerArgs(metric_type='diag') args.validate(chains=4) cmd = args.compose(1, cmd=[]) assert 'method=sample algorithm=hmc metric=diag_e' in ' '.join(cmd) @@ -247,29 +247,20 @@ def test_metric() -> None: assert 'metric=' not in ' '.join(cmd) jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json') - args = SamplerArgs(metric=jmetric) + args = SamplerArgs(metric_file=jmetric) args.validate(chains=4) cmd = args.compose(1, cmd=[]) - assert 'metric=diag_e' in ' '.join(cmd) assert 'metric_file=' in ' '.join(cmd) assert 'bernoulli.metric.json' in ' '.join(cmd) jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') - args = SamplerArgs(metric=[jmetric, jmetric2]) + args = SamplerArgs(metric_file=[jmetric, jmetric2]) args.validate(chains=2) cmd = args.compose(0, cmd=[]) assert 'bernoulli.metric.json' in ' '.join(cmd) cmd = args.compose(1, cmd=[]) assert 'bernoulli.metric-2.json' in ' '.join(cmd) - args = SamplerArgs(metric=[jmetric, jmetric2]) - with pytest.raises(ValueError): - args.validate(chains=4) - - args = SamplerArgs(metric='/no/such/path/to.file') - with pytest.raises(ValueError): - args.validate(chains=4) - def test_fixed_param() -> None: args = SamplerArgs(fixed_param=True) diff --git a/test/test_sample.py b/test/test_sample.py index e7fb90d8..9f637ad6 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -1004,7 +1004,7 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=jmetric, + inv_metric=jmetric, ) jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') bern_model.sample( @@ -1014,7 +1014,7 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=[jmetric, jmetric2], + inv_metric=[jmetric, jmetric2], ) # read json in as dict with open(jmetric) as fd: @@ -1028,7 +1028,7 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=metric_dict_1, + inv_metric=metric_dict_1, ) bern_model.sample( data=jdata, @@ -1036,7 +1036,15 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + inv_metric=[metric_dict_1, metric_dict_2], + ) + bern_model.sample( + data=jdata, + chains=2, + seed=12345, + iter_warmup=100, + iter_sampling=200, + inv_metric=[np.array(metric_dict_1['inv_metric']), jmetric2], ) with pytest.raises( ValueError, @@ -1049,23 +1057,21 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + inv_metric=[metric_dict_1, metric_dict_2], ) # metric mismatches - (not appropriate for bernoulli) with open(os.path.join(DATAFILES_PATH, 'metric_diag.data.json')) as fd: metric_dict_1 = json.load(fd) with open(os.path.join(DATAFILES_PATH, 'metric_dense.data.json')) as fd: metric_dict_2 = json.load(fd) - with pytest.raises( - ValueError, match='Found inconsistent "inv_metric" entry' - ): + with pytest.raises(RuntimeError, match='Error during sampling'): bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=100, iter_sampling=200, - metric=[metric_dict_1, metric_dict_2], + inv_metric=[metric_dict_1, metric_dict_2], ) # metric dict, no "inv_metric": some_dict = {"foo": [1, 2, 3]} @@ -1078,7 +1084,7 @@ def test_custom_metric() -> None: seed=12345, iter_warmup=100, iter_sampling=200, - metric=some_dict, + inv_metric=some_dict, ) diff --git a/test/test_utils.py b/test/test_utils.py index 27e5db53..5769abc7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -457,6 +457,21 @@ def test_metric_missing() -> None: read_metric(metric_file) +def test_deduce_metric_type() -> None: + assert stancsv.try_deduce_metric_type(np.zeros((3, 3))) == 'dense_e' + assert stancsv.try_deduce_metric_type(np.zeros((3,))) == 'diag_e' + + assert stancsv.try_deduce_metric_type([np.zeros((3, 3))]) == 'dense_e' + assert ( + stancsv.try_deduce_metric_type({"inv_metric": np.zeros((3,))}) + == 'diag_e' + ) + assert ( + stancsv.try_deduce_metric_type([{"inv_metric": np.zeros((3,))}]) + == 'diag_e' + ) + + @mark_windows_only def test_windows_short_path_directory() -> None: with tempfile.TemporaryDirectory( From ad1fde0580a1e3a5c7d8c6e77e6ce2ddf361a14f Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 14 Aug 2025 16:08:15 -0400 Subject: [PATCH 3/5] Actually assert in custom metric test --- cmdstanpy/model.py | 8 ++++- test/test_sample.py | 82 +++++++++++++++++++++++++++++++-------------- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 33b08bda..08f220b9 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -1009,7 +1009,13 @@ def sample( 'Chain_id must be a non-negative integer value,' ' found {}.'.format(chain_id) ) - if metric not in [None, 'diag', 'dense', 'unit_e', 'diag_e', 'dense_e']: + if metric is not None and metric not in ( + 'diag', + 'dense', + 'unit_e', + 'diag_e', + 'dense_e', + ): get_logger().warning( "Providing anything other than metric type for" " 'metric' is deprecated and will be removed" diff --git a/test/test_sample.py b/test/test_sample.py index 9f637ad6..2554e30e 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -996,56 +996,88 @@ def test_custom_metric() -> None: jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json') + jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') + # read json in as dict + with open(jmetric) as fd: + metric_dict_1 = json.load(fd) + with open(jmetric2) as fd: + metric_dict_2 = json.load(fd) # just test that it runs without error - bern_model.sample( + fit1 = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=jmetric, ) - jmetric2 = os.path.join(DATAFILES_PATH, 'bernoulli.metric-2.json') - bern_model.sample( + np.testing.assert_allclose( + fit1.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit1.inv_metric[1], metric_dict_1['inv_metric'], atol=1e-6 + ) + + fit2 = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=[jmetric, jmetric2], ) - # read json in as dict - with open(jmetric) as fd: - metric_dict_1 = json.load(fd) - with open(jmetric2) as fd: - metric_dict_2 = json.load(fd) - bern_model.sample( + np.testing.assert_allclose( + fit2.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit2.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 + ) + + fit3 = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=metric_dict_1, ) - bern_model.sample( + for i in range(4): + np.testing.assert_allclose( + fit3.inv_metric[i], metric_dict_1['inv_metric'], atol=1e-6 + ) + fit4 = bern_model.sample( data=jdata, chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], ) - bern_model.sample( + np.testing.assert_allclose( + fit4.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit4.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 + ) + + fit5 = bern_model.sample( data=jdata, chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=[np.array(metric_dict_1['inv_metric']), jmetric2], ) + np.testing.assert_allclose( + fit5.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 + ) + np.testing.assert_allclose( + fit5.inv_metric[1], metric_dict_2['inv_metric'], atol=1e-6 + ) + with pytest.raises( ValueError, match='Number of metric files must match number of chains,', @@ -1055,8 +1087,8 @@ def test_custom_metric() -> None: chains=4, parallel_chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], ) # metric mismatches - (not appropriate for bernoulli) @@ -1069,8 +1101,8 @@ def test_custom_metric() -> None: data=jdata, chains=2, seed=12345, - iter_warmup=100, - iter_sampling=200, + iter_warmup=10, + iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], ) # metric dict, no "inv_metric": From de18290652ddb1c91d26e46387c78db02884c53a Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 14 Aug 2025 16:26:44 -0400 Subject: [PATCH 4/5] Test with both parallelism setups --- test/test_sample.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_sample.py b/test/test_sample.py index 2554e30e..026b9144 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -991,7 +991,8 @@ def test_from_csv_no_param_hmc() -> None: assert no_parameters_sample.draws_pd().shape == (100, 93) -def test_custom_metric() -> None: +@pytest.mark.parametrize('force_one_process_per_chain', [True, False]) +def test_custom_metric(force_one_process_per_chain: bool) -> None: stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) @@ -1011,6 +1012,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=jmetric, + force_one_process_per_chain=force_one_process_per_chain, ) np.testing.assert_allclose( fit1.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 @@ -1027,6 +1029,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=[jmetric, jmetric2], + force_one_process_per_chain=force_one_process_per_chain, ) np.testing.assert_allclose( fit2.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 @@ -1043,6 +1046,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=metric_dict_1, + force_one_process_per_chain=force_one_process_per_chain, ) for i in range(4): np.testing.assert_allclose( @@ -1055,6 +1059,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, ) np.testing.assert_allclose( fit4.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 @@ -1070,6 +1075,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=[np.array(metric_dict_1['inv_metric']), jmetric2], + force_one_process_per_chain=force_one_process_per_chain, ) np.testing.assert_allclose( fit5.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6 @@ -1090,6 +1096,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, ) # metric mismatches - (not appropriate for bernoulli) with open(os.path.join(DATAFILES_PATH, 'metric_diag.data.json')) as fd: @@ -1104,6 +1111,7 @@ def test_custom_metric() -> None: iter_warmup=10, iter_sampling=10, inv_metric=[metric_dict_1, metric_dict_2], + force_one_process_per_chain=force_one_process_per_chain, ) # metric dict, no "inv_metric": some_dict = {"foo": [1, 2, 3]} @@ -1117,6 +1125,7 @@ def test_custom_metric() -> None: iter_warmup=100, iter_sampling=200, inv_metric=some_dict, + force_one_process_per_chain=force_one_process_per_chain, ) From cd5b5deb3c52c809cd57b8c058b016f1993a43d2 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Aug 2025 13:59:38 -0400 Subject: [PATCH 5/5] Docstring updates --- cmdstanpy/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 86157215..3335d216 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -798,9 +798,11 @@ def sample( :param max_treedepth: Maximum depth of trees evaluated by NUTS sampler per iteration. - :param metric: Specify the type of the mass matrix. Options are + :param metric: Specify the type of the inverse mass matrix. Options are 'diag' or 'diag_e' for diagonal matrix, 'dense' or 'dense_e' - for a dense matrix, or 'unit_e' an identity mass matrix. + for a dense matrix, or 'unit_e' an identity mass matrix. To provide + an initial value for the inverse mass matrix, use the ``inv_metric`` + argument. :param inv_metric: Provide an initial value for the inverse mass matrix.