|
17 | 17 | from botorch.fit import fit_gpytorch_mll |
18 | 18 | from botorch.models.gpytorch import GPyTorchModel |
19 | 19 | from botorch.models.multitask import MultiTaskGP |
| 20 | +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior |
20 | 21 | from botorch.posteriors.gpytorch import GPyTorchPosterior |
21 | 22 | from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal |
22 | 23 | from gpytorch.likelihoods import FixedNoiseGaussianLikelihood |
@@ -220,6 +221,72 @@ def batch_cross_validation( |
220 | 221 | ) |
221 | 222 |
|
222 | 223 |
|
| 224 | +def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults: |
| 225 | + r"""Compute efficient Leave-One-Out cross-validation for a GP model. |
| 226 | +
|
| 227 | + This is a high-level convenience function that automatically dispatches to |
| 228 | + the appropriate LOO CV implementation based on the model type: |
| 229 | +
|
| 230 | + - For ensemble models (``_is_ensemble=True``): Uses ``ensemble_loo_cv`` which |
| 231 | + returns a ``GaussianMixturePosterior`` with both per-member and mixture |
| 232 | + statistics. |
| 233 | + - For standard GP models: Uses ``efficient_loo_cv`` which returns a |
| 234 | + ``GPyTorchPosterior`` with the LOO predictive distributions. |
| 235 | +
|
| 236 | + Both implementations use efficient O(n³) matrix algebra rather than the |
| 237 | + naive O(n⁴) approach of refitting models for each fold. |
| 238 | +
|
| 239 | + NOTE: This function does not refit the model to each LOO fold. The model |
| 240 | + hyperparameters are kept fixed, providing a fast approximation to full |
| 241 | + LOO CV. For models where hyperparameter changes are significant, consider |
| 242 | + using ``batch_cross_validation`` instead. |
| 243 | +
|
| 244 | + Args: |
| 245 | + model: A fitted GPyTorchModel. The model type determines which LOO CV |
| 246 | + implementation is used. |
| 247 | + observation_noise: If True (default), return the posterior |
| 248 | + predictive variance (including observation noise). If False, |
| 249 | + return the posterior variance of the latent function (excluding |
| 250 | + observation noise). The posterior variance is computed by |
| 251 | + subtracting the observation noise from the posterior predictive |
| 252 | + variance. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + CVResults: A named tuple containing: |
| 256 | + - model: The fitted GP model. |
| 257 | + - posterior: The LOO predictive distributions. For ensemble models, |
| 258 | + this is a ``GaussianMixturePosterior``; otherwise, it's a |
| 259 | + ``GPyTorchPosterior``. |
| 260 | + - observed_Y: The observed Y values. |
| 261 | + - observed_Yvar: The observed noise variances (if applicable). |
| 262 | +
|
| 263 | + Example: |
| 264 | + >>> import torch |
| 265 | + >>> from botorch.cross_validation import loo_cv |
| 266 | + >>> from botorch.models import SingleTaskGP |
| 267 | + >>> from botorch.fit import fit_gpytorch_mll |
| 268 | + >>> from gpytorch.mlls import ExactMarginalLogLikelihood |
| 269 | + >>> |
| 270 | + >>> train_X = torch.rand(20, 2, dtype=torch.float64) |
| 271 | + >>> train_Y = torch.sin(train_X).sum(dim=-1, keepdim=True) |
| 272 | + >>> model = SingleTaskGP(train_X, train_Y) |
| 273 | + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) |
| 274 | + >>> fit_gpytorch_mll(mll) |
| 275 | + >>> loo_results = loo_cv(model) |
| 276 | + >>> loo_results.posterior.mean.shape |
| 277 | + torch.Size([20, 1, 1]) |
| 278 | +
|
| 279 | + See Also: |
| 280 | + - ``efficient_loo_cv``: Direct access to the standard GP implementation. |
| 281 | + - ``ensemble_loo_cv``: Direct access to the ensemble model implementation. |
| 282 | + - ``batch_cross_validation``: Full LOO CV with model refitting. |
| 283 | + """ |
| 284 | + if getattr(model, "_is_ensemble", False): |
| 285 | + return ensemble_loo_cv(model, observation_noise=observation_noise) |
| 286 | + else: |
| 287 | + return efficient_loo_cv(model, observation_noise=observation_noise) |
| 288 | + |
| 289 | + |
223 | 290 | def efficient_loo_cv( |
224 | 291 | model: GPyTorchModel, |
225 | 292 | observation_noise: bool = True, |
@@ -562,3 +629,247 @@ def _reshape_to_loo_cv_format(tensor: Tensor, num_outputs: int) -> Tensor: |
562 | 629 | else: |
563 | 630 | # Single-output: n -> n x 1 -> n x 1 x 1 |
564 | 631 | return tensor.unsqueeze(-1).unsqueeze(-1) |
| 632 | + |
| 633 | + |
| 634 | +def ensemble_loo_cv( |
| 635 | + model: GPyTorchModel, |
| 636 | + observation_noise: bool = True, |
| 637 | +) -> CVResults: |
| 638 | + r"""Compute efficient LOO cross-validation for ensemble models. |
| 639 | +
|
| 640 | + This function computes Leave-One-Out cross-validation for ensemble models |
| 641 | + like `SaasFullyBayesianSingleTaskGP`. For these models, the `forward` method |
| 642 | + returns a `MultivariateNormal` with a batch dimension containing statistics |
| 643 | + for all models in the ensemble. |
| 644 | +
|
| 645 | + The LOO predictions from each ensemble member form a Gaussian mixture. |
| 646 | + This function returns a `CVResults` with a `GaussianMixturePosterior` that |
| 647 | + provides both per-member statistics (via `posterior.mean` and |
| 648 | + `posterior.variance`) and aggregated mixture statistics (via |
| 649 | + `posterior.mixture_mean` and `posterior.mixture_variance`). |
| 650 | +
|
| 651 | + The mixture statistics are computed using the law of total variance: |
| 652 | +
|
| 653 | + .. math:: |
| 654 | +
|
| 655 | + \mu_{mix} = \frac{1}{K} \sum_{k=1}^{K} \mu_k |
| 656 | +
|
| 657 | + \sigma^2_{mix} = \frac{1}{K} \sum_{k=1}^{K} \sigma^2_k + |
| 658 | + \frac{1}{K} \sum_{k=1}^{K} \mu_k^2 - \mu_{mix}^2 |
| 659 | +
|
| 660 | + where K is the number of ensemble members. |
| 661 | +
|
| 662 | + NOTE: This function assumes the model has already been fitted (e.g., using |
| 663 | + `fit_fully_bayesian_model_nuts`) and that the model is an ensemble model |
| 664 | + with `_is_ensemble = True`. |
| 665 | +
|
| 666 | + Args: |
| 667 | + model: A ensemble GPyTorchModel (e.g., SaasFullyBayesianSingleTaskGP) |
| 668 | + whose `forward` method returns a `MultivariateNormal` distribution |
| 669 | + with a batch dimension for ensemble members. |
| 670 | + observation_noise: If True (default), return the posterior |
| 671 | + predictive variance (including observation noise). If False, |
| 672 | + return the posterior variance of the latent function (excluding |
| 673 | + observation noise). |
| 674 | +
|
| 675 | + Returns: |
| 676 | + CVResults: A named tuple containing: |
| 677 | + - model: The fitted ensemble GP model. |
| 678 | + - posterior: A `GaussianMixturePosterior` with per-member shape |
| 679 | + ``n x num_models x 1 x 1``. Access per-member statistics via |
| 680 | + ``posterior.mean`` and ``posterior.variance``, and mixture |
| 681 | + statistics via ``posterior.mixture_mean`` and |
| 682 | + ``posterior.mixture_variance``. |
| 683 | + - observed_Y: The observed Y values with shape ``n x 1 x 1``. |
| 684 | + - observed_Yvar: The observed noise variances (if provided). |
| 685 | +
|
| 686 | + Example: |
| 687 | + >>> import torch |
| 688 | + >>> from botorch.cross_validation import ensemble_loo_cv |
| 689 | + >>> from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP |
| 690 | + >>> from botorch.models.fully_bayesian import fit_fully_bayesian_model_nuts |
| 691 | + >>> |
| 692 | + >>> train_X = torch.rand(20, 2, dtype=torch.float64) |
| 693 | + >>> train_Y = torch.sin(train_X).sum(dim=-1, keepdim=True) |
| 694 | + >>> model = SaasFullyBayesianSingleTaskGP(train_X, train_Y) |
| 695 | + >>> fit_fully_bayesian_model_nuts(model, warmup_steps=64, num_samples=32) |
| 696 | + >>> loo_results = ensemble_loo_cv(model) |
| 697 | + >>> loo_results.posterior.mean.shape # Per-member means |
| 698 | + torch.Size([20, 32, 1, 1]) |
| 699 | + >>> loo_results.posterior.mixture_mean.shape # Aggregated mixture mean |
| 700 | + torch.Size([20, 1, 1]) |
| 701 | + """ |
| 702 | + # Check that this is an ensemble model |
| 703 | + if not getattr(model, "_is_ensemble", False): |
| 704 | + raise UnsupportedError( |
| 705 | + "ensemble_loo_cv requires an ensemble model (with _is_ensemble=True). " |
| 706 | + f"Got model of type {type(model).__name__}. " |
| 707 | + "For non-ensemble models, use efficient_loo_cv instead." |
| 708 | + ) |
| 709 | + |
| 710 | + # Compute raw LOO predictions |
| 711 | + # For ensemble models, shapes are: num_models x n x 1 |
| 712 | + loo_mean, loo_variance, train_Y = _compute_loo_predictions( |
| 713 | + model, observation_noise=observation_noise |
| 714 | + ) |
| 715 | + |
| 716 | + # Validate that we have the expected batch dimension for ensemble |
| 717 | + if loo_mean.dim() < 3: |
| 718 | + raise UnsupportedError( |
| 719 | + "Expected ensemble model to produce batched LOO results with shape " |
| 720 | + f"(batch_shape x num_models x n x 1), but got shape {loo_mean.shape}." |
| 721 | + ) |
| 722 | + |
| 723 | + # Get the number of outputs |
| 724 | + num_outputs = getattr(model, "_num_outputs", 1) |
| 725 | + |
| 726 | + # Build the GaussianMixturePosterior |
| 727 | + posterior = _build_ensemble_loo_posterior( |
| 728 | + loo_mean=loo_mean, loo_variance=loo_variance, num_outputs=num_outputs |
| 729 | + ) |
| 730 | + |
| 731 | + # Extract observed data (first ensemble member) and reshape to LOO CV format |
| 732 | + observed_Y, observed_Yvar = _get_ensemble_observed_data( |
| 733 | + model=model, train_Y=train_Y, num_outputs=num_outputs |
| 734 | + ) |
| 735 | + |
| 736 | + return CVResults( |
| 737 | + model=model, |
| 738 | + posterior=posterior, |
| 739 | + observed_Y=observed_Y, |
| 740 | + observed_Yvar=observed_Yvar, |
| 741 | + ) |
| 742 | + |
| 743 | + |
| 744 | +def _build_ensemble_loo_posterior( |
| 745 | + loo_mean: Tensor, |
| 746 | + loo_variance: Tensor, |
| 747 | + num_outputs: int, |
| 748 | +) -> GaussianMixturePosterior: |
| 749 | + r"""Build a GaussianMixturePosterior from raw ensemble LOO predictions. |
| 750 | +
|
| 751 | + This function takes raw LOO means and variances from an ensemble model |
| 752 | + (computed by `_compute_loo_predictions`) and packages them into a |
| 753 | + GaussianMixturePosterior that provides both per-member and mixture statistics. |
| 754 | +
|
| 755 | + Args: |
| 756 | + loo_mean: LOO means with shape ``batch_shape x num_models x n x 1`` |
| 757 | + (single-output) or ``batch_shape x num_models x m x n x 1`` |
| 758 | + (multi-output). |
| 759 | + loo_variance: LOO variances with same shape as loo_mean. |
| 760 | + num_outputs: Number of outputs (m). 1 for single-output models. |
| 761 | +
|
| 762 | + Returns: |
| 763 | + GaussianMixturePosterior with shape ``batch_shape x n x num_models x 1 x m``. |
| 764 | + The num_models dimension is at MCMC_DIM=-3. |
| 765 | + """ |
| 766 | + # Normalize shapes: add m=1 dimension for single-output to match multi-output |
| 767 | + if num_outputs == 1: |
| 768 | + # Single-output: ... x num_models x n x 1 -> ... x num_models x 1 x n x 1 |
| 769 | + loo_mean = loo_mean.unsqueeze(-3) |
| 770 | + loo_variance = loo_variance.unsqueeze(-3) |
| 771 | + |
| 772 | + # Now both cases have shape: ... x num_models x m x n x 1 |
| 773 | + # Transform to target shape: ... x n x num_models x 1 x m |
| 774 | + # 1. squeeze(-1): ... x num_models x m x n |
| 775 | + # 2. movedim(-1, -3): ... x n x num_models x m (move n before num_models) |
| 776 | + # 3. unsqueeze(-2): ... x n x num_models x 1 x m |
| 777 | + loo_mean = loo_mean.squeeze(-1).movedim(-1, -3).unsqueeze(-2) |
| 778 | + loo_variance = loo_variance.squeeze(-1).movedim(-1, -3).unsqueeze(-2) |
| 779 | + |
| 780 | + # Create distribution: iterate over outputs to create independent MVNs |
| 781 | + # After indexing with [..., t], shape is: batch_shape x n x num_models x 1 |
| 782 | + mvns = [ |
| 783 | + MultivariateNormal( |
| 784 | + mean=loo_mean[..., t], |
| 785 | + covariance_matrix=DiagLinearOperator(loo_variance[..., t]), |
| 786 | + ) |
| 787 | + for t in range(num_outputs) |
| 788 | + ] |
| 789 | + |
| 790 | + if num_outputs > 1: |
| 791 | + mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) |
| 792 | + else: |
| 793 | + mvn = mvns[0] |
| 794 | + |
| 795 | + return GaussianMixturePosterior(distribution=mvn) |
| 796 | + |
| 797 | + |
| 798 | +def _get_ensemble_observed_data( |
| 799 | + model: GPyTorchModel, |
| 800 | + train_Y: Tensor, |
| 801 | + num_outputs: int, |
| 802 | +) -> tuple[Tensor, Tensor | None]: |
| 803 | + r"""Extract observed data from an ensemble model for LOO CV. |
| 804 | +
|
| 805 | + Extracts the first ensemble member's training targets and observation noise, |
| 806 | + verifies all members share the same data, and reshapes to LOO CV format. |
| 807 | +
|
| 808 | + Args: |
| 809 | + model: The ensemble GP model. |
| 810 | + train_Y: Training targets with shape ``... x num_models x n`` (single-output) |
| 811 | + or ``... x num_models x m x n`` (multi-output). |
| 812 | + num_outputs: Number of outputs (m). |
| 813 | +
|
| 814 | + Returns: |
| 815 | + (observed_Y, observed_Yvar) with shape ``... x n x 1 x m``. |
| 816 | +
|
| 817 | + Raises: |
| 818 | + UnsupportedError: If ensemble members have different training data. |
| 819 | + """ |
| 820 | + # num_models is at dim -2 for single-output, -3 for multi-output |
| 821 | + num_models_dim = -2 if num_outputs == 1 else -3 |
| 822 | + |
| 823 | + # Verify all ensemble members share the same training data |
| 824 | + _verify_ensemble_data_consistency(train_Y, num_models_dim, "train_Y") |
| 825 | + |
| 826 | + # Extract first ensemble member's data (they're all the same) |
| 827 | + train_Y_first = train_Y.select(num_models_dim, 0) |
| 828 | + observed_Y = _reshape_to_loo_cv_format(train_Y_first, num_outputs) |
| 829 | + |
| 830 | + # Get observed Yvar if available (for fixed noise models) |
| 831 | + observed_Yvar = None |
| 832 | + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood): |
| 833 | + noise = model.likelihood.noise |
| 834 | + # Noise has the same shape structure as train_Y |
| 835 | + # Verify consistency and extract first member |
| 836 | + if noise.dim() > 1: |
| 837 | + _verify_ensemble_data_consistency( |
| 838 | + noise, num_models_dim, "observation noise" |
| 839 | + ) |
| 840 | + noise = noise.select(num_models_dim, 0) |
| 841 | + observed_Yvar = _reshape_to_loo_cv_format(noise, num_outputs) |
| 842 | + |
| 843 | + return observed_Y, observed_Yvar |
| 844 | + |
| 845 | + |
| 846 | +def _verify_ensemble_data_consistency( |
| 847 | + tensor: Tensor, |
| 848 | + num_models_dim: int, |
| 849 | + tensor_name: str, |
| 850 | +) -> None: |
| 851 | + r"""Verify all ensemble members have identical data along ``num_models_dim``. |
| 852 | +
|
| 853 | + Args: |
| 854 | + tensor: Data tensor with a num_models dimension. |
| 855 | + num_models_dim: Dimension index for num_models (typically -2 or -3). |
| 856 | + tensor_name: Name for error messages (e.g., "train_Y"). |
| 857 | +
|
| 858 | + Raises: |
| 859 | + UnsupportedError: If data differs across ensemble members. |
| 860 | + """ |
| 861 | + num_models = tensor.shape[num_models_dim] |
| 862 | + if num_models <= 1: |
| 863 | + return |
| 864 | + |
| 865 | + first_member = tensor.select(num_models_dim, 0) |
| 866 | + first_expanded = first_member.unsqueeze(num_models_dim).expand_as(tensor) |
| 867 | + |
| 868 | + if not torch.allclose(tensor, first_expanded): |
| 869 | + raise UnsupportedError( |
| 870 | + f"Ensemble members have different {tensor_name}. " |
| 871 | + "ensemble_loo_cv only supports ensembles where all members share the " |
| 872 | + "same training data (e.g., fully Bayesian models with MCMC samples). " |
| 873 | + "For ensembles with different data per member, cross-validate each " |
| 874 | + "member individually using efficient_loo_cv." |
| 875 | + ) |
0 commit comments