1
1
"""
2
2
CmdStan arguments
3
3
"""
4
+
4
5
import os
5
6
from enum import Enum , auto
6
7
from time import time
7
- from typing import Any , Dict , List , Mapping , Optional , Union
8
+ from typing import Any , Mapping , Optional , Union
8
9
9
10
import numpy as np
10
11
from numpy .random import default_rng
@@ -65,9 +66,9 @@ def __init__(
65
66
thin : Optional [int ] = None ,
66
67
max_treedepth : Optional [int ] = None ,
67
68
metric : Union [
68
- str , Dict [str , Any ], List [str ], List [ Dict [str , Any ]], None
69
+ str , dict [str , Any ], list [str ], list [ dict [str , Any ]], None
69
70
] = None ,
70
- step_size : Union [float , List [float ], None ] = None ,
71
+ step_size : Union [float , list [float ], None ] = None ,
71
72
adapt_engaged : bool = True ,
72
73
adapt_delta : Optional [float ] = None ,
73
74
adapt_init_phase : Optional [int ] = None ,
@@ -84,7 +85,7 @@ def __init__(
84
85
self .max_treedepth = max_treedepth
85
86
self .metric = metric
86
87
self .metric_type : Optional [str ] = None
87
- self .metric_file : Union [str , List [str ], None ] = None
88
+ self .metric_file : Union [str , list [str ], None ] = None
88
89
self .step_size = step_size
89
90
self .adapt_engaged = adapt_engaged
90
91
self .adapt_delta = adapt_delta
@@ -161,8 +162,9 @@ def validate(self, chains: Optional[int]) -> None:
161
162
):
162
163
if self .step_size <= 0 :
163
164
raise ValueError (
164
- 'Argument "step_size" must be > 0, '
165
- 'found {}.' .format (self .step_size )
165
+ 'Argument "step_size" must be > 0, found {}.' .format (
166
+ self .step_size
167
+ )
166
168
)
167
169
else :
168
170
if len (self .step_size ) != chains :
@@ -217,9 +219,9 @@ def validate(self, chains: Optional[int]) -> None:
217
219
)
218
220
)
219
221
if all (isinstance (elem , dict ) for elem in self .metric ):
220
- metric_files : List [str ] = []
222
+ metric_files : list [str ] = []
221
223
for i , metric in enumerate (self .metric ):
222
- metric_dict : Dict [str , Any ] = metric # type: ignore
224
+ metric_dict : dict [str , Any ] = metric # type: ignore
223
225
if 'inv_metric' not in metric_dict :
224
226
raise ValueError (
225
227
'Entry "inv_metric" not found in metric dict '
@@ -343,7 +345,7 @@ def validate(self, chains: Optional[int]) -> None:
343
345
'When fixed_param=True, cannot specify adaptation parameters.'
344
346
)
345
347
346
- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
348
+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
347
349
"""
348
350
Compose CmdStan command for method-specific non-default arguments.
349
351
"""
@@ -467,7 +469,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
467
469
positive_float (self .tol_param , 'tol_param' )
468
470
positive_int (self .history_size , 'history_size' )
469
471
470
- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
472
+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
471
473
"""compose command string for CmdStan for non-default arg values."""
472
474
cmd .append ('method=optimize' )
473
475
if self .algorithm :
@@ -511,7 +513,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
511
513
raise ValueError (f'Invalid path for mode file: { self .mode } ' )
512
514
positive_int (self .draws , 'draws' )
513
515
514
- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
516
+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
515
517
"""compose command string for CmdStan for non-default arg values."""
516
518
cmd .append ('method=laplace' )
517
519
cmd .append (f'mode={ self .mode } ' )
@@ -579,7 +581,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
579
581
positive_int (self .num_draws , 'num_draws' )
580
582
positive_int (self .num_elbo_draws , 'num_elbo_draws' )
581
583
582
- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
584
+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
583
585
"""compose command string for CmdStan for non-default arg values."""
584
586
cmd .append ('method=pathfinder' )
585
587
@@ -624,12 +626,13 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]:
624
626
class GenerateQuantitiesArgs :
625
627
"""Arguments needed for generate_quantities method."""
626
628
627
- def __init__ (self , csv_files : List [str ]) -> None :
629
+ def __init__ (self , csv_files : list [str ]) -> None :
628
630
"""Initialize object."""
629
631
self .sample_csv_files = csv_files
630
632
631
633
def validate (
632
- self , chains : Optional [int ] = None # pylint: disable=unused-argument
634
+ self ,
635
+ chains : Optional [int ] = None , # pylint: disable=unused-argument
633
636
) -> None :
634
637
"""
635
638
Check arguments correctness and consistency.
@@ -642,7 +645,7 @@ def validate(
642
645
'Invalid path for sample csv file: {}' .format (csv )
643
646
)
644
647
645
- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
648
+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
646
649
"""
647
650
Compose CmdStan command for method-specific non-default arguments.
648
651
"""
@@ -681,7 +684,8 @@ def __init__(
681
684
self .output_samples = output_samples
682
685
683
686
def validate (
684
- self , chains : Optional [int ] = None # pylint: disable=unused-argument
687
+ self ,
688
+ chains : Optional [int ] = None , # pylint: disable=unused-argument
685
689
) -> None :
686
690
"""
687
691
Check arguments correctness and consistency.
@@ -705,7 +709,7 @@ def validate(
705
709
positive_int (self .output_samples , 'output_samples' )
706
710
707
711
# pylint: disable=unused-argument
708
- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
712
+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
709
713
"""
710
714
Compose CmdStan command for method-specific non-default arguments.
711
715
"""
@@ -747,7 +751,7 @@ def __init__(
747
751
self ,
748
752
model_name : str ,
749
753
model_exe : OptionalPath ,
750
- chain_ids : Optional [List [int ]],
754
+ chain_ids : Optional [list [int ]],
751
755
method_args : Union [
752
756
SamplerArgs ,
753
757
OptimizeArgs ,
@@ -757,8 +761,8 @@ def __init__(
757
761
PathfinderArgs ,
758
762
],
759
763
data : Union [Mapping [str , Any ], str , None ] = None ,
760
- seed : Union [int , List [int ], None ] = None ,
761
- inits : Union [int , float , str , List [str ], None ] = None ,
764
+ seed : Union [int , list [int ], None ] = None ,
765
+ inits : Union [int , float , str , list [str ], None ] = None ,
762
766
output_dir : OptionalPath = None ,
763
767
sig_figs : Optional [int ] = None ,
764
768
save_latent_dynamics : bool = False ,
@@ -959,11 +963,11 @@ def compose_command(
959
963
* ,
960
964
diagnostic_file : Optional [str ] = None ,
961
965
profile_file : Optional [str ] = None ,
962
- ) -> List [str ]:
966
+ ) -> list [str ]:
963
967
"""
964
968
Compose CmdStan command for non-default arguments.
965
969
"""
966
- cmd : List [str ] = []
970
+ cmd : list [str ] = []
967
971
if idx is not None and self .chain_ids is not None :
968
972
if idx < 0 or idx > len (self .chain_ids ) - 1 :
969
973
raise ValueError (
0 commit comments