Skip to content

Commit e45e36b

Browse files
Define types of start kwargs
Take out leftover start-from-trace support. And rearrange some code blocks for easier refactoring later.
1 parent c6e9153 commit e45e36b

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

pymc/parallel_sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import traceback
2222

2323
from collections import namedtuple
24+
from typing import Dict, Sequence
2425

2526
import cloudpickle
2627
import numpy as np
@@ -227,7 +228,7 @@ def __init__(
227228
step_method_pickled,
228229
chain: int,
229230
seed,
230-
start,
231+
start: Dict[str, np.ndarray],
231232
mp_ctx,
232233
):
233234
self.chain = chain
@@ -389,7 +390,7 @@ def __init__(
389390
chains: int,
390391
cores: int,
391392
seeds: list,
392-
start_points: list,
393+
start_points: Sequence[Dict[str, np.ndarray]],
393394
step_method,
394395
start_chain_num: int = 0,
395396
progressbar: bool = True,

pymc/sampling.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from collections import defaultdict
2525
from copy import copy, deepcopy
26-
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
26+
from typing import Dict, Iterable, List, Optional, Sequence, Set, Union, cast
2727

2828
import aesara.gradient as tg
2929
import cloudpickle
@@ -253,7 +253,7 @@ def sample(
253253
step=None,
254254
init="auto",
255255
n_init=200_000,
256-
start=None,
256+
start: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None,
257257
trace: Optional[Union[BaseTrace, List[str]]] = None,
258258
chain_idx=0,
259259
chains=None,
@@ -689,7 +689,7 @@ def _sample_many(
689689
draws,
690690
chain: int,
691691
chains: int,
692-
start: list,
692+
start: Sequence[PointType],
693693
random_seed: list,
694694
step,
695695
callback=None,
@@ -746,7 +746,7 @@ def _sample_population(
746746
draws: int,
747747
chain: int,
748748
chains: int,
749-
start,
749+
start: Sequence[PointType],
750750
random_seed,
751751
step,
752752
tune,
@@ -809,7 +809,7 @@ def _sample(
809809
chain: int,
810810
progressbar: bool,
811811
random_seed,
812-
start,
812+
start: PointType,
813813
draws: int,
814814
step=None,
815815
trace: Optional[Union[BaseTrace, List[str]]] = None,
@@ -875,7 +875,7 @@ def _sample(
875875
def iter_sample(
876876
draws: int,
877877
step,
878-
start: Optional[Dict[Any, Any]] = None,
878+
start: PointType,
879879
trace=None,
880880
chain=0,
881881
tune: Optional[int] = None,
@@ -895,8 +895,7 @@ def iter_sample(
895895
step : function
896896
Step function
897897
start : dict
898-
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
899-
there is a trace provided and model.initial_point if not (defaults to empty dict)
898+
Starting point in parameter space (or partial point).
900899
trace : backend or list
901900
This should be a backend instance, or a list of variables to track.
902901
If None or a list of variables, the NDArray backend is used.
@@ -935,7 +934,7 @@ def iter_sample(
935934
def _iter_sample(
936935
draws,
937936
step,
938-
start=None,
937+
start: Optional[PointType],
939938
trace: Optional[Union[BaseTrace, List[str]]] = None,
940939
chain=0,
941940
tune=None,
@@ -951,8 +950,8 @@ def _iter_sample(
951950
The number of samples to draw
952951
step : function
953952
Step function
954-
start : dict, optional
955-
Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
953+
start : dict
954+
Starting point in parameter space (or partial point).
956955
trace : backend or list
957956
This should be a backend instance, or a list of variables to track.
958957
If None or a list of variables, the NDArray backend is used.
@@ -978,18 +977,16 @@ def _iter_sample(
978977
if draws < 1:
979978
raise ValueError("Argument `draws` must be greater than 0.")
980979

981-
if start is None:
982-
start = {}
983-
984980
strace = _choose_backend(trace, model=model)
985981

986-
model.update_start_vals(start, model.initial_point)
987-
988982
try:
989983
step = CompoundStep(step)
990984
except TypeError:
991985
pass
992986

987+
if start is None:
988+
start = {}
989+
model.update_start_vals(start, model.initial_point)
993990
point = Point(start, model=model, filter_model_vars=True)
994991

995992
if step.generates_stats and strace.supports_sampler_stats:
@@ -1200,7 +1197,7 @@ def _prepare_iter_population(
12001197
draws: int,
12011198
chains: list,
12021199
step,
1203-
start: list,
1200+
start: Sequence[PointType],
12041201
parallelize: bool,
12051202
tune=None,
12061203
model=None,
@@ -1253,10 +1250,7 @@ def _prepare_iter_population(
12531250
traces = [_choose_backend(None, model=model) for chain in chains]
12541251
for c, strace in enumerate(traces):
12551252
# initialize the trace size and variable transforms
1256-
if len(strace) > 0:
1257-
model.update_start_vals(start[c], strace.point(-1))
1258-
else:
1259-
model.update_start_vals(start[c], model.initial_point)
1253+
model.update_start_vals(start[c], model.initial_point)
12601254

12611255
# 2. create a population (points) that tracks each chain
12621256
# it is updated as the chains are advanced
@@ -1390,7 +1384,7 @@ def _mp_sample(
13901384
cores: int,
13911385
chain: int,
13921386
random_seed: list,
1393-
start: list,
1387+
start: Sequence[PointType],
13941388
progressbar=True,
13951389
trace: Optional[Union[BaseTrace, List[str]]] = None,
13961390
model=None,
@@ -1448,9 +1442,11 @@ def _mp_sample(
14481442
strace = _choose_backend(copy(trace), model=model)
14491443
else:
14501444
strace = _choose_backend(None, model=model)
1445+
14511446
# for user supplied start value, fill-in missing value if the supplied
14521447
# dict does not contain all parameters
14531448
model.update_start_vals(start[idx - chain], model.initial_point)
1449+
14541450
if step.generates_stats and strace.supports_sampler_stats:
14551451
strace.setup(draws + tune, idx, step.stats_dtypes)
14561452
else:

0 commit comments

Comments
 (0)