Skip to content

Commit 84428f0

Browse files
committed
Merge branch 'develop' into 2.37-tests
2 parents 87a1c9b + b3c6c47 commit 84428f0

23 files changed

+1582
-711
lines changed

.github/codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
comment: false

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
5050
steps:
5151
- name: Check out github
52-
uses: actions/checkout@v4
52+
uses: actions/checkout@v5
5353

5454
- name: Set up Python ${{ matrix.python-version }}
5555
uses: actions/setup-python@v5

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
steps:
2020
- name: Check out source code
21-
uses: actions/checkout@v4
21+
uses: actions/checkout@v5
2222

2323
- name: Set up Python ${{ matrix.python-version }}
2424
uses: actions/setup-python@v5

cmdstanpy/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def _cleanup_tmpdir() -> None:
4242
set_make_env,
4343
show_versions,
4444
write_stan_json,
45+
enable_logging,
46+
disable_logging,
4547
)
4648

4749
__all__ = [
@@ -63,4 +65,6 @@ def _cleanup_tmpdir() -> None:
6365
'show_versions',
6466
'rebuild_cmdstan',
6567
'cmdstan_version',
68+
"enable_logging",
69+
"disable_logging",
6670
]

cmdstanpy/stanfit/__init__.py

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import glob
44
import os
5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import List, Optional, Union
66

77
from cmdstanpy.cmdstan_args import (
88
CmdStanArgs,
@@ -12,7 +12,7 @@
1212
SamplerArgs,
1313
VariationalArgs,
1414
)
15-
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
15+
from cmdstanpy.utils import check_sampler_csv, get_logger, stancsv
1616

1717
from .gq import CmdStanGQ
1818
from .laplace import CmdStanLaplace
@@ -103,10 +103,9 @@ def from_csv(
103103
' includes non-csv file: {}'.format(file)
104104
)
105105

106-
config_dict: Dict[str, Any] = {}
107106
try:
108-
with open(csvfiles[0], 'r') as fd:
109-
scan_config(fd, config_dict, 0)
107+
comments, *_ = stancsv.parse_comments_header_and_draws(csvfiles[0])
108+
config_dict = stancsv.parse_config(comments)
110109
except (IOError, OSError, PermissionError) as e:
111110
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
112111
if 'model' not in config_dict or 'method' not in config_dict:
@@ -118,39 +117,43 @@ def from_csv(
118117
method, config_dict['method']
119118
)
120119
)
120+
model: str = config_dict['model'] # type: ignore
121121
try:
122122
if config_dict['method'] == 'sample':
123+
save_warmup = config_dict['save_warmup'] == 1
123124
chains = len(csvfiles)
125+
num_samples: int = config_dict['num_samples'] # type: ignore
126+
num_warmup: int = config_dict['num_warmup'] # type: ignore
127+
thin: int = config_dict['thin'] # type: ignore
124128
sampler_args = SamplerArgs(
125-
iter_sampling=config_dict['num_samples'],
126-
iter_warmup=config_dict['num_warmup'],
127-
thin=config_dict['thin'],
128-
save_warmup=config_dict['save_warmup'],
129+
iter_sampling=num_samples,
130+
iter_warmup=num_warmup,
131+
thin=thin,
132+
save_warmup=save_warmup,
129133
)
130134
# bugfix 425, check for fixed_params output
131135
try:
132136
check_sampler_csv(
133137
csvfiles[0],
134-
iter_sampling=config_dict['num_samples'],
135-
iter_warmup=config_dict['num_warmup'],
136-
thin=config_dict['thin'],
137-
save_warmup=config_dict['save_warmup'],
138+
iter_sampling=num_samples,
139+
iter_warmup=num_warmup,
140+
thin=thin,
141+
save_warmup=save_warmup,
138142
)
139143
except ValueError:
140144
try:
141145
check_sampler_csv(
142146
csvfiles[0],
143-
is_fixed_param=True,
144-
iter_sampling=config_dict['num_samples'],
145-
iter_warmup=config_dict['num_warmup'],
146-
thin=config_dict['thin'],
147-
save_warmup=config_dict['save_warmup'],
147+
iter_sampling=num_samples,
148+
iter_warmup=num_warmup,
149+
thin=thin,
150+
save_warmup=save_warmup,
148151
)
149152
sampler_args = SamplerArgs(
150-
iter_sampling=config_dict['num_samples'],
151-
iter_warmup=config_dict['num_warmup'],
152-
thin=config_dict['thin'],
153-
save_warmup=config_dict['save_warmup'],
153+
iter_sampling=num_samples,
154+
iter_warmup=num_warmup,
155+
thin=thin,
156+
save_warmup=save_warmup,
154157
fixed_param=True,
155158
)
156159
except ValueError as e:
@@ -159,8 +162,8 @@ def from_csv(
159162
) from e
160163

161164
cmdstan_args = CmdStanArgs(
162-
model_name=config_dict['model'],
163-
model_exe=config_dict['model'],
165+
model_name=model,
166+
model_exe=model,
164167
chain_ids=[x + 1 for x in range(chains)],
165168
method_args=sampler_args,
166169
)
@@ -177,14 +180,18 @@ def from_csv(
177180
"Cannot find optimization algorithm"
178181
" in file {}.".format(csvfiles[0])
179182
)
183+
algorithm: str = config_dict['algorithm'] # type: ignore
184+
save_iterations = config_dict['save_iterations'] == 1
185+
jacobian = config_dict.get('jacobian', 0) == 1
186+
180187
optimize_args = OptimizeArgs(
181-
algorithm=config_dict['algorithm'],
182-
save_iterations=config_dict['save_iterations'],
183-
jacobian=config_dict.get('jacobian', 0),
188+
algorithm=algorithm,
189+
save_iterations=save_iterations,
190+
jacobian=jacobian,
184191
)
185192
cmdstan_args = CmdStanArgs(
186-
model_name=config_dict['model'],
187-
model_exe=config_dict['model'],
193+
model_name=model,
194+
model_exe=model,
188195
chain_ids=None,
189196
method_args=optimize_args,
190197
)
@@ -200,18 +207,18 @@ def from_csv(
200207
" in file {}.".format(csvfiles[0])
201208
)
202209
variational_args = VariationalArgs(
203-
algorithm=config_dict['algorithm'],
204-
iter=config_dict['iter'],
205-
grad_samples=config_dict['grad_samples'],
206-
elbo_samples=config_dict['elbo_samples'],
207-
eta=config_dict['eta'],
208-
tol_rel_obj=config_dict['tol_rel_obj'],
209-
eval_elbo=config_dict['eval_elbo'],
210-
output_samples=config_dict['output_samples'],
210+
algorithm=config_dict['algorithm'], # type: ignore
211+
iter=config_dict['iter'], # type: ignore
212+
grad_samples=config_dict['grad_samples'], # type: ignore
213+
elbo_samples=config_dict['elbo_samples'], # type: ignore
214+
eta=config_dict['eta'], # type: ignore
215+
tol_rel_obj=config_dict['tol_rel_obj'], # type: ignore
216+
eval_elbo=config_dict['eval_elbo'], # type: ignore
217+
output_samples=config_dict['output_samples'], # type: ignore
211218
)
212219
cmdstan_args = CmdStanArgs(
213-
model_name=config_dict['model'],
214-
model_exe=config_dict['model'],
220+
model_name=model,
221+
model_exe=model,
215222
chain_ids=None,
216223
method_args=variational_args,
217224
)
@@ -221,14 +228,15 @@ def from_csv(
221228
runset._set_retcode(i, 0)
222229
return CmdStanVB(runset)
223230
elif config_dict['method'] == 'laplace':
231+
jacobian = config_dict['jacobian'] == 1
224232
laplace_args = LaplaceArgs(
225-
mode=config_dict['mode'],
226-
draws=config_dict['draws'],
227-
jacobian=config_dict['jacobian'],
233+
mode=config_dict['mode'], # type: ignore
234+
draws=config_dict['draws'], # type: ignore
235+
jacobian=jacobian,
228236
)
229237
cmdstan_args = CmdStanArgs(
230-
model_name=config_dict['model'],
231-
model_exe=config_dict['model'],
238+
model_name=model,
239+
model_exe=model,
232240
chain_ids=None,
233241
method_args=laplace_args,
234242
)
@@ -237,18 +245,18 @@ def from_csv(
237245
for i in range(len(runset._retcodes)):
238246
runset._set_retcode(i, 0)
239247
mode: CmdStanMLE = from_csv(
240-
config_dict['mode'],
248+
config_dict['mode'], # type: ignore
241249
method='optimize',
242250
) # type: ignore
243251
return CmdStanLaplace(runset, mode=mode)
244252
elif config_dict['method'] == 'pathfinder':
245253
pathfinder_args = PathfinderArgs(
246-
num_draws=config_dict['num_draws'],
247-
num_paths=config_dict['num_paths'],
254+
num_draws=config_dict['num_draws'], # type: ignore
255+
num_paths=config_dict['num_paths'], # type: ignore
248256
)
249257
cmdstan_args = CmdStanArgs(
250-
model_name=config_dict['model'],
251-
model_exe=config_dict['model'],
258+
model_name=model,
259+
model_exe=model,
252260
chain_ids=None,
253261
method_args=pathfinder_args,
254262
)

cmdstanpy/stanfit/gq.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@
3131

3232

3333
from cmdstanpy.cmdstan_args import Method
34-
from cmdstanpy.utils import build_xarray_data, flatten_chains, get_logger
35-
from cmdstanpy.utils.stancsv import scan_generic_csv
34+
from cmdstanpy.utils import (
35+
build_xarray_data,
36+
flatten_chains,
37+
get_logger,
38+
stancsv,
39+
)
3640

3741
from .mcmc import CmdStanMCMC
3842
from .metadata import InferenceMetadata
@@ -65,8 +69,7 @@ def __init__(
6569
self.previous_fit: Fit = previous_fit
6670

6771
self._draws: np.ndarray = np.array(())
68-
config = self._validate_csv_files()
69-
self._metadata = InferenceMetadata(config)
72+
self._metadata = self._validate_csv_files()
7073

7174
def __repr__(self) -> str:
7275
repr = 'CmdStanGQ: model={} chains={}{}'.format(
@@ -99,48 +102,38 @@ def __getstate__(self) -> dict:
99102
self._assemble_generated_quantities()
100103
return self.__dict__
101104

102-
def _validate_csv_files(self) -> Dict[str, Any]:
105+
def _validate_csv_files(self) -> InferenceMetadata:
103106
"""
104107
Checks that Stan CSV output files for all chains are consistent
105-
and returns dict containing config and column names.
108+
and returns InferenceMetadata object containing config and column names.
106109
107-
Raises exception when inconsistencies detected.
110+
Raises exception if inconsistencies are detected.
108111
"""
109-
dzero = {}
110-
for i in range(self.chains):
111-
if i == 0:
112-
dzero = scan_generic_csv(
113-
path=self.runset.csv_files[i],
114-
)
115-
else:
116-
drest = scan_generic_csv(
117-
path=self.runset.csv_files[i],
118-
)
119-
for key in dzero:
120-
if (
121-
key
122-
not in [
123-
'id',
124-
'fitted_params',
125-
'diagnostic_file',
126-
'metric_file',
127-
'profile_file',
128-
'init',
129-
'seed',
130-
'start_datetime',
131-
]
132-
and dzero[key] != drest[key]
133-
):
134-
raise ValueError(
135-
'CmdStan config mismatch in Stan CSV file {}: '
136-
'arg {} is {}, expected {}'.format(
137-
self.runset.csv_files[i],
138-
key,
139-
dzero[key],
140-
drest[key],
141-
)
112+
excluded_fields = {
113+
'id',
114+
'fitted_params',
115+
'diagnostic_file',
116+
'metric_file',
117+
'profile_file',
118+
'init',
119+
'seed',
120+
'start_datetime',
121+
}
122+
meta0 = InferenceMetadata.from_csv(self.runset.csv_files[0])
123+
for i in range(1, self.chains):
124+
meta = InferenceMetadata.from_csv(self.runset.csv_files[i])
125+
for key in set(meta._cmdstan_config.keys()) - excluded_fields:
126+
if meta0[key] != meta[key]:
127+
raise ValueError(
128+
'CmdStan config mismatch in Stan CSV file {}: '
129+
'arg {} is {}, expected {}'.format(
130+
self.runset.csv_files[i],
131+
key,
132+
meta0[key],
133+
meta[key],
142134
)
143-
return dzero
135+
)
136+
return meta0
144137

145138
@property
146139
def chains(self) -> int:
@@ -157,7 +150,7 @@ def column_names(self) -> Tuple[str, ...]:
157150
"""
158151
Names of generated quantities of interest.
159152
"""
160-
return self._metadata.cmdstan_config['column_names'] # type: ignore
153+
return self._metadata.column_names
161154

162155
@property
163156
def metadata(self) -> InferenceMetadata:
@@ -633,11 +626,17 @@ def _assemble_generated_quantities(self) -> None:
633626
order='F',
634627
)
635628
for chain in range(self.chains):
636-
with open(self.runset.csv_files[chain], 'r') as fd:
637-
lines = (line for line in fd if not line.startswith('#'))
638-
gq_sample[:, chain, :] = np.loadtxt(
639-
lines, dtype=np.ndarray, ndmin=2, skiprows=1, delimiter=','
629+
csv_file = self.runset.csv_files[chain]
630+
try:
631+
*_, draws = stancsv.parse_comments_header_and_draws(
632+
self.runset.csv_files[chain]
640633
)
634+
gq_sample[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)
635+
except Exception as exc:
636+
raise ValueError(
637+
f"An error occurred when parsing Stan csv {csv_file}"
638+
f" for chain {chain}"
639+
) from exc
641640
self._draws = gq_sample
642641

643642
def _draws_start(self, inc_warmup: bool) -> Tuple[int, int]:

0 commit comments

Comments
 (0)