1010from botorch .exceptions .errors import InputDataError
1111from botorch .models .deterministic import GenericDeterministicModel
1212from botorch .models .model import Model , ModelDict , ModelList
13+ from botorch .models .transforms .input import Normalize
1314from botorch .posteriors .ensemble import EnsemblePosterior
1415from botorch .posteriors .posterior_list import PosteriorList
1516from botorch .utils .datasets import SupervisedDataset
@@ -22,6 +23,22 @@ def posterior(self, X, output_indices, observation_noise, **kwargs):
2223 pass
2324
2425
26+ class ModelWithInputTransformButNoTrainInputs (Model ):
27+ """A model that has input_transform but no train_inputs attribute."""
28+
29+ def __init__ (self , input_transform ):
30+ """Initialize the model with an input transform.
31+
32+ Args:
33+ input_transform: The input transform to apply.
34+ """
35+ super ().__init__ ()
36+ self .input_transform = input_transform
37+
38+ def posterior (self , X , output_indices , observation_noise , ** kwargs ):
39+ pass
40+
41+
2542class GenericDeterministicModelWithBatchShape (GenericDeterministicModel ):
2643 # mocking torch.nn.Module components is kind of funky, so let's do this instead
2744 @property
@@ -58,6 +75,33 @@ def test_not_so_abstract_base_model(self):
5875 with self .assertRaises (NotImplementedError ):
5976 model .subset_output ([0 ])
6077
78+ def test_set_transformed_inputs_warning_without_train_inputs (self ) -> None :
79+ # Test that a RuntimeWarning is raised when a model has an input_transform
80+ # but no train_inputs attribute.
81+ input_transform = Normalize (d = 2 )
82+ model = ModelWithInputTransformButNoTrainInputs (input_transform = input_transform )
83+
84+ # Verify the model has input_transform but no train_inputs
85+ self .assertTrue (hasattr (model , "input_transform" ))
86+ self .assertFalse (hasattr (model , "train_inputs" ))
87+
88+ # Test cases: (method_name, callable that triggers _set_transformed_inputs)
89+ test_cases = [
90+ ("_set_transformed_inputs" , lambda : model ._set_transformed_inputs ()),
91+ ("eval" , lambda : model .eval ()),
92+ ("train(mode=False)" , lambda : model .train (mode = False )),
93+ ]
94+
95+ for method_name , trigger_fn in test_cases :
96+ with self .subTest (method = method_name ):
97+ with self .assertWarnsRegex (
98+ RuntimeWarning ,
99+ "Could not update `train_inputs` with transformed inputs since "
100+ "ModelWithInputTransformButNoTrainInputs does not have a "
101+ "`train_inputs` attribute" ,
102+ ):
103+ trigger_fn ()
104+
61105 def test_construct_inputs (self ) -> None :
62106 model = NotSoAbstractBaseModel ()
63107 with self .subTest ("Wrong training data type" ), self .assertRaisesRegex (
0 commit comments