23
23
24
24
from collections import defaultdict
25
25
from copy import copy , deepcopy
26
- from typing import Any , Dict , Iterable , List , Optional , Set , Union , cast
26
+ from typing import Dict , Iterable , List , Optional , Sequence , Set , Union , cast
27
27
28
28
import aesara .gradient as tg
29
29
import cloudpickle
@@ -253,7 +253,7 @@ def sample(
253
253
step = None ,
254
254
init = "auto" ,
255
255
n_init = 200_000 ,
256
- start = None ,
256
+ start : Optional [ Union [ PointType , Sequence [ Optional [ PointType ]]]] = None ,
257
257
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
258
258
chain_idx = 0 ,
259
259
chains = None ,
@@ -689,7 +689,7 @@ def _sample_many(
689
689
draws ,
690
690
chain : int ,
691
691
chains : int ,
692
- start : list ,
692
+ start : Sequence [ PointType ] ,
693
693
random_seed : list ,
694
694
step ,
695
695
callback = None ,
@@ -746,7 +746,7 @@ def _sample_population(
746
746
draws : int ,
747
747
chain : int ,
748
748
chains : int ,
749
- start ,
749
+ start : Sequence [ PointType ] ,
750
750
random_seed ,
751
751
step ,
752
752
tune ,
@@ -809,7 +809,7 @@ def _sample(
809
809
chain : int ,
810
810
progressbar : bool ,
811
811
random_seed ,
812
- start ,
812
+ start : PointType ,
813
813
draws : int ,
814
814
step = None ,
815
815
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
@@ -875,7 +875,7 @@ def _sample(
875
875
def iter_sample (
876
876
draws : int ,
877
877
step ,
878
- start : Optional [ Dict [ Any , Any ]] = None ,
878
+ start : PointType ,
879
879
trace = None ,
880
880
chain = 0 ,
881
881
tune : Optional [int ] = None ,
@@ -895,8 +895,7 @@ def iter_sample(
895
895
step : function
896
896
Step function
897
897
start : dict
898
- Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
899
- there is a trace provided and model.initial_point if not (defaults to empty dict)
898
+ Starting point in parameter space (or partial point).
900
899
trace : backend or list
901
900
This should be a backend instance, or a list of variables to track.
902
901
If None or a list of variables, the NDArray backend is used.
@@ -935,7 +934,7 @@ def iter_sample(
935
934
def _iter_sample (
936
935
draws ,
937
936
step ,
938
- start = None ,
937
+ start : Optional [ PointType ] ,
939
938
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
940
939
chain = 0 ,
941
940
tune = None ,
@@ -951,8 +950,8 @@ def _iter_sample(
951
950
The number of samples to draw
952
951
step : function
953
952
Step function
954
- start : dict, optional
955
- Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
953
+ start : dict
954
+ Starting point in parameter space (or partial point).
956
955
trace : backend or list
957
956
This should be a backend instance, or a list of variables to track.
958
957
If None or a list of variables, the NDArray backend is used.
@@ -978,18 +977,16 @@ def _iter_sample(
978
977
if draws < 1 :
979
978
raise ValueError ("Argument `draws` must be greater than 0." )
980
979
981
- if start is None :
982
- start = {}
983
-
984
980
strace = _choose_backend (trace , model = model )
985
981
986
- model .update_start_vals (start , model .initial_point )
987
-
988
982
try :
989
983
step = CompoundStep (step )
990
984
except TypeError :
991
985
pass
992
986
987
+ if start is None :
988
+ start = {}
989
+ model .update_start_vals (start , model .initial_point )
993
990
point = Point (start , model = model , filter_model_vars = True )
994
991
995
992
if step .generates_stats and strace .supports_sampler_stats :
@@ -1200,7 +1197,7 @@ def _prepare_iter_population(
1200
1197
draws : int ,
1201
1198
chains : list ,
1202
1199
step ,
1203
- start : list ,
1200
+ start : Sequence [ PointType ] ,
1204
1201
parallelize : bool ,
1205
1202
tune = None ,
1206
1203
model = None ,
@@ -1253,10 +1250,7 @@ def _prepare_iter_population(
1253
1250
traces = [_choose_backend (None , model = model ) for chain in chains ]
1254
1251
for c , strace in enumerate (traces ):
1255
1252
# initialize the trace size and variable transforms
1256
- if len (strace ) > 0 :
1257
- model .update_start_vals (start [c ], strace .point (- 1 ))
1258
- else :
1259
- model .update_start_vals (start [c ], model .initial_point )
1253
+ model .update_start_vals (start [c ], model .initial_point )
1260
1254
1261
1255
# 2. create a population (points) that tracks each chain
1262
1256
# it is updated as the chains are advanced
@@ -1390,7 +1384,7 @@ def _mp_sample(
1390
1384
cores : int ,
1391
1385
chain : int ,
1392
1386
random_seed : list ,
1393
- start : list ,
1387
+ start : Sequence [ PointType ] ,
1394
1388
progressbar = True ,
1395
1389
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
1396
1390
model = None ,
@@ -1448,9 +1442,11 @@ def _mp_sample(
1448
1442
strace = _choose_backend (copy (trace ), model = model )
1449
1443
else :
1450
1444
strace = _choose_backend (None , model = model )
1445
+
1451
1446
# for user supplied start value, fill-in missing value if the supplied
1452
1447
# dict does not contain all parameters
1453
1448
model .update_start_vals (start [idx - chain ], model .initial_point )
1449
+
1454
1450
if step .generates_stats and strace .supports_sampler_stats :
1455
1451
strace .setup (draws + tune , idx , step .stats_dtypes )
1456
1452
else :
0 commit comments