Skip to content

Fix bug in fit_MAP when shared variables are used in graph #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- pymc>=5.21
- pytest-cov>=2.5
- pytest>=3.0
- dask
- dask<2025.1.1
- xhistogram
- statsmodels
- numba<=0.60.0
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- pip
- pytest-cov>=2.5
- pytest>=3.0
- dask
- dask<2025.1.1
- xhistogram
- statsmodels
- numba<=0.60.0
Expand Down
10 changes: 9 additions & 1 deletion pymc_extras/inference/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax(
orig_loss_fn = f_loss.vm.jit_fn

@jax.jit
def loss_fn_jax_grad(x, *shared):
def loss_fn_jax_grad(x):
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)

f_loss_and_grad = loss_fn_jax_grad
Expand Down Expand Up @@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss(
point=initial_point_dict, outputs=[loss], inputs=inputs
)

# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
# away.
if use_jax_gradients:
from pymc.sampling.jax import _replace_shared_variables

[loss] = _replace_shared_variables([loss])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this is the sort of stuff that can lead to big constant foldings in JAX


compute_grad = use_grad and not use_jax_gradients
compute_hess = use_hess and not use_jax_gradients
compute_hessp = use_hessp and not use_jax_gradients
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


extras_require = dict(
dask_histogram=["dask[complete]", "xhistogram"],
dask_histogram=["dask[complete]<2025.1.1", "xhistogram"],
histogram=["xhistogram"],
)
extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values())))
Expand Down
23 changes: 23 additions & 0 deletions tests/test_find_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import pytest

Expand Down Expand Up @@ -101,3 +102,25 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie

assert np.isclose(mu_hat, 3, atol=0.5)
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)


def test_JAX_map_shared_variables():
with pm.Model() as m:
data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data")
mu = pm.Normal("mu")
sigma = pm.Exponential("sigma", 1)
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data)

optimized_point = find_MAP(
method="L-BFGS-B",
use_grad=True,
use_hess=False,
use_hessp=False,
progressbar=False,
gradient_backend="jax",
compile_kwargs={"mode": "JAX"},
)
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]

assert np.isclose(mu_hat, 3, atol=0.5)
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
Loading