Skip to content

Commit 6f15cbb

Browse files
authored
Fix bug whereby partial traces have fewer draws than would be available (#4318)
* add test for _choose_chains * fix bug - choose overall maximum * update release notes * 📝 * 🎨 * minimise diff * minimise diff * Update pymc3/sampling.py
1 parent 2a38198 commit 6f15cbb

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## PyMC3 3.10.1 (on deck)
44

55
### Maintenance
6+
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
67
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
78

89
## PyMC3 3.10.0 (7 December 2020)

pymc3/sampling.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,14 @@ def _mp_sample(
15081508

15091509

15101510
def _choose_chains(traces, tune):
1511+
"""
1512+
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
1513+
1514+
We get here after a ``KeyboardInterrupt``, and so the different
1515+
traces have different lengths. We therefore pick the number of
1516+
traces such that (number of traces) * (length of shortest trace)
1517+
is maximised.
1518+
"""
15111519
if tune is None:
15121520
tune = 0
15131521

@@ -1518,22 +1526,13 @@ def _choose_chains(traces, tune):
15181526
if not sum(lengths):
15191527
raise ValueError("Not enough samples to build a trace.")
15201528

1521-
idxs = np.argsort(lengths)[::-1]
1529+
idxs = np.argsort(lengths)
15221530
l_sort = np.array(lengths)[idxs]
15231531

1524-
final_length = l_sort[0]
1525-
last_total = 0
1526-
for i, length in enumerate(l_sort):
1527-
total = (i + 1) * length
1528-
if total < last_total:
1529-
use_until = i
1530-
break
1531-
last_total = total
1532-
final_length = length
1533-
else:
1534-
use_until = len(lengths)
1532+
use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])
1533+
final_length = l_sort[use_until]
15351534

1536-
return [traces[idx] for idx in idxs[:use_until]], final_length + tune
1535+
return [traces[idx] for idx in idxs[use_until:]], final_length + tune
15371536

15381537

15391538
def stop_tuning(step):

pymc3/tests/test_sampling.py

+28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import pymc3 as pm
3232

33+
from pymc3.backends.ndarray import NDArray
3334
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
3435
from pymc3.tests.helpers import SeededTest
3536
from pymc3.tests.models import simple_init
@@ -299,6 +300,33 @@ def test_partial_trace_sample():
299300
trace = pm.sample(trace=[a])
300301

301302

303+
@pytest.mark.parametrize(
304+
"n_points, tune, expected_length, expected_n_traces",
305+
[
306+
((5, 2, 2), 0, 2, 3),
307+
((6, 1, 1), 1, 6, 1),
308+
],
309+
)
310+
def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
311+
with pm.Model() as model:
312+
a = pm.Normal("a", mu=0, sigma=1)
313+
trace_0 = NDArray(model)
314+
trace_1 = NDArray(model)
315+
trace_2 = NDArray(model)
316+
trace_0.setup(n_points[0], 1)
317+
trace_1.setup(n_points[1], 1)
318+
trace_2.setup(n_points[2], 1)
319+
for _ in range(n_points[0]):
320+
trace_0.record({"a": 0})
321+
for _ in range(n_points[1]):
322+
trace_1.record({"a": 0})
323+
for _ in range(n_points[2]):
324+
trace_2.record({"a": 0})
325+
traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune)
326+
assert length == expected_length
327+
assert expected_n_traces == len(traces)
328+
329+
302330
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
303331
class TestNamedSampling(SeededTest):
304332
def test_shared_named(self):

0 commit comments

Comments
 (0)