2
2
3
3
import glob
4
4
import os
5
- from typing import Any , Dict , List , Optional , Union
5
+ from typing import List , Optional , Union
6
6
7
7
from cmdstanpy .cmdstan_args import (
8
8
CmdStanArgs ,
12
12
SamplerArgs ,
13
13
VariationalArgs ,
14
14
)
15
- from cmdstanpy .utils import check_sampler_csv , get_logger , scan_config
15
+ from cmdstanpy .utils import check_sampler_csv , get_logger , stancsv
16
16
17
17
from .gq import CmdStanGQ
18
18
from .laplace import CmdStanLaplace
@@ -103,10 +103,9 @@ def from_csv(
103
103
' includes non-csv file: {}' .format (file )
104
104
)
105
105
106
- config_dict : Dict [str , Any ] = {}
107
106
try :
108
- with open (csvfiles [0 ], 'r' ) as fd :
109
- scan_config ( fd , config_dict , 0 )
107
+ comments , * _ = stancsv . parse_comments_header_and_draws (csvfiles [0 ])
108
+ config_dict = stancsv . parse_config ( comments )
110
109
except (IOError , OSError , PermissionError ) as e :
111
110
raise ValueError ('Cannot read CSV file: {}' .format (csvfiles [0 ])) from e
112
111
if 'model' not in config_dict or 'method' not in config_dict :
@@ -118,39 +117,43 @@ def from_csv(
118
117
method , config_dict ['method' ]
119
118
)
120
119
)
120
+ model : str = config_dict ['model' ] # type: ignore
121
121
try :
122
122
if config_dict ['method' ] == 'sample' :
123
+ save_warmup = config_dict ['save_warmup' ] == 1
123
124
chains = len (csvfiles )
125
+ num_samples : int = config_dict ['num_samples' ] # type: ignore
126
+ num_warmup : int = config_dict ['num_warmup' ] # type: ignore
127
+ thin : int = config_dict ['thin' ] # type: ignore
124
128
sampler_args = SamplerArgs (
125
- iter_sampling = config_dict [ ' num_samples' ] ,
126
- iter_warmup = config_dict [ ' num_warmup' ] ,
127
- thin = config_dict [ ' thin' ] ,
128
- save_warmup = config_dict [ ' save_warmup' ] ,
129
+ iter_sampling = num_samples ,
130
+ iter_warmup = num_warmup ,
131
+ thin = thin ,
132
+ save_warmup = save_warmup ,
129
133
)
130
134
# bugfix 425, check for fixed_params output
131
135
try :
132
136
check_sampler_csv (
133
137
csvfiles [0 ],
134
- iter_sampling = config_dict [ ' num_samples' ] ,
135
- iter_warmup = config_dict [ ' num_warmup' ] ,
136
- thin = config_dict [ ' thin' ] ,
137
- save_warmup = config_dict [ ' save_warmup' ] ,
138
+ iter_sampling = num_samples ,
139
+ iter_warmup = num_warmup ,
140
+ thin = thin ,
141
+ save_warmup = save_warmup ,
138
142
)
139
143
except ValueError :
140
144
try :
141
145
check_sampler_csv (
142
146
csvfiles [0 ],
143
- is_fixed_param = True ,
144
- iter_sampling = config_dict ['num_samples' ],
145
- iter_warmup = config_dict ['num_warmup' ],
146
- thin = config_dict ['thin' ],
147
- save_warmup = config_dict ['save_warmup' ],
147
+ iter_sampling = num_samples ,
148
+ iter_warmup = num_warmup ,
149
+ thin = thin ,
150
+ save_warmup = save_warmup ,
148
151
)
149
152
sampler_args = SamplerArgs (
150
- iter_sampling = config_dict [ ' num_samples' ] ,
151
- iter_warmup = config_dict [ ' num_warmup' ] ,
152
- thin = config_dict [ ' thin' ] ,
153
- save_warmup = config_dict [ ' save_warmup' ] ,
153
+ iter_sampling = num_samples ,
154
+ iter_warmup = num_warmup ,
155
+ thin = thin ,
156
+ save_warmup = save_warmup ,
154
157
fixed_param = True ,
155
158
)
156
159
except ValueError as e :
@@ -159,8 +162,8 @@ def from_csv(
159
162
) from e
160
163
161
164
cmdstan_args = CmdStanArgs (
162
- model_name = config_dict [ ' model' ] ,
163
- model_exe = config_dict [ ' model' ] ,
165
+ model_name = model ,
166
+ model_exe = model ,
164
167
chain_ids = [x + 1 for x in range (chains )],
165
168
method_args = sampler_args ,
166
169
)
@@ -177,14 +180,18 @@ def from_csv(
177
180
"Cannot find optimization algorithm"
178
181
" in file {}." .format (csvfiles [0 ])
179
182
)
183
+ algorithm : str = config_dict ['algorithm' ] # type: ignore
184
+ save_iterations = config_dict ['save_iterations' ] == 1
185
+ jacobian = config_dict .get ('jacobian' , 0 ) == 1
186
+
180
187
optimize_args = OptimizeArgs (
181
- algorithm = config_dict [ ' algorithm' ] ,
182
- save_iterations = config_dict [ ' save_iterations' ] ,
183
- jacobian = config_dict . get ( ' jacobian' , 0 ) ,
188
+ algorithm = algorithm ,
189
+ save_iterations = save_iterations ,
190
+ jacobian = jacobian ,
184
191
)
185
192
cmdstan_args = CmdStanArgs (
186
- model_name = config_dict [ ' model' ] ,
187
- model_exe = config_dict [ ' model' ] ,
193
+ model_name = model ,
194
+ model_exe = model ,
188
195
chain_ids = None ,
189
196
method_args = optimize_args ,
190
197
)
@@ -200,18 +207,18 @@ def from_csv(
200
207
" in file {}." .format (csvfiles [0 ])
201
208
)
202
209
variational_args = VariationalArgs (
203
- algorithm = config_dict ['algorithm' ],
204
- iter = config_dict ['iter' ],
205
- grad_samples = config_dict ['grad_samples' ],
206
- elbo_samples = config_dict ['elbo_samples' ],
207
- eta = config_dict ['eta' ],
208
- tol_rel_obj = config_dict ['tol_rel_obj' ],
209
- eval_elbo = config_dict ['eval_elbo' ],
210
- output_samples = config_dict ['output_samples' ],
210
+ algorithm = config_dict ['algorithm' ], # type: ignore
211
+ iter = config_dict ['iter' ], # type: ignore
212
+ grad_samples = config_dict ['grad_samples' ], # type: ignore
213
+ elbo_samples = config_dict ['elbo_samples' ], # type: ignore
214
+ eta = config_dict ['eta' ], # type: ignore
215
+ tol_rel_obj = config_dict ['tol_rel_obj' ], # type: ignore
216
+ eval_elbo = config_dict ['eval_elbo' ], # type: ignore
217
+ output_samples = config_dict ['output_samples' ], # type: ignore
211
218
)
212
219
cmdstan_args = CmdStanArgs (
213
- model_name = config_dict [ ' model' ] ,
214
- model_exe = config_dict [ ' model' ] ,
220
+ model_name = model ,
221
+ model_exe = model ,
215
222
chain_ids = None ,
216
223
method_args = variational_args ,
217
224
)
@@ -221,14 +228,15 @@ def from_csv(
221
228
runset ._set_retcode (i , 0 )
222
229
return CmdStanVB (runset )
223
230
elif config_dict ['method' ] == 'laplace' :
231
+ jacobian = config_dict ['jacobian' ] == 1
224
232
laplace_args = LaplaceArgs (
225
- mode = config_dict ['mode' ],
226
- draws = config_dict ['draws' ],
227
- jacobian = config_dict [ ' jacobian' ] ,
233
+ mode = config_dict ['mode' ], # type: ignore
234
+ draws = config_dict ['draws' ], # type: ignore
235
+ jacobian = jacobian ,
228
236
)
229
237
cmdstan_args = CmdStanArgs (
230
- model_name = config_dict [ ' model' ] ,
231
- model_exe = config_dict [ ' model' ] ,
238
+ model_name = model ,
239
+ model_exe = model ,
232
240
chain_ids = None ,
233
241
method_args = laplace_args ,
234
242
)
@@ -237,18 +245,18 @@ def from_csv(
237
245
for i in range (len (runset ._retcodes )):
238
246
runset ._set_retcode (i , 0 )
239
247
mode : CmdStanMLE = from_csv (
240
- config_dict ['mode' ],
248
+ config_dict ['mode' ], # type: ignore
241
249
method = 'optimize' ,
242
250
) # type: ignore
243
251
return CmdStanLaplace (runset , mode = mode )
244
252
elif config_dict ['method' ] == 'pathfinder' :
245
253
pathfinder_args = PathfinderArgs (
246
- num_draws = config_dict ['num_draws' ],
247
- num_paths = config_dict ['num_paths' ],
254
+ num_draws = config_dict ['num_draws' ], # type: ignore
255
+ num_paths = config_dict ['num_paths' ], # type: ignore
248
256
)
249
257
cmdstan_args = CmdStanArgs (
250
- model_name = config_dict [ ' model' ] ,
251
- model_exe = config_dict [ ' model' ] ,
258
+ model_name = model ,
259
+ model_exe = model ,
252
260
chain_ids = None ,
253
261
method_args = pathfinder_args ,
254
262
)
0 commit comments