Skip to content

Commit 645d9e5

Browse files
SebastianAmentmeta-codesync[bot]
authored andcommitted
Efficient leave-one-out cross-validation for ensemble models (#3103)
Summary: Pull Request resolved: #3103 This commit introduces a new function `ensemble_loo_cv` in `botorch/cross_validation.py` for performing efficient leave-one-out cross-validation (LOOCV) on ensemble models, in addition to `loo_cv`, which automatically dispatches to the correct efficient CV implementation (ensemble, or non-ensemble), dependent on the `_is_ensemble` attribute. Reviewed By: Balandat Differential Revision: D88506875 fbshipit-source-id: 6605b5dd3206ec80d9af3d6ef76456f543ef2664
1 parent 976a944 commit 645d9e5

File tree

2 files changed

+768
-0
lines changed

2 files changed

+768
-0
lines changed

botorch/cross_validation.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.fit import fit_gpytorch_mll
1818
from botorch.models.gpytorch import GPyTorchModel
1919
from botorch.models.multitask import MultiTaskGP
20+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2021
from botorch.posteriors.gpytorch import GPyTorchPosterior
2122
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
2223
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
@@ -220,6 +221,72 @@ def batch_cross_validation(
220221
)
221222

222223

224+
def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults:
225+
r"""Compute efficient Leave-One-Out cross-validation for a GP model.
226+
227+
This is a high-level convenience function that automatically dispatches to
228+
the appropriate LOO CV implementation based on the model type:
229+
230+
- For ensemble models (``_is_ensemble=True``): Uses ``ensemble_loo_cv`` which
231+
returns a ``GaussianMixturePosterior`` with both per-member and mixture
232+
statistics.
233+
- For standard GP models: Uses ``efficient_loo_cv`` which returns a
234+
``GPyTorchPosterior`` with the LOO predictive distributions.
235+
236+
Both implementations use efficient O(n³) matrix algebra rather than the
237+
naive O(n⁴) approach of refitting models for each fold.
238+
239+
NOTE: This function does not refit the model to each LOO fold. The model
240+
hyperparameters are kept fixed, providing a fast approximation to full
241+
LOO CV. For models where hyperparameter changes are significant, consider
242+
using ``batch_cross_validation`` instead.
243+
244+
Args:
245+
model: A fitted GPyTorchModel. The model type determines which LOO CV
246+
implementation is used.
247+
observation_noise: If True (default), return the posterior
248+
predictive variance (including observation noise). If False,
249+
return the posterior variance of the latent function (excluding
250+
observation noise). The posterior variance is computed by
251+
subtracting the observation noise from the posterior predictive
252+
variance.
253+
254+
Returns:
255+
CVResults: A named tuple containing:
256+
- model: The fitted GP model.
257+
- posterior: The LOO predictive distributions. For ensemble models,
258+
this is a ``GaussianMixturePosterior``; otherwise, it's a
259+
``GPyTorchPosterior``.
260+
- observed_Y: The observed Y values.
261+
- observed_Yvar: The observed noise variances (if applicable).
262+
263+
Example:
264+
>>> import torch
265+
>>> from botorch.cross_validation import loo_cv
266+
>>> from botorch.models import SingleTaskGP
267+
>>> from botorch.fit import fit_gpytorch_mll
268+
>>> from gpytorch.mlls import ExactMarginalLogLikelihood
269+
>>>
270+
>>> train_X = torch.rand(20, 2, dtype=torch.float64)
271+
>>> train_Y = torch.sin(train_X).sum(dim=-1, keepdim=True)
272+
>>> model = SingleTaskGP(train_X, train_Y)
273+
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
274+
>>> fit_gpytorch_mll(mll)
275+
>>> loo_results = loo_cv(model)
276+
>>> loo_results.posterior.mean.shape
277+
torch.Size([20, 1, 1])
278+
279+
See Also:
280+
- ``efficient_loo_cv``: Direct access to the standard GP implementation.
281+
- ``ensemble_loo_cv``: Direct access to the ensemble model implementation.
282+
- ``batch_cross_validation``: Full LOO CV with model refitting.
283+
"""
284+
if getattr(model, "_is_ensemble", False):
285+
return ensemble_loo_cv(model, observation_noise=observation_noise)
286+
else:
287+
return efficient_loo_cv(model, observation_noise=observation_noise)
288+
289+
223290
def efficient_loo_cv(
224291
model: GPyTorchModel,
225292
observation_noise: bool = True,
@@ -562,3 +629,247 @@ def _reshape_to_loo_cv_format(tensor: Tensor, num_outputs: int) -> Tensor:
562629
else:
563630
# Single-output: n -> n x 1 -> n x 1 x 1
564631
return tensor.unsqueeze(-1).unsqueeze(-1)
632+
633+
634+
def ensemble_loo_cv(
635+
model: GPyTorchModel,
636+
observation_noise: bool = True,
637+
) -> CVResults:
638+
r"""Compute efficient LOO cross-validation for ensemble models.
639+
640+
This function computes Leave-One-Out cross-validation for ensemble models
641+
like `SaasFullyBayesianSingleTaskGP`. For these models, the `forward` method
642+
returns a `MultivariateNormal` with a batch dimension containing statistics
643+
for all models in the ensemble.
644+
645+
The LOO predictions from each ensemble member form a Gaussian mixture.
646+
This function returns a `CVResults` with a `GaussianMixturePosterior` that
647+
provides both per-member statistics (via `posterior.mean` and
648+
`posterior.variance`) and aggregated mixture statistics (via
649+
`posterior.mixture_mean` and `posterior.mixture_variance`).
650+
651+
The mixture statistics are computed using the law of total variance:
652+
653+
.. math::
654+
655+
\mu_{mix} = \frac{1}{K} \sum_{k=1}^{K} \mu_k
656+
657+
\sigma^2_{mix} = \frac{1}{K} \sum_{k=1}^{K} \sigma^2_k +
658+
\frac{1}{K} \sum_{k=1}^{K} \mu_k^2 - \mu_{mix}^2
659+
660+
where K is the number of ensemble members.
661+
662+
NOTE: This function assumes the model has already been fitted (e.g., using
663+
`fit_fully_bayesian_model_nuts`) and that the model is an ensemble model
664+
with `_is_ensemble = True`.
665+
666+
Args:
667+
model: A ensemble GPyTorchModel (e.g., SaasFullyBayesianSingleTaskGP)
668+
whose `forward` method returns a `MultivariateNormal` distribution
669+
with a batch dimension for ensemble members.
670+
observation_noise: If True (default), return the posterior
671+
predictive variance (including observation noise). If False,
672+
return the posterior variance of the latent function (excluding
673+
observation noise).
674+
675+
Returns:
676+
CVResults: A named tuple containing:
677+
- model: The fitted ensemble GP model.
678+
- posterior: A `GaussianMixturePosterior` with per-member shape
679+
``n x num_models x 1 x 1``. Access per-member statistics via
680+
``posterior.mean`` and ``posterior.variance``, and mixture
681+
statistics via ``posterior.mixture_mean`` and
682+
``posterior.mixture_variance``.
683+
- observed_Y: The observed Y values with shape ``n x 1 x 1``.
684+
- observed_Yvar: The observed noise variances (if provided).
685+
686+
Example:
687+
>>> import torch
688+
>>> from botorch.cross_validation import ensemble_loo_cv
689+
>>> from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
690+
>>> from botorch.models.fully_bayesian import fit_fully_bayesian_model_nuts
691+
>>>
692+
>>> train_X = torch.rand(20, 2, dtype=torch.float64)
693+
>>> train_Y = torch.sin(train_X).sum(dim=-1, keepdim=True)
694+
>>> model = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
695+
>>> fit_fully_bayesian_model_nuts(model, warmup_steps=64, num_samples=32)
696+
>>> loo_results = ensemble_loo_cv(model)
697+
>>> loo_results.posterior.mean.shape # Per-member means
698+
torch.Size([20, 32, 1, 1])
699+
>>> loo_results.posterior.mixture_mean.shape # Aggregated mixture mean
700+
torch.Size([20, 1, 1])
701+
"""
702+
# Check that this is an ensemble model
703+
if not getattr(model, "_is_ensemble", False):
704+
raise UnsupportedError(
705+
"ensemble_loo_cv requires an ensemble model (with _is_ensemble=True). "
706+
f"Got model of type {type(model).__name__}. "
707+
"For non-ensemble models, use efficient_loo_cv instead."
708+
)
709+
710+
# Compute raw LOO predictions
711+
# For ensemble models, shapes are: num_models x n x 1
712+
loo_mean, loo_variance, train_Y = _compute_loo_predictions(
713+
model, observation_noise=observation_noise
714+
)
715+
716+
# Validate that we have the expected batch dimension for ensemble
717+
if loo_mean.dim() < 3:
718+
raise UnsupportedError(
719+
"Expected ensemble model to produce batched LOO results with shape "
720+
f"(batch_shape x num_models x n x 1), but got shape {loo_mean.shape}."
721+
)
722+
723+
# Get the number of outputs
724+
num_outputs = getattr(model, "_num_outputs", 1)
725+
726+
# Build the GaussianMixturePosterior
727+
posterior = _build_ensemble_loo_posterior(
728+
loo_mean=loo_mean, loo_variance=loo_variance, num_outputs=num_outputs
729+
)
730+
731+
# Extract observed data (first ensemble member) and reshape to LOO CV format
732+
observed_Y, observed_Yvar = _get_ensemble_observed_data(
733+
model=model, train_Y=train_Y, num_outputs=num_outputs
734+
)
735+
736+
return CVResults(
737+
model=model,
738+
posterior=posterior,
739+
observed_Y=observed_Y,
740+
observed_Yvar=observed_Yvar,
741+
)
742+
743+
744+
def _build_ensemble_loo_posterior(
745+
loo_mean: Tensor,
746+
loo_variance: Tensor,
747+
num_outputs: int,
748+
) -> GaussianMixturePosterior:
749+
r"""Build a GaussianMixturePosterior from raw ensemble LOO predictions.
750+
751+
This function takes raw LOO means and variances from an ensemble model
752+
(computed by `_compute_loo_predictions`) and packages them into a
753+
GaussianMixturePosterior that provides both per-member and mixture statistics.
754+
755+
Args:
756+
loo_mean: LOO means with shape ``batch_shape x num_models x n x 1``
757+
(single-output) or ``batch_shape x num_models x m x n x 1``
758+
(multi-output).
759+
loo_variance: LOO variances with same shape as loo_mean.
760+
num_outputs: Number of outputs (m). 1 for single-output models.
761+
762+
Returns:
763+
GaussianMixturePosterior with shape ``batch_shape x n x num_models x 1 x m``.
764+
The num_models dimension is at MCMC_DIM=-3.
765+
"""
766+
# Normalize shapes: add m=1 dimension for single-output to match multi-output
767+
if num_outputs == 1:
768+
# Single-output: ... x num_models x n x 1 -> ... x num_models x 1 x n x 1
769+
loo_mean = loo_mean.unsqueeze(-3)
770+
loo_variance = loo_variance.unsqueeze(-3)
771+
772+
# Now both cases have shape: ... x num_models x m x n x 1
773+
# Transform to target shape: ... x n x num_models x 1 x m
774+
# 1. squeeze(-1): ... x num_models x m x n
775+
# 2. movedim(-1, -3): ... x n x num_models x m (move n before num_models)
776+
# 3. unsqueeze(-2): ... x n x num_models x 1 x m
777+
loo_mean = loo_mean.squeeze(-1).movedim(-1, -3).unsqueeze(-2)
778+
loo_variance = loo_variance.squeeze(-1).movedim(-1, -3).unsqueeze(-2)
779+
780+
# Create distribution: iterate over outputs to create independent MVNs
781+
# After indexing with [..., t], shape is: batch_shape x n x num_models x 1
782+
mvns = [
783+
MultivariateNormal(
784+
mean=loo_mean[..., t],
785+
covariance_matrix=DiagLinearOperator(loo_variance[..., t]),
786+
)
787+
for t in range(num_outputs)
788+
]
789+
790+
if num_outputs > 1:
791+
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
792+
else:
793+
mvn = mvns[0]
794+
795+
return GaussianMixturePosterior(distribution=mvn)
796+
797+
798+
def _get_ensemble_observed_data(
799+
model: GPyTorchModel,
800+
train_Y: Tensor,
801+
num_outputs: int,
802+
) -> tuple[Tensor, Tensor | None]:
803+
r"""Extract observed data from an ensemble model for LOO CV.
804+
805+
Extracts the first ensemble member's training targets and observation noise,
806+
verifies all members share the same data, and reshapes to LOO CV format.
807+
808+
Args:
809+
model: The ensemble GP model.
810+
train_Y: Training targets with shape ``... x num_models x n`` (single-output)
811+
or ``... x num_models x m x n`` (multi-output).
812+
num_outputs: Number of outputs (m).
813+
814+
Returns:
815+
(observed_Y, observed_Yvar) with shape ``... x n x 1 x m``.
816+
817+
Raises:
818+
UnsupportedError: If ensemble members have different training data.
819+
"""
820+
# num_models is at dim -2 for single-output, -3 for multi-output
821+
num_models_dim = -2 if num_outputs == 1 else -3
822+
823+
# Verify all ensemble members share the same training data
824+
_verify_ensemble_data_consistency(train_Y, num_models_dim, "train_Y")
825+
826+
# Extract first ensemble member's data (they're all the same)
827+
train_Y_first = train_Y.select(num_models_dim, 0)
828+
observed_Y = _reshape_to_loo_cv_format(train_Y_first, num_outputs)
829+
830+
# Get observed Yvar if available (for fixed noise models)
831+
observed_Yvar = None
832+
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
833+
noise = model.likelihood.noise
834+
# Noise has the same shape structure as train_Y
835+
# Verify consistency and extract first member
836+
if noise.dim() > 1:
837+
_verify_ensemble_data_consistency(
838+
noise, num_models_dim, "observation noise"
839+
)
840+
noise = noise.select(num_models_dim, 0)
841+
observed_Yvar = _reshape_to_loo_cv_format(noise, num_outputs)
842+
843+
return observed_Y, observed_Yvar
844+
845+
846+
def _verify_ensemble_data_consistency(
847+
tensor: Tensor,
848+
num_models_dim: int,
849+
tensor_name: str,
850+
) -> None:
851+
r"""Verify all ensemble members have identical data along ``num_models_dim``.
852+
853+
Args:
854+
tensor: Data tensor with a num_models dimension.
855+
num_models_dim: Dimension index for num_models (typically -2 or -3).
856+
tensor_name: Name for error messages (e.g., "train_Y").
857+
858+
Raises:
859+
UnsupportedError: If data differs across ensemble members.
860+
"""
861+
num_models = tensor.shape[num_models_dim]
862+
if num_models <= 1:
863+
return
864+
865+
first_member = tensor.select(num_models_dim, 0)
866+
first_expanded = first_member.unsqueeze(num_models_dim).expand_as(tensor)
867+
868+
if not torch.allclose(tensor, first_expanded):
869+
raise UnsupportedError(
870+
f"Ensemble members have different {tensor_name}. "
871+
"ensemble_loo_cv only supports ensembles where all members share the "
872+
"same training data (e.g., fully Bayesian models with MCMC samples). "
873+
"For ensembles with different data per member, cross-validate each "
874+
"member individually using efficient_loo_cv."
875+
)

0 commit comments

Comments
 (0)