Skip to content

Commit 8be7d28

Browse files
authored
Merge 7ce1389 into 63dd0cd
2 parents 63dd0cd + 7ce1389 commit 8be7d28

File tree

13 files changed

+1419
-50
lines changed

13 files changed

+1419
-50
lines changed

botorch/acquisition/fixed_feature.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import torch
1818
from botorch.acquisition.acquisition import AcquisitionFunction
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
1920
from torch import Tensor
20-
from torch.nn import Module
2121

2222

23-
class FixedFeatureAcquisitionFunction(AcquisitionFunction):
23+
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
2424
"""A wrapper around AquisitionFunctions to fix a subset of features.
2525
2626
Example:
@@ -56,8 +56,7 @@ def __init__(
5656
combination of `Tensor`s and numbers which can be broadcasted
5757
to form a tensor with trailing dimension size of `d_f`.
5858
"""
59-
Module.__init__(self)
60-
self.acq_func = acq_function
59+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
6160
dtype = torch.float
6261
device = torch.device("cpu")
6362
self.d = d
@@ -126,24 +125,13 @@ def forward(self, X: Tensor):
126125
X_full = self._construct_X_full(X)
127126
return self.acq_func(X_full)
128127

129-
@property
130-
def X_pending(self):
131-
r"""Return the `X_pending` of the base acquisition function."""
132-
try:
133-
return self.acq_func.X_pending
134-
except (ValueError, AttributeError):
135-
raise ValueError(
136-
f"Base acquisition function {type(self.acq_func).__name__} "
137-
"does not have an `X_pending` attribute."
138-
)
139-
140-
@X_pending.setter
141-
def X_pending(self, X_pending: Optional[Tensor]):
128+
def set_X_pending(self, X_pending: Optional[Tensor]):
142129
r"""Sets the `X_pending` of the base acquisition function."""
143130
if X_pending is not None:
144-
self.acq_func.X_pending = self._construct_X_full(X_pending)
131+
full_X_pending = self._construct_X_full(X_pending)
145132
else:
146-
self.acq_func.X_pending = X_pending
133+
full_X_pending = None
134+
self.acq_func.set_X_pending(full_X_pending)
147135

148136
def _construct_X_full(self, X: Tensor) -> Tensor:
149137
r"""Constructs the full input for the base acquisition function.

botorch/acquisition/penalized.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
import torch
1717
from botorch.acquisition.acquisition import AcquisitionFunction
18-
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
1918
from botorch.acquisition.objective import GenericMCObjective
20-
from botorch.exceptions import UnsupportedError
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
2120
from torch import Tensor
2221

2322

@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
139138
return regularization_term
140139

141140

142-
class PenalizedAcquisitionFunction(AcquisitionFunction):
141+
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
143142
r"""Single-outcome acquisition function regularized by the given penalty.
144143
145144
The usage is similar to:
@@ -161,29 +160,16 @@ def __init__(
161160
penalty_func: The regularization function.
162161
regularization_parameter: Regularization parameter used in optimization.
163162
"""
164-
super().__init__(model=raw_acqf.model)
165-
self.raw_acqf = raw_acqf
163+
AcquisitionFunction.__init__(self, model=raw_acqf.model)
164+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
166165
self.penalty_func = penalty_func
167166
self.regularization_parameter = regularization_parameter
168167

169168
def forward(self, X: Tensor) -> Tensor:
170-
raw_value = self.raw_acqf(X=X)
169+
raw_value = self.acq_func(X=X)
171170
penalty_term = self.penalty_func(X)
172171
return raw_value - self.regularization_parameter * penalty_term
173172

174-
@property
175-
def X_pending(self) -> Optional[Tensor]:
176-
return self.raw_acqf.X_pending
177-
178-
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
179-
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
180-
self.raw_acqf.set_X_pending(X_pending=X_pending)
181-
else:
182-
raise UnsupportedError(
183-
"The raw acquisition function is Analytic and does not account "
184-
"for X_pending yet."
185-
)
186-
187173

188174
def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
189175
r"""Computes the group lasso regularization function for the given point.

0 commit comments

Comments
 (0)