Skip to content

Commit 47b61de

Browse files
brandonwillardtwiecki
authored andcommitted
Set model-level RandomVariable seeds during sampling
1 parent 090fb88 commit 47b61de

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

pymc3/sampling.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def sample(
448448
random_seed = [random_seed]
449449
if random_seed is None or isinstance(random_seed, int):
450450
if random_seed is not None:
451-
np.random.seed(random_seed)
451+
# np.random.seed(random_seed)
452+
model.default_rng.get_value(borrow=True).seed(random_seed)
452453
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
453454
if not isinstance(random_seed, abc.Iterable):
454455
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
@@ -971,7 +972,8 @@ def _iter_sample(
971972
model = modelcontext(model)
972973
draws = int(draws)
973974
if random_seed is not None:
974-
np.random.seed(random_seed)
975+
# np.random.seed(random_seed)
976+
model.default_rng.get_value(borrow=True).seed(random_seed)
975977
if draws < 1:
976978
raise ValueError("Argument `draws` must be greater than 0.")
977979

@@ -1239,7 +1241,8 @@ def _prepare_iter_population(
12391241
model = modelcontext(model)
12401242
draws = int(draws)
12411243
if random_seed is not None:
1242-
np.random.seed(random_seed)
1244+
# np.random.seed(random_seed)
1245+
model.default_rng.get_value(borrow=True).seed(random_seed)
12431246
if draws < 1:
12441247
raise ValueError("Argument `draws` should be above 0.")
12451248

@@ -1710,7 +1713,8 @@ def sample_posterior_predictive(
17101713
vars_ = model.observed_RVs
17111714

17121715
if random_seed is not None:
1713-
np.random.seed(random_seed)
1716+
# np.random.seed(random_seed)
1717+
model.default_rng.get_value(borrow=True).seed(random_seed)
17141718

17151719
indices = np.arange(samples)
17161720

@@ -1820,7 +1824,7 @@ def sample_posterior_predictive_w(
18201824
Dictionary with the variables as keys. The values corresponding to the
18211825
posterior predictive samples from the weighted models.
18221826
"""
1823-
np.random.seed(random_seed)
1827+
# np.random.seed(random_seed)
18241828

18251829
if isinstance(traces[0], InferenceData):
18261830
n_samples = [
@@ -1837,6 +1841,8 @@ def sample_posterior_predictive_w(
18371841
models = [modelcontext(models)] * len(traces)
18381842

18391843
for model in models:
1844+
if random_seed:
1845+
model.default_rng.get_value(borrow=True).seed(random_seed)
18401846
if model.potentials:
18411847
warnings.warn(
18421848
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
@@ -1976,7 +1982,8 @@ def sample_prior_predictive(
19761982
vars_ = set(var_names)
19771983

19781984
if random_seed is not None:
1979-
np.random.seed(random_seed)
1985+
# np.random.seed(random_seed)
1986+
model.default_rng.get_value(borrow=True).seed(random_seed)
19801987

19811988
names = get_default_varnames(vars_, include_transformed=False)
19821989

@@ -2123,7 +2130,8 @@ def init_nuts(
21232130

21242131
if random_seed is not None:
21252132
random_seed = int(np.atleast_1d(random_seed)[0])
2126-
np.random.seed(random_seed)
2133+
# np.random.seed(random_seed)
2134+
model.default_rng.get_value(borrow=True).seed(random_seed)
21272135

21282136
cb = [
21292137
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),

0 commit comments

Comments
 (0)