diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml new file mode 100644 index 0000000000..c5b3f23963 --- /dev/null +++ b/.github/workflows/jaxtests.yml @@ -0,0 +1,64 @@ +name: jax-sampling + +on: + pull_request: + push: + branches: [master] + +jobs: + pytest: + strategy: + matrix: + os: [ubuntu-latest] + floatx: [float64] + test-subset: + - pymc3/tests/test_sampling_jax.py + fail-fast: false + runs-on: ${{ matrix.os }} + env: + TEST_SUBSET: ${{ matrix.test-subset }} + THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native' + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v2 + - name: Cache conda + uses: actions/cache@v1 + env: + # Increase this value to reset cache if environment-dev-py39.yml has not changed + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ + hashFiles('conda-envs/environment-dev-py39.yml') }} + - name: Cache multiple paths + uses: actions/cache@v2 + env: + # Increase this value to reset cache if requirements.txt has not changed + CACHE_NUMBER: 0 + with: + path: | + ~/.cache/pip + $RUNNER_TOOL_CACHE/Python/* + ~\AppData\Local\pip\Cache + key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ + hashFiles('requirements.txt') }} + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: pymc3-dev-py39 + channel-priority: strict + environment-file: conda-envs/environment-dev-py39.yml + use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + - name: Install pymc3 + run: | + conda activate pymc3-dev-py39 + pip install -e . + python --version + - name: Install jax specific dependencies + run: | + conda activate pymc3-dev-py39 + pip install numpyro tensorflow_probability + - name: Run tests + run: | + python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1b33e899d3..e492c7e705 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -27,6 +27,7 @@ jobs: --ignore=pymc3/tests/test_quadpotential.py --ignore=pymc3/tests/test_random.py --ignore=pymc3/tests/test_sampling.py + --ignore=pymc3/tests/test_sampling_jax.py --ignore=pymc3/tests/test_shape_handling.py --ignore=pymc3/tests/test_shared.py --ignore=pymc3/tests/test_smc.py diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 99e35d659b..68c2ea597d 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,7 +5,8 @@ ### Breaking Changes ### New Features -+ Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)). +- Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)). +- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)). ### Maintenance - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). diff --git a/pymc3/sampling_jax.py b/pymc3/sampling_jax.py index e02d21cb3c..522bca7b12 100644 --- a/pymc3/sampling_jax.py +++ b/pymc3/sampling_jax.py @@ -3,6 +3,8 @@ import re import warnings +from collections import defaultdict + xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--") xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)]) @@ -121,6 +123,7 @@ def sample_numpyro_nuts( random_seed=10, model=None, progress_bar=True, + keep_untransformed=False, ): from numpyro.infer import MCMC, NUTS @@ -175,8 +178,48 @@ def _sample(current_state, seed): # print("Sampling time = ", tic4 - tic3) posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} + tic3 = pd.Timestamp.now() + posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed) + tic4 = pd.Timestamp.now() az_trace = az.from_dict(posterior=posterior) - tic3 = pd.Timestamp.now() print("Compilation + sampling time = ", tic3 - tic2) + print("Transformation time = ", tic4 - tic3) + return az_trace # , leapfrogs_taken, tic3 - tic2 + + +def _transform_samples(samples, model, keep_untransformed=False): + + # Find out which RVs we need to compute: + free_rv_names = {x.name for x in model.free_RVs} + unobserved_names = {x.name for x in model.unobserved_RVs} + + names_to_compute = unobserved_names - free_rv_names + ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute] + + # Create function graph for these: + fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute) + + # Jaxify, which returns a list of functions, one for each op + jax_fns = jax_funcify(fgraph) + + # Put together the inputs + inputs = [samples[x.name] for x in model.free_RVs] + + for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns): + + # We need a function taking a single argument to run vmap, while the + # jax_fn takes a list, so: + result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs) + + # Add to sample dict + samples[cur_op.name] = result + + # Discard unwanted transformed variables, if desired: + vars_to_keep = set( + pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed) + ) + samples = {x: y for x, y in samples.items() if x in vars_to_keep} + + return samples diff --git a/pymc3/tests/test_sampling_jax.py b/pymc3/tests/test_sampling_jax.py new file mode 100644 index 0000000000..46a406833c --- /dev/null +++ b/pymc3/tests/test_sampling_jax.py @@ -0,0 +1,19 @@ +import numpy as np + +import pymc3 as pm + +from pymc3.sampling_jax import sample_numpyro_nuts + + +def test_transform_samples(): + + with pm.Model() as model: + + sigma = pm.HalfNormal("sigma") + b = pm.Normal("b", sigma=sigma) + trace = sample_numpyro_nuts(keep_untransformed=True) + + log_vals = trace.posterior["sigma_log__"].values + trans_vals = trace.posterior["sigma"].values + + assert np.allclose(np.exp(log_vals), trans_vals) diff --git a/scripts/check_all_tests_are_covered.py b/scripts/check_all_tests_are_covered.py index 2882f57346..f02f90d509 100644 --- a/scripts/check_all_tests_are_covered.py +++ b/scripts/check_all_tests_are_covered.py @@ -12,13 +12,17 @@ from pathlib import Path if __name__ == "__main__": - pytest_ci_job = Path(".github") / "workflows/pytest.yml" - txt = pytest_ci_job.read_text() - ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt)) - non_ignored_tests = set(re.findall(r"(?