From 553572122cd6c69383ed4329a16f2d076a33200c Mon Sep 17 00:00:00 2001 From: Martin Ingram Date: Wed, 27 Jan 2021 12:41:41 +1100 Subject: [PATCH 1/4] Transform samples from sample_numpyro_nuts * Add `pymc3.sampling_jax._transform_samples` function which transforms draws * Modify `pymc3.sampling_jax.sample_numpyro_nuts` function to use this function to return transformed samples * Add release note --- RELEASE-NOTES.md | 1 + pymc3/sampling_jax.py | 46 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 65cde7cc6e..2e11c944f4 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,6 +5,7 @@ ### Breaking Changes ### New Features +- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)). ### Maintenance - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). diff --git a/pymc3/sampling_jax.py b/pymc3/sampling_jax.py index e02d21cb3c..a0f8d1ebf6 100644 --- a/pymc3/sampling_jax.py +++ b/pymc3/sampling_jax.py @@ -3,6 +3,8 @@ import re import warnings +from collections import defaultdict + xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--") xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)]) @@ -175,8 +177,50 @@ def _sample(current_state, seed): # print("Sampling time = ", tic4 - tic3) posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} + tic3 = pd.Timestamp.now() + posterior = _transform_samples(posterior, model, keep_untransformed=False) + tic4 = pd.Timestamp.now() az_trace = az.from_dict(posterior=posterior) - tic3 = pd.Timestamp.now() print("Compilation + sampling time = ", tic3 - tic2) + print("Transformation time = ", tic4 - tic3) + return az_trace # , leapfrogs_taken, tic3 - tic2 + + +def _transform_samples(samples, model, keep_untransformed=False): + + # Find out which RVs we need to compute: + free_rv_names = {x.name for x in model.free_RVs} + unobserved_names = {x.name for x in model.unobserved_RVs} + + names_to_compute = unobserved_names - free_rv_names + ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute] + + # Create function graph for these: + fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute) + + # Jaxify, which returns a list of functions, one for each op + jax_fns = jax_funcify(fgraph) + + # Put together the inputs + inputs = [samples[x.name] for x in model.free_RVs] + + for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns): + + # We need a function taking a single argument to run vmap, while the + # jax_fn takes a list, so: + to_run = lambda x: cur_jax_fn(*x) + + result = jax.vmap(jax.vmap(to_run))(inputs) + + # Add to sample dict + samples[cur_op.name] = result + + # Discard unwanted transformed variables, if desired: + vars_to_keep = set( + pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed) + ) + samples = {x: y for x, y in samples.items() if x in vars_to_keep} + + return samples From c18f0ce4b2c67955a5d6840963eaaf42f0ed0d98 Mon Sep 17 00:00:00 2001 From: Martin Ingram Date: Mon, 8 Feb 2021 12:45:06 +1100 Subject: [PATCH 2/4] Update pymc3/sampling_jax.py Co-authored-by: Junpeng Lao --- pymc3/sampling_jax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc3/sampling_jax.py b/pymc3/sampling_jax.py index a0f8d1ebf6..915b9e19d4 100644 --- a/pymc3/sampling_jax.py +++ b/pymc3/sampling_jax.py @@ -210,9 +210,7 @@ def _transform_samples(samples, model, keep_untransformed=False): # We need a function taking a single argument to run vmap, while the # jax_fn takes a list, so: - to_run = lambda x: cur_jax_fn(*x) - - result = jax.vmap(jax.vmap(to_run))(inputs) + result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs) # Add to sample dict samples[cur_op.name] = result From 78d15f4600aabc7c907503fffacefc313210c153 Mon Sep 17 00:00:00 2001 From: Martin Ingram Date: Mon, 8 Feb 2021 13:04:02 +1100 Subject: [PATCH 3/4] Added a small test --- pymc3/sampling_jax.py | 3 ++- pymc3/tests/test_sampling_jax.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 pymc3/tests/test_sampling_jax.py diff --git a/pymc3/sampling_jax.py b/pymc3/sampling_jax.py index 915b9e19d4..522bca7b12 100644 --- a/pymc3/sampling_jax.py +++ b/pymc3/sampling_jax.py @@ -123,6 +123,7 @@ def sample_numpyro_nuts( random_seed=10, model=None, progress_bar=True, + keep_untransformed=False, ): from numpyro.infer import MCMC, NUTS @@ -178,7 +179,7 @@ def _sample(current_state, seed): posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} tic3 = pd.Timestamp.now() - posterior = _transform_samples(posterior, model, keep_untransformed=False) + posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed) tic4 = pd.Timestamp.now() az_trace = az.from_dict(posterior=posterior) diff --git a/pymc3/tests/test_sampling_jax.py b/pymc3/tests/test_sampling_jax.py new file mode 100644 index 0000000000..46a406833c --- /dev/null +++ b/pymc3/tests/test_sampling_jax.py @@ -0,0 +1,19 @@ +import numpy as np + +import pymc3 as pm + +from pymc3.sampling_jax import sample_numpyro_nuts + + +def test_transform_samples(): + + with pm.Model() as model: + + sigma = pm.HalfNormal("sigma") + b = pm.Normal("b", sigma=sigma) + trace = sample_numpyro_nuts(keep_untransformed=True) + + log_vals = trace.posterior["sigma_log__"].values + trans_vals = trace.posterior["sigma"].values + + assert np.allclose(np.exp(log_vals), trans_vals) From 07b715c2ad6f0a756b7e86d88f4b2313b52622df Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Thu, 11 Feb 2021 12:45:15 +0100 Subject: [PATCH 4/4] Split jax tests into their own workflow --- .github/workflows/jaxtests.yml | 64 ++++++++++++++++++++++++++ .github/workflows/pytest.yml | 1 + scripts/check_all_tests_are_covered.py | 20 ++++---- 3 files changed, 77 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/jaxtests.yml diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml new file mode 100644 index 0000000000..c5b3f23963 --- /dev/null +++ b/.github/workflows/jaxtests.yml @@ -0,0 +1,64 @@ +name: jax-sampling + +on: + pull_request: + push: + branches: [master] + +jobs: + pytest: + strategy: + matrix: + os: [ubuntu-latest] + floatx: [float64] + test-subset: + - pymc3/tests/test_sampling_jax.py + fail-fast: false + runs-on: ${{ matrix.os }} + env: + TEST_SUBSET: ${{ matrix.test-subset }} + THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native' + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v2 + - name: Cache conda + uses: actions/cache@v1 + env: + # Increase this value to reset cache if environment-dev-py39.yml has not changed + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ + hashFiles('conda-envs/environment-dev-py39.yml') }} + - name: Cache multiple paths + uses: actions/cache@v2 + env: + # Increase this value to reset cache if requirements.txt has not changed + CACHE_NUMBER: 0 + with: + path: | + ~/.cache/pip + $RUNNER_TOOL_CACHE/Python/* + ~\AppData\Local\pip\Cache + key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ + hashFiles('requirements.txt') }} + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: pymc3-dev-py39 + channel-priority: strict + environment-file: conda-envs/environment-dev-py39.yml + use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + - name: Install pymc3 + run: | + conda activate pymc3-dev-py39 + pip install -e . + python --version + - name: Install jax specific dependencies + run: | + conda activate pymc3-dev-py39 + pip install numpyro tensorflow_probability + - name: Run tests + run: | + python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1b33e899d3..e492c7e705 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -27,6 +27,7 @@ jobs: --ignore=pymc3/tests/test_quadpotential.py --ignore=pymc3/tests/test_random.py --ignore=pymc3/tests/test_sampling.py + --ignore=pymc3/tests/test_sampling_jax.py --ignore=pymc3/tests/test_shape_handling.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py diff --git a/scripts/check_all_tests_are_covered.py b/scripts/check_all_tests_are_covered.py index 2882f57346..f02f90d509 100644 --- a/scripts/check_all_tests_are_covered.py +++ b/scripts/check_all_tests_are_covered.py @@ -12,13 +12,17 @@ from pathlib import Path if __name__ == "__main__": - pytest_ci_job = Path(".github") / "workflows/pytest.yml" - txt = pytest_ci_job.read_text() - ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt)) - non_ignored_tests = set(re.findall(r"(?