Skip to content

Commit ce66620

Browse files
michaelosthegealoctavodia
authored andcommitted
Avoid SamplerReport methods in convergence checks
The goal was to uncouple sampling functions from `MultiTrace` and `SamplerReport`. Some calls to `SamplerReport._log_summary()` were unnecessary because `MultiTrace._add_warnings()` was never called inbetween instantiation and `_log_summary()`, therefore the traces never contained warnings. Running convergence checks and logging the warnings can also be done without needing `MultiTrace` or `SamplerReport` instances/methods.
1 parent db65421 commit ce66620

File tree

5 files changed

+26
-55
lines changed

5 files changed

+26
-55
lines changed

pymc/backends/report.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,7 @@
1717

1818
from typing import Dict, List, Optional
1919

20-
import arviz
21-
22-
from pymc.stats.convergence import (
23-
_LEVELS,
24-
SamplerWarning,
25-
log_warnings,
26-
run_convergence_checks,
27-
)
20+
from pymc.stats.convergence import _LEVELS, SamplerWarning
2821

2922
logger = logging.getLogger("pymc")
3023

@@ -73,22 +66,13 @@ def raise_ok(self, level="error"):
7366
if errors:
7467
raise ValueError("Serious convergence issues during sampling.")
7568

76-
def _run_convergence_checks(self, idata: arviz.InferenceData, model):
77-
warnings = run_convergence_checks(idata, model)
78-
self._add_warnings(warnings)
79-
8069
def _add_warnings(self, warnings, chain=None):
8170
if chain is None:
8271
warn_list = self._global_warnings
8372
else:
8473
warn_list = self._chain_warnings.setdefault(chain, [])
8574
warn_list.extend(warnings)
8675

87-
def _log_summary(self):
88-
for chain, warns in self._chain_warnings.items():
89-
log_warnings(warns)
90-
log_warnings(self._global_warnings)
91-
9276
def _slice(self, start, stop, step):
9377
report = SamplerReport()
9478

pymc/sampling/mcmc.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
from pymc.model import Model, modelcontext
4141
from pymc.sampling.parallel import Draw, _cpu_count
4242
from pymc.sampling.population import _sample_population
43-
from pymc.stats.convergence import log_warning_stats, run_convergence_checks
43+
from pymc.stats.convergence import (
44+
log_warning_stats,
45+
log_warnings,
46+
run_convergence_checks,
47+
)
4448
from pymc.step_methods import NUTS, CompoundStep
4549
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
4650
from pymc.step_methods.hmc import quadpotential
@@ -602,7 +606,6 @@ def sample(
602606
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
603607
f"took {t_sampling:.0f} seconds."
604608
)
605-
mtrace.report._log_summary()
606609

607610
idata = None
608611
if compute_convergence_checks or return_inferencedata:
@@ -612,14 +615,9 @@ def sample(
612615
idata = pm.to_inference_data(mtrace, **ikwargs)
613616

614617
if compute_convergence_checks:
615-
if draws - tune < 100:
616-
warnings.warn(
617-
"The number of samples is too small to check convergence reliably.",
618-
stacklevel=2,
619-
)
620-
else:
621-
convergence_warnings = run_convergence_checks(idata, model)
622-
mtrace.report._add_warnings(convergence_warnings)
618+
warns = run_convergence_checks(idata, model)
619+
mtrace.report._add_warnings(warns)
620+
log_warnings(warns)
623621

624622
if return_inferencedata:
625623
# By default we drop the "warning" stat which contains `SamplerWarning`
@@ -925,9 +923,6 @@ def _mp_sample(
925923
strace = traces[error._chain]
926924
for strace in traces:
927925
strace.close()
928-
929-
multitrace = MultiTrace(traces)
930-
multitrace._report._log_summary()
931926
raise
932927
except KeyboardInterrupt:
933928
pass

pymc/smc/sampling.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pymc.model import Model, modelcontext
3535
from pymc.sampling.parallel import _cpu_count
3636
from pymc.smc.kernels import IMH
37+
from pymc.stats.convergence import log_warnings, run_convergence_checks
3738
from pymc.util import RandomState, _get_seeds_per_chain
3839

3940

@@ -237,7 +238,11 @@ def sample_smc(
237238
)
238239

239240
if compute_convergence_checks:
240-
_compute_convergence_checks(idata, draws, model, trace)
241+
if idata is None:
242+
idata = to_inference_data(trace, log_likelihood=False)
243+
warns = run_convergence_checks(idata, model)
244+
trace.report._add_warnings(warns)
245+
log_warnings(warns)
241246

242247
if return_inferencedata:
243248
assert idata is not None
@@ -298,21 +303,6 @@ def _save_sample_stats(
298303
return sample_stats, idata
299304

300305

301-
def _compute_convergence_checks(
302-
idata: Optional[InferenceData], draws: int, model: Model, trace: MultiTrace
303-
):
304-
if draws < 100:
305-
warnings.warn(
306-
"The number of samples is too small to check convergence reliably.",
307-
stacklevel=2,
308-
)
309-
else:
310-
if idata is None:
311-
idata = to_inference_data(trace, log_likelihood=False)
312-
trace.report._run_convergence_checks(idata, model)
313-
trace.report._log_summary()
314-
315-
316306
def _sample_smc_int(
317307
draws,
318308
kernel,

pymc/stats/convergence.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
5353
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
5454
return [warn]
5555

56+
if idata["posterior"].sizes["draw"] < 100:
57+
msg = "The number of samples is too small to check convergence reliably."
58+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
59+
return [warn]
60+
5661
if idata["posterior"].sizes["chain"] == 1:
57-
msg = (
58-
"Only one chain was sampled, this makes it impossible to " "run some convergence checks"
59-
)
62+
msg = "Only one chain was sampled, this makes it impossible to run some convergence checks"
6063
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
6164
return [warn]
6265

pymc/tests/smc/test_smc.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import warnings
1516

1617
import numpy as np
@@ -215,13 +216,11 @@ def test_return_datatype(self, chains):
215216
assert mt.nchains == chains
216217
assert mt["x"].size == chains * draws
217218

218-
def test_convergence_checks(self):
219-
with self.fast_model:
220-
with pytest.warns(
221-
UserWarning,
222-
match="The number of samples is too small",
223-
):
219+
def test_convergence_checks(self, caplog):
220+
with caplog.at_level(logging.INFO):
221+
with self.fast_model:
224222
pm.sample_smc(draws=99)
223+
assert "The number of samples is too small" in caplog.text
225224

226225
def test_deprecated_parallel_arg(self):
227226
with self.fast_model:

0 commit comments

Comments
 (0)