Skip to content

PyMC/PyTensor Implementation of Pathfinder VI #387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4540b84
renamed samples argument name and pathfinder variables to avoid confu…
aphc14 Oct 3, 2024
0c880d2
Minor changes made to the `fit_pathfinder` function and added test
aphc14 Oct 19, 2024
8835cd5
extract additional pathfinder objects from high level API for debugging
aphc14 Sep 17, 2024
663a60a
changed pathfinder samples argument to num_draws
aphc14 Oct 26, 2024
05aeeaf
Merge branch 'replicate_pathfinder_w_pytensor' into scipy_lbfgs
aphc14 Oct 26, 2024
0db91fe
feat(pathfinder): add PyMC-based Pathfinder VI implementation
aphc14 Oct 31, 2024
cb4436c
Multipath Pathfinder VI implementation in pymc-experimental
aphc14 Nov 4, 2024
2efb511
Added type hints and epsilon parameter to fit_pathfinder
aphc14 Nov 7, 2024
fdc3f38
Removed initial point values (l=0) to reduce iterations. Simplified …
aphc14 Nov 7, 2024
1fd7a11
Added placeholder/reminder to remove jax dependency when converting t…
aphc14 Nov 7, 2024
ef2956f
Sync updates with draft PR #386. \n- Added pytensor.function for bfgs…
aphc14 Nov 7, 2024
8b134b7
Reduced size of compute graph with pathfinder_body_fn
aphc14 Nov 11, 2024
6484b3d
- Added TODO comments for implementing Taylor approximation methods: …
aphc14 Nov 14, 2024
aa765fb
fix: correct posterior approximations in Pathfinder VI
aphc14 Nov 21, 2024
4299a58
feat: Add dense BFGS sampling for Pathfinder VI
aphc14 Nov 21, 2024
f1a54c6
feat: improve Pathfinder performance and compatibility
aphc14 Nov 24, 2024
ea802fc
minor: improve error handling in Pathfinder VI
aphc14 Nov 25, 2024
a77f2c8
Progress bar and other minor changes
aphc14 Nov 27, 2024
9faaa72
set maxcor to max(5, floor(N / 1.9)). max=1 will cause error
aphc14 Nov 27, 2024
2815c4f
Merge branch 'main' into pathfinder_w_pytensor_symbolic
aphc14 Dec 7, 2024
e4b8996
Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improve…
aphc14 Dec 7, 2024
885afaa
Improvements to Importance Sampling and InferenceData shape
aphc14 Dec 8, 2024
ba85587
Display summary of results, Improve error handling, General improvements
aphc14 Jan 22, 2025
382aeb7
Merge branch 'main' into pathfinder_w_pytensor_symbolic
aphc14 Jan 22, 2025
baad3d9
Move pathfinder module to pymc_extras
aphc14 Jan 22, 2025
862627e
Improve pathfinder error handling and type hints
aphc14 Jan 24, 2025
03e9dd0
fix: Use typing_extensions.Self for Python 3.10 compatibility
aphc14 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.util import find_spec


def fit(method, **kwargs):
Expand All @@ -31,9 +30,6 @@ def fit(method, **kwargs):
arviz.InferenceData
"""
if method == "pathfinder":
if find_spec("blackjax") is None:
raise RuntimeError("Need BlackJAX to use `pathfinder`")

from pymc_extras.inference.pathfinder import fit_pathfinder

return fit_pathfinder(**kwargs)
Expand Down
134 changes: 0 additions & 134 deletions pymc_extras/inference/pathfinder.py

This file was deleted.

3 changes: 3 additions & 0 deletions pymc_extras/inference/pathfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit_pathfinder"]
139 changes: 139 additions & 0 deletions pymc_extras/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
import warnings as _warnings

from dataclasses import dataclass, field
from typing import Literal

import arviz as az
import numpy as np

from numpy.typing import NDArray
from scipy.special import logsumexp

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ImportanceSamplingResult:
"""container for importance sampling results"""

samples: NDArray
pareto_k: float | None = None
warnings: list[str] = field(default_factory=list)
method: str = "none"


def importance_sampling(
samples: NDArray,
logP: NDArray,
logQ: NDArray,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"] | None,
random_seed: int | None = None,
) -> ImportanceSamplingResult:
"""Pareto Smoothed Importance Resampling (PSIR)
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.

Parameters
----------
samples : NDArray
samples from proposal distribution, shape (L, M, N)
logP : NDArray
log probability values of target distribution, shape (L, M)
logQ : NDArray
log probability values of proposal distribution, shape (L, M)
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
method : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
random_seed : int | None

Returns
-------
ImportanceSamplingResult
importance sampled draws and other info based on the specified method

Future work!
----------
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.

References
----------
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668

Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538

Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

warnings = []
num_paths, _, N = samples.shape

if method == "none":
warnings.append(
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return ImportanceSamplingResult(samples=samples, warnings=warnings)
else:
samples = samples.reshape(-1, N)
logP = logP.ravel()
logQ = logQ.ravel()

# adjust log densities
log_I = np.log(num_paths)
logP -= log_I
logQ -= log_I
logiw = logP - logQ

with _warnings.catch_warnings():
_warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
)
if method == "psis":
replace = False
logiw, pareto_k = az.psislw(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = az.psislw(logiw)
elif method == "identity":
replace = False
pareto_k = None
else:
raise ValueError(f"Invalid importance sampling method: {method}")

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
# TODO: Find replacement diagnostics for Pathfinder.

p = np.exp(logiw - logsumexp(logiw))
rng = np.random.default_rng(random_seed)

try:
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
return ImportanceSamplingResult(
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
)
except ValueError as e1:
if "Fewer non-zero entries in p than size" in str(e1):
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
warnings.append(
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
)
try:
resampled = rng.choice(
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
)
return ImportanceSamplingResult(
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
)
except ValueError as e2:
logger.error(
"Importance sampling failed even with psir importance sampling. "
"This might indicate invalid probability weights or insufficient valid samples."
)
raise ValueError(
"Importance sampling failed for both with and without replacement"
) from e2
raise
Loading
Loading