Skip to content

Commit d3a2721

Browse files
blethammeta-codesync[bot]
authored andcommitted
Add support for styles to botorch community PFN (#3099)
Summary: Pull Request resolved: #3099 Style tensors can be specified via a method override, and from there are correctly passed along in the call to the PFN torch model. Reviewed By: SamuelGabriel Differential Revision: D87847406 fbshipit-source-id: 81d9fdc238274217c13fc5f777a44db2db107cfc
1 parent 5c769ef commit d3a2721

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

botorch_community/models/prior_fitted_network.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_styles(
5151
.hyperparameters_dict_to_tensor(hps_subset)
5252
.repeat(batch_size, 1)
5353
.to(device)
54+
.float()
5455
) # shape (batch_size, num_styles)
5556
style_kwargs["style"] = style
5657

@@ -65,6 +66,7 @@ def get_styles(
6566
.hyperparameters_dict_to_tensor(hps_subset)
6667
.repeat(batch_size, 1)
6768
.to(device)
69+
.float()
6870
) # shape (batch_size, num_styles)
6971
style_kwargs["y_style"] = y_style
7072
return style_kwargs
@@ -172,6 +174,7 @@ def __init__(
172174
# so here we use a FixedNoiseGaussianLikelihood that is unused.
173175
if train_Yvar is None:
174176
train_Yvar = torch.zeros_like(train_Y)
177+
self.train_Yvar = train_Yvar # shape: (n, 1)
175178
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar)
176179
self.pfn = model.to(device=train_X.device)
177180
self.batch_first = batch_first
@@ -180,6 +183,17 @@ def __init__(
180183
self.style = style
181184
if input_transform is not None:
182185
self.input_transform = input_transform
186+
self._compute_styles()
187+
188+
def _compute_styles(self):
189+
"""
190+
Can be used to compute styles to be used for PFN prediction based on
191+
training data.
192+
193+
When implemented, will directly modify self.style_hyperparameters or
194+
self.style.
195+
"""
196+
pass
183197

184198
def posterior(
185199
self,
@@ -221,15 +235,10 @@ def posterior(
221235
if posterior_transform is not None:
222236
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
223237

224-
X, train_X, train_Y, orig_X_shape = self._prepare_data(
238+
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(
225239
X, negate_train_ys=negate_train_ys
226240
)
227241

228-
styles = self._get_styles(
229-
hps=self.style_hyperparameters,
230-
batch_size=X.shape[0],
231-
)
232-
233242
probabilities = self.pfn_predict(
234243
X=X,
235244
train_X=train_X,
@@ -248,7 +257,7 @@ def posterior(
248257

249258
def _prepare_data(
250259
self, X: Tensor, negate_train_ys: bool = False
251-
) -> tuple[Tensor, Tensor, Tensor, torch.Size]:
260+
) -> tuple[Tensor, Tensor, Tensor, torch.Size, dict[str, Tensor]]:
252261
orig_X_shape = X.shape # X has shape b? x q? x d
253262
if len(X.shape) > 3:
254263
raise UnsupportedError(f"X must be at most 3-d, got {X.shape}.")
@@ -258,17 +267,19 @@ def _prepare_data(
258267
X = self.transform_inputs(X) # shape (b , q, d)
259268

260269
train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d)
270+
train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1)
261271
if negate_train_ys:
262272
assert self.train_Y.mean().abs() < 1e-4, "train_Y must be zero-centered."
263-
train_Y = match_batch_shape(
264-
-self.train_Y if negate_train_ys else self.train_Y, X
265-
) # shape (b, n, 1)
266-
return X, train_X, train_Y, orig_X_shape
273+
train_Y = -train_Y
274+
styles = self._get_styles(
275+
batch_size=X.shape[0],
276+
) # shape (b, num_styles)
277+
return X, train_X, train_Y, orig_X_shape, styles
267278

268-
def _get_styles(self, hps, batch_size) -> dict[str, Tensor]:
279+
def _get_styles(self, batch_size) -> dict[str, Tensor]:
269280
style_kwargs = get_styles(
270281
model=self.pfn,
271-
hps=hps,
282+
hps=self.style_hyperparameters,
272283
batch_size=batch_size,
273284
device=self.train_X.device,
274285
)
@@ -277,7 +288,10 @@ def _get_styles(self, hps, batch_size) -> dict[str, Tensor]:
277288
style_kwargs == {}
278289
), "Cannot provide both style and style_hyperparameters."
279290
style_kwargs["style"] = (
280-
self.style[None].repeat(batch_size, 1, 1).to(self.train_X.device)
291+
self.style[None]
292+
.repeat(batch_size, 1, 1)
293+
.to(self.train_X.device)
294+
.float()
281295
)
282296
return style_kwargs
283297

@@ -306,9 +320,9 @@ def pfn_predict(
306320
train_Y = train_Y.transpose(0, 1) # shape (n, b, 1)
307321

308322
logits = self.pfn(
309-
train_X.float(),
310-
train_Y.float(),
311-
X.float(),
323+
x=train_X.float(),
324+
y=train_Y.float(),
325+
test_x=X.float(),
312326
**forward_kwargs,
313327
)
314328
if not self.batch_first:
@@ -368,15 +382,10 @@ def posterior(
368382
if posterior_transform is not None:
369383
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
370384

371-
X, train_X, train_Y, orig_X_shape = self._prepare_data(
385+
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(
372386
X, negate_train_ys=negate_train_ys
373387
)
374388

375-
styles = self._get_styles(
376-
hps=self.style_hyperparameters,
377-
batch_size=X.shape[0],
378-
)
379-
380389
if pending_X is not None:
381390
assert pending_X.dim() == 2, "pending_X must be 2-dimensional."
382391
pending_X = pending_X[None].repeat(X.shape[0], 1, 1) # shape (b, n', d)
@@ -452,12 +461,13 @@ def posterior(
452461
if len(X.shape) == 1 or X.shape[-2] == 1:
453462
# No q dimension, or q=1
454463
return marginals
455-
X, train_X, train_Y, orig_X_shape = self._prepare_data(X)
464+
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(X)
456465
# Estimate correlation structure, making another forward pass.
457466
R = self.estimate_correlations(
458467
X=X,
459468
train_X=train_X,
460469
train_Y=train_Y,
470+
styles=styles,
461471
marginals=marginals,
462472
) # (b, q, q)
463473
R = R.view(*orig_X_shape[:-2], X.shape[-2], X.shape[-2]) # (b?, q, q)
@@ -472,6 +482,7 @@ def estimate_correlations(
472482
X: Tensor,
473483
train_X: Tensor,
474484
train_Y: Tensor,
485+
styles: dict[str, Tensor],
475486
marginals: BoundedRiemannPosterior,
476487
) -> Tensor:
477488
"""
@@ -488,6 +499,7 @@ def estimate_correlations(
488499
X: evaluation point, shape (b, q, d)
489500
train_X: Training X, shape (b, n, d)
490501
train_Y: Training Y, shape (b, n, 1)
502+
styles: dict from name to tensor shaped (b, ns) for any styles.
491503
marginals: A posterior object with marginal posteriors for f(X), but no
492504
correlation structure yet added. posterior.probabilities has
493505
shape (b?, q, num_buckets).
@@ -499,6 +511,7 @@ def estimate_correlations(
499511
X=X,
500512
train_X=train_X,
501513
train_Y=train_Y,
514+
styles=styles,
502515
marginals=marginals,
503516
)
504517
# Get marginal moments
@@ -525,6 +538,7 @@ def _compute_conditional_means(
525538
X: Tensor,
526539
train_X: Tensor,
527540
train_Y: Tensor,
541+
styles: dict[str, Tensor],
528542
marginals: BoundedRiemannPosterior,
529543
) -> tuple[Tensor, Tensor]:
530544
"""
@@ -538,6 +552,7 @@ def _compute_conditional_means(
538552
X: evaluation point, shape (b, q, d)
539553
train_X: Training X, shape (b, n, d)
540554
train_Y: Training Y, shape (b, n, 1)
555+
styles: dict from name to tensor shaped (b, ns) for any styles.
541556
marginals: A posterior object with marginal posteriors for f(X), but no
542557
correlation structure yet added. posterior.probabilities has
543558
shape (b?, q, num_buckets).
@@ -561,13 +576,18 @@ def _compute_conditional_means(
561576
train_Y = train_Y.unsqueeze(1).expand(b, q, n, 1)
562577
cond_Y = cond_val.unsqueeze(-1).unsqueeze(-1) # (b, q, 1, 1)
563578
train_Y = torch.cat((train_Y, cond_Y), dim=-2) # (b, q, n+1, 1)
579+
cond_styles = {}
580+
for name, style in styles.items():
581+
ns = style.shape[-1]
582+
cond_styles[name] = style.unsqueeze(-2).expand(b, q, ns).reshape(b * q, ns)
564583
# Construct eval points
565584
eval_X = X.unsqueeze(1).expand(b, q, q, d)
566585
# Squeeze everything into necessary 2 batch dims, and do PFN forward pass
567586
cond_probabilities = self.pfn_predict(
568587
X=eval_X.reshape(b * q, q, d),
569588
train_X=train_X.reshape(b * q, n + 1, d),
570589
train_Y=train_Y.reshape(b * q, n + 1, 1),
590+
**cond_styles,
571591
) # (b * q, q, num_buckets)
572592
# Object for conditional posteriors
573593
cond_posterior = BoundedRiemannPosterior(

test_community/models/test_prior_fitted_network.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,15 @@ def __init__(self, n_buckets: int = 1000):
4949
self.style_encoder = None
5050
self.y_style_encoder = None
5151

52-
def forward(self, train_X: Tensor, train_Y: Tensor, test_X: Tensor) -> Tensor:
53-
return torch.zeros(*test_X.shape[:-1], self.n_buckets, device=test_X.device)
52+
def forward(
53+
self,
54+
x: Tensor,
55+
y: Tensor,
56+
test_x: Tensor,
57+
style: Tensor | None = None,
58+
y_style: Tensor | None = None,
59+
) -> Tensor:
60+
return torch.zeros(*test_x.shape[:-1], self.n_buckets, device=test_x.device)
5461

5562

5663
class TestPriorFittedNetwork(BotorchTestCase):
@@ -162,11 +169,15 @@ def test_shapes(self):
162169

163170
# prepare_data
164171
X = torch.rand(5, 3, **tkwargs)
165-
X, train_X, train_Y, orig_X_shape = pfn._prepare_data(X)
172+
X, train_X, train_Y, orig_X_shape, styles = pfn._prepare_data(X)
166173
self.assertEqual(X.shape, torch.Size([1, 5, 3]))
167174
self.assertEqual(train_X.shape, torch.Size([1, 10, 3]))
168175
self.assertEqual(train_Y.shape, torch.Size([1, 10, 1]))
169176
self.assertEqual(orig_X_shape, torch.Size([5, 3]))
177+
self.assertEqual(styles, {})
178+
pfn.style = torch.rand(4, **tkwargs)
179+
X, train_X, train_Y, orig_X_shape, styles = pfn._prepare_data(X)
180+
self.assertEqual(styles["style"].shape, torch.Size([1, 1, 4]))
170181

171182
def test_input_transform(self):
172183
model = PFNModel(
@@ -204,7 +215,7 @@ def test_style_hyperparameters(self):
204215
orig_forward = dummy_pfn.forward
205216
dummy_pfn.forward = lambda *a, **kw: (
206217
captured.update(kw),
207-
orig_forward(*a[:3]),
218+
orig_forward(**kw),
208219
)[1]
209220

210221
pfn.posterior(torch.rand(5, 3))
@@ -227,7 +238,7 @@ def test_style_params_require_style_hyperparameters(self):
227238

228239
captured = {}
229240
orig = dummy_pfn.forward
230-
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(*a[:3]))[1]
241+
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(**kw))[1]
231242

232243
PFNModel(train_X, train_Y, dummy_pfn).posterior(torch.rand(5, 3))
233244
self.assertNotIn("style", captured)
@@ -244,7 +255,7 @@ def test_raw_style_tensor(self):
244255
captured = {}
245256
dummy_pfn = DummyPFN()
246257
orig = dummy_pfn.forward
247-
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(*a[:3]))[1]
258+
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(**kw))[1]
248259

249260
pfn = PFNModel(train_X, train_Y, dummy_pfn, style=style)
250261
pfn.posterior(torch.rand(5, 3))
@@ -444,6 +455,7 @@ def test_compute_conditional_means(self):
444455
X=X,
445456
train_X=torch.zeros(3, 4, 5),
446457
train_Y=torch.zeros(3, 4, 1),
458+
styles={"style": torch.zeros(3, 7)},
447459
marginals=marginals,
448460
)
449461
res = mock_pfn_predict.call_args[1]
@@ -463,6 +475,9 @@ def test_compute_conditional_means(self):
463475
self.assertTrue(
464476
torch.equal(torch.round(res["train_Y"], decimals=2), torch.cat(a, dim=0))
465477
)
478+
# Verify style and y_style are passed correctly
479+
self.assertTrue(torch.equal(res["style"], torch.zeros(6, 7)))
480+
self.assertNotIn("y_style", res)
466481

467482
def test_estimate_correlations(self):
468483
probabilities = torch.ones(2, 3, 1000)
@@ -481,6 +496,7 @@ def test_estimate_correlations(self):
481496
X=torch.ones(2, 3, 5),
482497
train_X=torch.zeros(2, 4, 5),
483498
train_Y=torch.zeros(2, 4, 1),
499+
styles={"style": torch.zeros(2, 4), "y_style": torch.ones(2, 4)},
484500
marginals=marginals,
485501
)
486502
self.assertAllClose(torch.diagonal(R, dim1=-2, dim2=-1), torch.ones(2, 3))
@@ -499,6 +515,7 @@ def test_estimate_correlations(self):
499515
X=torch.ones(1, 3, 5),
500516
train_X=torch.zeros(1, 4, 5),
501517
train_Y=torch.zeros(1, 4, 1),
518+
styles={"style": torch.zeros(1, 4), "y_style": torch.ones(1, 4)},
502519
marginals=marginals,
503520
)
504521
self.assertEqual(R.shape, torch.Size([1, 3, 3]))

0 commit comments

Comments
 (0)