15
15
16
16
import torch
17
17
from botorch .acquisition .acquisition import AcquisitionFunction
18
- from botorch .acquisition .analytic import AnalyticAcquisitionFunction
19
18
from botorch .acquisition .objective import GenericMCObjective
20
- from botorch .exceptions import UnsupportedError
19
+ from botorch .acquisition . wrapper import AbstractAcquisitionFunctionWrapper
21
20
from torch import Tensor
22
21
23
22
@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
139
138
return regularization_term
140
139
141
140
142
- class PenalizedAcquisitionFunction (AcquisitionFunction ):
141
+ class PenalizedAcquisitionFunction (AbstractAcquisitionFunctionWrapper ):
143
142
r"""Single-outcome acquisition function regularized by the given penalty.
144
143
145
144
The usage is similar to:
@@ -161,29 +160,16 @@ def __init__(
161
160
penalty_func: The regularization function.
162
161
regularization_parameter: Regularization parameter used in optimization.
163
162
"""
164
- super () .__init__ (model = raw_acqf .model )
165
- self . raw_acqf = raw_acqf
163
+ AcquisitionFunction .__init__ (self , model = raw_acqf .model )
164
+ AbstractAcquisitionFunctionWrapper . __init__ ( self , acq_function = raw_acqf )
166
165
self .penalty_func = penalty_func
167
166
self .regularization_parameter = regularization_parameter
168
167
169
168
def forward (self , X : Tensor ) -> Tensor :
170
- raw_value = self .raw_acqf (X = X )
169
+ raw_value = self .acq_func (X = X )
171
170
penalty_term = self .penalty_func (X )
172
171
return raw_value - self .regularization_parameter * penalty_term
173
172
174
- @property
175
- def X_pending (self ) -> Optional [Tensor ]:
176
- return self .raw_acqf .X_pending
177
-
178
- def set_X_pending (self , X_pending : Optional [Tensor ] = None ) -> None :
179
- if not isinstance (self .raw_acqf , AnalyticAcquisitionFunction ):
180
- self .raw_acqf .set_X_pending (X_pending = X_pending )
181
- else :
182
- raise UnsupportedError (
183
- "The raw acquisition function is Analytic and does not account "
184
- "for X_pending yet."
185
- )
186
-
187
173
188
174
def group_lasso_regularizer (X : Tensor , groups : List [List [int ]]) -> Tensor :
189
175
r"""Computes the group lasso regularization function for the given point.
0 commit comments