@@ -105,7 +105,9 @@ def validate(self, chains: Optional[int]) -> None:
105
105
* length of per-chain lists equals specified # of chains
106
106
"""
107
107
if not isinstance (chains , (int , np .integer )) or chains < 1 :
108
- raise ValueError ('Sampler expects number of chains to be greater than 0.' )
108
+ raise ValueError (
109
+ 'Sampler expects number of chains to be greater than 0.'
110
+ )
109
111
if not (
110
112
self .adapt_delta is None
111
113
and self .adapt_init_phase is None
@@ -117,13 +119,17 @@ def validate(self, chains: Optional[int]) -> None:
117
119
if self .adapt_delta is not None :
118
120
msg = '{}, adapt_delta: {}' .format (msg , self .adapt_delta )
119
121
if self .adapt_init_phase is not None :
120
- msg = '{}, adapt_init_phase: {}' .format (msg , self .adapt_init_phase )
122
+ msg = '{}, adapt_init_phase: {}' .format (
123
+ msg , self .adapt_init_phase
124
+ )
121
125
if self .adapt_metric_window is not None :
122
126
msg = '{}, adapt_metric_window: {}' .format (
123
127
msg , self .adapt_metric_window
124
128
)
125
129
if self .adapt_step_size is not None :
126
- msg = '{}, adapt_step_size: {}' .format (msg , self .adapt_step_size )
130
+ msg = '{}, adapt_step_size: {}' .format (
131
+ msg , self .adapt_step_size
132
+ )
127
133
raise ValueError (msg )
128
134
129
135
if self .iter_warmup is not None :
@@ -151,7 +157,9 @@ def validate(self, chains: Optional[int]) -> None:
151
157
positive_int (self .max_treedepth , 'max_treedepth' )
152
158
153
159
if self .step_size is not None :
154
- if isinstance (self .step_size , (float , int , np .integer , np .floating )):
160
+ if isinstance (
161
+ self .step_size , (float , int , np .integer , np .floating )
162
+ ):
155
163
if self .step_size <= 0 :
156
164
raise ValueError (
157
165
'Argument "step_size" must be > 0, found {}.' .format (
@@ -189,7 +197,9 @@ def validate(self, chains: Optional[int]) -> None:
189
197
self .metric_file = self .metric
190
198
elif isinstance (self .metric , dict ):
191
199
if 'inv_metric' not in self .metric :
192
- raise ValueError ('Entry "inv_metric" not found in metric dict.' )
200
+ raise ValueError (
201
+ 'Entry "inv_metric" not found in metric dict.'
202
+ )
193
203
dims = list (np .asarray (self .metric ['inv_metric' ]).shape )
194
204
if len (dims ) == 1 :
195
205
self .metric_type = 'diag_e'
@@ -218,14 +228,20 @@ def validate(self, chains: Optional[int]) -> None:
218
228
'for chain {}.' .format (i + 1 )
219
229
)
220
230
if i == 0 :
221
- dims = list (np .asarray (metric_dict ['inv_metric' ]).shape )
231
+ dims = list (
232
+ np .asarray (metric_dict ['inv_metric' ]).shape
233
+ )
222
234
else :
223
- dims2 = list (np .asarray (metric_dict ['inv_metric' ]).shape )
235
+ dims2 = list (
236
+ np .asarray (metric_dict ['inv_metric' ]).shape
237
+ )
224
238
if dims != dims2 :
225
239
raise ValueError (
226
240
'Found inconsistent "inv_metric" entry '
227
241
'for chain {}: entry has dims '
228
- '{}, expected {}.' .format (i + 1 , dims , dims2 )
242
+ '{}, expected {}.' .format (
243
+ i + 1 , dims , dims2
244
+ )
229
245
)
230
246
dict_file = create_named_text_file (
231
247
dir = _TMPDIR , prefix = "metric" , suffix = ".json"
@@ -249,13 +265,15 @@ def validate(self, chains: Optional[int]) -> None:
249
265
dims2 = read_metric (metric )
250
266
if len (dims ) != len (dims2 ):
251
267
raise ValueError (
252
- 'Metrics files {}, {}, inconsistent metrics' .format (
268
+ 'Metrics files {}, {},'
269
+ ' inconsistent metrics' .format (
253
270
self .metric [0 ], metric
254
271
)
255
272
)
256
273
if dims != dims2 :
257
274
raise ValueError (
258
- 'Metrics files {}, {}, inconsistent metrics' .format (
275
+ 'Metrics files {}, {},'
276
+ ' inconsistent metrics' .format (
259
277
self .metric [0 ], metric
260
278
)
261
279
)
@@ -268,7 +286,9 @@ def validate(self, chains: Optional[int]) -> None:
268
286
else :
269
287
raise ValueError (
270
288
'Argument "metric" must be a list of pathnames or '
271
- 'Python dicts, found list of {}.' .format (type (self .metric [0 ]))
289
+ 'Python dicts, found list of {}.' .format (
290
+ type (self .metric [0 ])
291
+ )
272
292
)
273
293
else :
274
294
raise ValueError (
@@ -281,9 +301,8 @@ def validate(self, chains: Optional[int]) -> None:
281
301
if self .adapt_delta is not None :
282
302
if not 0 < self .adapt_delta < 1 :
283
303
raise ValueError (
284
- 'Argument "adapt_delta" must be between 0 and 1, found {}' .format (
285
- self .adapt_delta
286
- )
304
+ 'Argument "adapt_delta" must be between 0 and 1,'
305
+ ' found {}' .format (self .adapt_delta )
287
306
)
288
307
if self .adapt_init_phase is not None :
289
308
if self .adapt_init_phase < 0 or not isinstance (
@@ -437,7 +456,9 @@ def validate(self, _chains: Optional[int] = None) -> None:
437
456
)
438
457
if self .algorithm .lower () != 'lbfgs' :
439
458
if self .history_size is not None :
440
- raise ValueError ('history_size requires that algorithm be set to lbfgs' )
459
+ raise ValueError (
460
+ 'history_size requires that algorithm be set to lbfgs'
461
+ )
441
462
442
463
positive_float (self .init_alpha , 'init_alpha' )
443
464
positive_int (self .iter , 'iter' )
@@ -620,7 +641,9 @@ def validate(
620
641
"""
621
642
for csv in self .sample_csv_files :
622
643
if not os .path .exists (csv ):
623
- raise ValueError ('Invalid path for sample csv file: {}' .format (csv ))
644
+ raise ValueError (
645
+ 'Invalid path for sample csv file: {}' .format (csv )
646
+ )
624
647
625
648
def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
626
649
"""
@@ -667,7 +690,10 @@ def validate(
667
690
"""
668
691
Check arguments correctness and consistency.
669
692
"""
670
- if self .algorithm is not None and self .algorithm not in self .VARIATIONAL_ALGOS :
693
+ if (
694
+ self .algorithm is not None
695
+ and self .algorithm not in self .VARIATIONAL_ALGOS
696
+ ):
671
697
raise ValueError (
672
698
'Please specify variational algorithms as one of [{}]' .format (
673
699
', ' .join (self .VARIATIONAL_ALGOS )
@@ -794,16 +820,19 @@ def validate(self) -> None:
794
820
if chain_id < 1 :
795
821
raise ValueError ('invalid chain_id {}' .format (chain_id ))
796
822
if self .output_dir is not None :
797
- self .output_dir = os .path .realpath (os .path .expanduser (self .output_dir ))
823
+ self .output_dir = os .path .realpath (
824
+ os .path .expanduser (self .output_dir )
825
+ )
798
826
if not os .path .exists (self .output_dir ):
799
827
try :
800
828
os .makedirs (self .output_dir )
801
- get_logger ().info ('created output directory: %s' , self .output_dir )
829
+ get_logger ().info (
830
+ 'created output directory: %s' , self .output_dir
831
+ )
802
832
except (RuntimeError , PermissionError ) as exc :
803
833
raise ValueError (
804
- 'Invalid path for output files, no such dir: {}.' .format (
805
- self .output_dir
806
- )
834
+ 'Invalid path for output files, '
835
+ 'no such dir: {}.' .format (self .output_dir )
807
836
) from exc
808
837
if not os .path .isdir (self .output_dir ):
809
838
raise ValueError (
@@ -818,12 +847,14 @@ def validate(self) -> None:
818
847
os .remove (testpath ) # cleanup
819
848
except Exception as exc :
820
849
raise ValueError (
821
- 'Invalid path for output files, cannot write to dir: {}.' .format (
822
- self .output_dir
823
- )
850
+ 'Invalid path for output files,'
851
+ ' cannot write to dir: {}.' .format (self .output_dir )
824
852
) from exc
825
853
if self .refresh is not None :
826
- if not isinstance (self .refresh , (int , np .integer )) or self .refresh < 1 :
854
+ if (
855
+ not isinstance (self .refresh , (int , np .integer ))
856
+ or self .refresh < 1
857
+ ):
827
858
raise ValueError (
828
859
'Argument "refresh" must be a positive integer value, '
829
860
'found {}.' .format (self .refresh )
@@ -895,7 +926,9 @@ def validate(self) -> None:
895
926
if isinstance (self .inits , (float , int , np .floating , np .integer )):
896
927
if self .inits < 0 :
897
928
raise ValueError (
898
- 'Argument "inits" must be > 0, found {}' .format (self .inits )
929
+ 'Argument "inits" must be > 0, found {}' .format (
930
+ self .inits
931
+ )
899
932
)
900
933
elif isinstance (self .inits , str ):
901
934
if not (
0 commit comments