Skip to content

Commit d0aab78

Browse files
SebastianAmentmeta-codesync[bot]
authored andcommitted
Add test coverage for warning in Model._set_transformed_inputs (#3107)
Summary: Pull Request resolved: #3107 This commit adds test coverage for the warning in `Model._set_transformed_inputs`. [This is currently lacking coverage](https://app.codecov.io/gh/meta-pytorch/botorch/pull/3103/blob/botorch/models/model.py?dropdown=coverage). Reviewed By: esantorella Differential Revision: D88653645 fbshipit-source-id: 2b3831635711cc31816250741e494a1502b1cb8b
1 parent 3ff4d24 commit d0aab78

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

test/models/test_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from botorch.exceptions.errors import InputDataError
1111
from botorch.models.deterministic import GenericDeterministicModel
1212
from botorch.models.model import Model, ModelDict, ModelList
13+
from botorch.models.transforms.input import Normalize
1314
from botorch.posteriors.ensemble import EnsemblePosterior
1415
from botorch.posteriors.posterior_list import PosteriorList
1516
from botorch.utils.datasets import SupervisedDataset
@@ -22,6 +23,22 @@ def posterior(self, X, output_indices, observation_noise, **kwargs):
2223
pass
2324

2425

26+
class ModelWithInputTransformButNoTrainInputs(Model):
27+
"""A model that has input_transform but no train_inputs attribute."""
28+
29+
def __init__(self, input_transform):
30+
"""Initialize the model with an input transform.
31+
32+
Args:
33+
input_transform: The input transform to apply.
34+
"""
35+
super().__init__()
36+
self.input_transform = input_transform
37+
38+
def posterior(self, X, output_indices, observation_noise, **kwargs):
39+
pass
40+
41+
2542
class GenericDeterministicModelWithBatchShape(GenericDeterministicModel):
2643
# mocking torch.nn.Module components is kind of funky, so let's do this instead
2744
@property
@@ -58,6 +75,33 @@ def test_not_so_abstract_base_model(self):
5875
with self.assertRaises(NotImplementedError):
5976
model.subset_output([0])
6077

78+
def test_set_transformed_inputs_warning_without_train_inputs(self) -> None:
79+
# Test that a RuntimeWarning is raised when a model has an input_transform
80+
# but no train_inputs attribute.
81+
input_transform = Normalize(d=2)
82+
model = ModelWithInputTransformButNoTrainInputs(input_transform=input_transform)
83+
84+
# Verify the model has input_transform but no train_inputs
85+
self.assertTrue(hasattr(model, "input_transform"))
86+
self.assertFalse(hasattr(model, "train_inputs"))
87+
88+
# Test cases: (method_name, callable that triggers _set_transformed_inputs)
89+
test_cases = [
90+
("_set_transformed_inputs", lambda: model._set_transformed_inputs()),
91+
("eval", lambda: model.eval()),
92+
("train(mode=False)", lambda: model.train(mode=False)),
93+
]
94+
95+
for method_name, trigger_fn in test_cases:
96+
with self.subTest(method=method_name):
97+
with self.assertWarnsRegex(
98+
RuntimeWarning,
99+
"Could not update `train_inputs` with transformed inputs since "
100+
"ModelWithInputTransformButNoTrainInputs does not have a "
101+
"`train_inputs` attribute",
102+
):
103+
trigger_fn()
104+
61105
def test_construct_inputs(self) -> None:
62106
model = NotSoAbstractBaseModel()
63107
with self.subTest("Wrong training data type"), self.assertRaisesRegex(

0 commit comments

Comments
 (0)