Skip to content

Commit 5b92e35

Browse files
committed
Use correct linter settings
1 parent b5c9101 commit 5b92e35

20 files changed

+409
-188
lines changed

cmdstanpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def _cleanup_tmpdir() -> None:
3737
from .utils import (
3838
cmdstan_path,
3939
cmdstan_version,
40+
disable_logging,
41+
enable_logging,
4042
install_cmdstan,
4143
set_cmdstan_path,
4244
set_make_env,
4345
show_versions,
4446
write_stan_json,
45-
enable_logging,
46-
disable_logging,
4747
)
4848

4949
__all__ = [

cmdstanpy/cmdstan_args.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def validate(self, chains: Optional[int]) -> None:
105105
* length of per-chain lists equals specified # of chains
106106
"""
107107
if not isinstance(chains, (int, np.integer)) or chains < 1:
108-
raise ValueError('Sampler expects number of chains to be greater than 0.')
108+
raise ValueError(
109+
'Sampler expects number of chains to be greater than 0.'
110+
)
109111
if not (
110112
self.adapt_delta is None
111113
and self.adapt_init_phase is None
@@ -117,13 +119,17 @@ def validate(self, chains: Optional[int]) -> None:
117119
if self.adapt_delta is not None:
118120
msg = '{}, adapt_delta: {}'.format(msg, self.adapt_delta)
119121
if self.adapt_init_phase is not None:
120-
msg = '{}, adapt_init_phase: {}'.format(msg, self.adapt_init_phase)
122+
msg = '{}, adapt_init_phase: {}'.format(
123+
msg, self.adapt_init_phase
124+
)
121125
if self.adapt_metric_window is not None:
122126
msg = '{}, adapt_metric_window: {}'.format(
123127
msg, self.adapt_metric_window
124128
)
125129
if self.adapt_step_size is not None:
126-
msg = '{}, adapt_step_size: {}'.format(msg, self.adapt_step_size)
130+
msg = '{}, adapt_step_size: {}'.format(
131+
msg, self.adapt_step_size
132+
)
127133
raise ValueError(msg)
128134

129135
if self.iter_warmup is not None:
@@ -151,7 +157,9 @@ def validate(self, chains: Optional[int]) -> None:
151157
positive_int(self.max_treedepth, 'max_treedepth')
152158

153159
if self.step_size is not None:
154-
if isinstance(self.step_size, (float, int, np.integer, np.floating)):
160+
if isinstance(
161+
self.step_size, (float, int, np.integer, np.floating)
162+
):
155163
if self.step_size <= 0:
156164
raise ValueError(
157165
'Argument "step_size" must be > 0, found {}.'.format(
@@ -189,7 +197,9 @@ def validate(self, chains: Optional[int]) -> None:
189197
self.metric_file = self.metric
190198
elif isinstance(self.metric, dict):
191199
if 'inv_metric' not in self.metric:
192-
raise ValueError('Entry "inv_metric" not found in metric dict.')
200+
raise ValueError(
201+
'Entry "inv_metric" not found in metric dict.'
202+
)
193203
dims = list(np.asarray(self.metric['inv_metric']).shape)
194204
if len(dims) == 1:
195205
self.metric_type = 'diag_e'
@@ -218,14 +228,20 @@ def validate(self, chains: Optional[int]) -> None:
218228
'for chain {}.'.format(i + 1)
219229
)
220230
if i == 0:
221-
dims = list(np.asarray(metric_dict['inv_metric']).shape)
231+
dims = list(
232+
np.asarray(metric_dict['inv_metric']).shape
233+
)
222234
else:
223-
dims2 = list(np.asarray(metric_dict['inv_metric']).shape)
235+
dims2 = list(
236+
np.asarray(metric_dict['inv_metric']).shape
237+
)
224238
if dims != dims2:
225239
raise ValueError(
226240
'Found inconsistent "inv_metric" entry '
227241
'for chain {}: entry has dims '
228-
'{}, expected {}.'.format(i + 1, dims, dims2)
242+
'{}, expected {}.'.format(
243+
i + 1, dims, dims2
244+
)
229245
)
230246
dict_file = create_named_text_file(
231247
dir=_TMPDIR, prefix="metric", suffix=".json"
@@ -249,13 +265,15 @@ def validate(self, chains: Optional[int]) -> None:
249265
dims2 = read_metric(metric)
250266
if len(dims) != len(dims2):
251267
raise ValueError(
252-
'Metrics files {}, {}, inconsistent metrics'.format(
268+
'Metrics files {}, {},'
269+
' inconsistent metrics'.format(
253270
self.metric[0], metric
254271
)
255272
)
256273
if dims != dims2:
257274
raise ValueError(
258-
'Metrics files {}, {}, inconsistent metrics'.format(
275+
'Metrics files {}, {},'
276+
' inconsistent metrics'.format(
259277
self.metric[0], metric
260278
)
261279
)
@@ -268,7 +286,9 @@ def validate(self, chains: Optional[int]) -> None:
268286
else:
269287
raise ValueError(
270288
'Argument "metric" must be a list of pathnames or '
271-
'Python dicts, found list of {}.'.format(type(self.metric[0]))
289+
'Python dicts, found list of {}.'.format(
290+
type(self.metric[0])
291+
)
272292
)
273293
else:
274294
raise ValueError(
@@ -281,9 +301,8 @@ def validate(self, chains: Optional[int]) -> None:
281301
if self.adapt_delta is not None:
282302
if not 0 < self.adapt_delta < 1:
283303
raise ValueError(
284-
'Argument "adapt_delta" must be between 0 and 1, found {}'.format(
285-
self.adapt_delta
286-
)
304+
'Argument "adapt_delta" must be between 0 and 1,'
305+
' found {}'.format(self.adapt_delta)
287306
)
288307
if self.adapt_init_phase is not None:
289308
if self.adapt_init_phase < 0 or not isinstance(
@@ -437,7 +456,9 @@ def validate(self, _chains: Optional[int] = None) -> None:
437456
)
438457
if self.algorithm.lower() != 'lbfgs':
439458
if self.history_size is not None:
440-
raise ValueError('history_size requires that algorithm be set to lbfgs')
459+
raise ValueError(
460+
'history_size requires that algorithm be set to lbfgs'
461+
)
441462

442463
positive_float(self.init_alpha, 'init_alpha')
443464
positive_int(self.iter, 'iter')
@@ -620,7 +641,9 @@ def validate(
620641
"""
621642
for csv in self.sample_csv_files:
622643
if not os.path.exists(csv):
623-
raise ValueError('Invalid path for sample csv file: {}'.format(csv))
644+
raise ValueError(
645+
'Invalid path for sample csv file: {}'.format(csv)
646+
)
624647

625648
def compose(self, idx: int, cmd: list[str]) -> list[str]:
626649
"""
@@ -667,7 +690,10 @@ def validate(
667690
"""
668691
Check arguments correctness and consistency.
669692
"""
670-
if self.algorithm is not None and self.algorithm not in self.VARIATIONAL_ALGOS:
693+
if (
694+
self.algorithm is not None
695+
and self.algorithm not in self.VARIATIONAL_ALGOS
696+
):
671697
raise ValueError(
672698
'Please specify variational algorithms as one of [{}]'.format(
673699
', '.join(self.VARIATIONAL_ALGOS)
@@ -794,16 +820,19 @@ def validate(self) -> None:
794820
if chain_id < 1:
795821
raise ValueError('invalid chain_id {}'.format(chain_id))
796822
if self.output_dir is not None:
797-
self.output_dir = os.path.realpath(os.path.expanduser(self.output_dir))
823+
self.output_dir = os.path.realpath(
824+
os.path.expanduser(self.output_dir)
825+
)
798826
if not os.path.exists(self.output_dir):
799827
try:
800828
os.makedirs(self.output_dir)
801-
get_logger().info('created output directory: %s', self.output_dir)
829+
get_logger().info(
830+
'created output directory: %s', self.output_dir
831+
)
802832
except (RuntimeError, PermissionError) as exc:
803833
raise ValueError(
804-
'Invalid path for output files, no such dir: {}.'.format(
805-
self.output_dir
806-
)
834+
'Invalid path for output files, '
835+
'no such dir: {}.'.format(self.output_dir)
807836
) from exc
808837
if not os.path.isdir(self.output_dir):
809838
raise ValueError(
@@ -818,12 +847,14 @@ def validate(self) -> None:
818847
os.remove(testpath) # cleanup
819848
except Exception as exc:
820849
raise ValueError(
821-
'Invalid path for output files, cannot write to dir: {}.'.format(
822-
self.output_dir
823-
)
850+
'Invalid path for output files,'
851+
' cannot write to dir: {}.'.format(self.output_dir)
824852
) from exc
825853
if self.refresh is not None:
826-
if not isinstance(self.refresh, (int, np.integer)) or self.refresh < 1:
854+
if (
855+
not isinstance(self.refresh, (int, np.integer))
856+
or self.refresh < 1
857+
):
827858
raise ValueError(
828859
'Argument "refresh" must be a positive integer value, '
829860
'found {}.'.format(self.refresh)
@@ -895,7 +926,9 @@ def validate(self) -> None:
895926
if isinstance(self.inits, (float, int, np.floating, np.integer)):
896927
if self.inits < 0:
897928
raise ValueError(
898-
'Argument "inits" must be > 0, found {}'.format(self.inits)
929+
'Argument "inits" must be > 0, found {}'.format(
930+
self.inits
931+
)
899932
)
900933
elif isinstance(self.inits, str):
901934
if not (

cmdstanpy/compilation.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def validate_stanc_opts(self) -> None:
165165
del self._stanc_options[deprecated]
166166
else:
167167
get_logger().warning(
168-
'compiler option "%s" is deprecated and should not be used',
168+
'compiler option "%s" is deprecated and should '
169+
'not be used',
169170
deprecated,
170171
)
171172
for key, val in self._stanc_options.items():
@@ -224,7 +225,8 @@ def validate_cpp_opts(self) -> None:
224225
val = self._cpp_options[key]
225226
if not isinstance(val, int) or val < 0:
226227
raise ValueError(
227-
f'{key} must be a non-negative integer value, found {val}.'
228+
f'{key} must be a non-negative integer '
229+
f'value, found {val}.'
228230
)
229231

230232
def validate_user_header(self) -> None:
@@ -234,7 +236,8 @@ def validate_user_header(self) -> None:
234236
"""
235237
if self._user_header != "":
236238
if not (
237-
os.path.exists(self._user_header) and os.path.isfile(self._user_header)
239+
os.path.exists(self._user_header)
240+
and os.path.isfile(self._user_header)
238241
):
239242
raise ValueError(
240243
f"User header file {self._user_header} cannot be found"
@@ -272,7 +275,9 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
272275
else:
273276
for key, val in new_opts.stanc_options.items():
274277
if key == 'include-paths':
275-
if isinstance(val, Iterable) and not isinstance(val, str):
278+
if isinstance(val, Iterable) and not isinstance(
279+
val, str
280+
):
276281
for path in val:
277282
self.add_include_path(str(path))
278283
else:
@@ -337,7 +342,9 @@ def compose(self, filename_in_msg: Optional[str] = None) -> list[str]:
337342
return opts
338343

339344

340-
def src_info(stan_file: str, compiler_options: CompilerOptions) -> dict[str, Any]:
345+
def src_info(
346+
stan_file: str, compiler_options: CompilerOptions
347+
) -> dict[str, Any]:
341348
"""
342349
Get source info for Stan program file.
343350
@@ -525,7 +532,9 @@ def format_stan_file(
525532
else:
526533
raise ValueError(
527534
"Invalid arguments passed for current CmdStan"
528-
+ " version({})\n".format(cmdstan_version() or "Unknown")
535+
+ " version({})\n".format(
536+
cmdstan_version() or "Unknown"
537+
)
529538
+ "--canonicalize requires 2.29 or higher"
530539
)
531540
else:

cmdstanpy/install_cmdstan.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def latest_version() -> str:
110110
print('retry ({}/5)'.format(i + 1))
111111
sleep(1)
112112
continue
113-
raise CmdStanRetrieveError('Cannot connect to CmdStan github repo.') from e
113+
raise CmdStanRetrieveError(
114+
'Cannot connect to CmdStan github repo.'
115+
) from e
114116
content = json.loads(response.decode('utf-8'))
115117
tag = content['tag_name']
116118
match = re.search(r'v?(.+)', tag)
@@ -286,18 +288,26 @@ def build(verbose: bool = False, progress: bool = True, cores: int = 1) -> None:
286288
raise CmdStanInstallError(f'Command "make build" failed\n{str(e)}')
287289
if not os.path.exists(os.path.join('bin', 'stansummary' + EXTENSION)):
288290
raise CmdStanInstallError(
289-
f'bin/stansummary{EXTENSION} not found, please rebuild or report a bug!'
291+
f'bin/stansummary{EXTENSION} not found, please rebuild or '
292+
'report a bug!'
290293
)
291294
if not os.path.exists(os.path.join('bin', 'diagnose' + EXTENSION)):
292295
raise CmdStanInstallError(
293-
f'bin/stansummary{EXTENSION} not found, please rebuild or report a bug!'
296+
f'bin/stansummary{EXTENSION} not found, please rebuild or '
297+
'report a bug!'
294298
)
295299

296300
if is_windows():
297301
# Add tbb to the $PATH on Windows
298-
libtbb = os.path.join(os.getcwd(), 'stan', 'lib', 'stan_math', 'lib', 'tbb')
302+
libtbb = os.path.join(
303+
os.getcwd(), 'stan', 'lib', 'stan_math', 'lib', 'tbb'
304+
)
299305
os.environ['PATH'] = ';'.join(
300-
list(OrderedDict.fromkeys([libtbb] + os.environ.get('PATH', '').split(';')))
306+
list(
307+
OrderedDict.fromkeys(
308+
[libtbb] + os.environ.get('PATH', '').split(';')
309+
)
310+
)
301311
)
302312

303313

@@ -408,9 +418,8 @@ def install_version(
408418
)
409419
if overwrite and os.path.exists('.'):
410420
print(
411-
'Overwrite requested, remove existing build of version {}'.format(
412-
cmdstan_version
413-
)
421+
'Overwrite requested, remove existing build '
422+
'of version {}'.format(cmdstan_version)
414423
)
415424
clean_all(verbose)
416425
print('Rebuilding version {}'.format(cmdstan_version))
@@ -477,9 +486,9 @@ def retrieve_version(version: str, progress: bool = True) -> None:
477486
for i in range(6): # always retry to allow for transient URLErrors
478487
try:
479488
if progress and progbar.allow_show_progress():
480-
progress_hook: Optional[Callable[[int, int, int], None]] = (
481-
wrap_url_progress_hook()
482-
)
489+
progress_hook: Optional[
490+
Callable[[int, int, int], None]
491+
] = wrap_url_progress_hook()
483492
else:
484493
progress_hook = None
485494
file_tmp, _ = urllib.request.urlretrieve(
@@ -488,13 +497,14 @@ def retrieve_version(version: str, progress: bool = True) -> None:
488497
break
489498
except urllib.error.HTTPError as e:
490499
raise CmdStanRetrieveError(
491-
'HTTPError: {}\nVersion {} not available from github.com.'.format(
492-
e.code, version
493-
)
500+
'HTTPError: {}\nVersion {} not available from '
501+
'github.com.'.format(e.code, version)
494502
) from e
495503
except urllib.error.URLError as e:
496504
print(
497-
'Failed to download CmdStan version {} from github.com'.format(version)
505+
'Failed to download CmdStan version {} from github.com'.format(
506+
version
507+
)
498508
)
499509
print(e)
500510
if i < 5:
@@ -640,7 +650,8 @@ def parse_cmdline_args() -> dict[str, Any]:
640650
'--interactive',
641651
'-i',
642652
action='store_true',
643-
help="Ignore other arguments and run the installation in " + "interactive mode",
653+
help="Ignore other arguments and run the installation in "
654+
+ "interactive mode",
644655
)
645656
parser.add_argument(
646657
'--version',

0 commit comments

Comments
 (0)