Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions botorch_community/models/prior_fitted_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_styles(
.hyperparameters_dict_to_tensor(hps_subset)
.repeat(batch_size, 1)
.to(device)
.float()
) # shape (batch_size, num_styles)
style_kwargs["style"] = style

Expand All @@ -65,6 +66,7 @@ def get_styles(
.hyperparameters_dict_to_tensor(hps_subset)
.repeat(batch_size, 1)
.to(device)
.float()
) # shape (batch_size, num_styles)
style_kwargs["y_style"] = y_style
return style_kwargs
Expand Down Expand Up @@ -172,6 +174,7 @@ def __init__(
# so here we use a FixedNoiseGaussianLikelihood that is unused.
if train_Yvar is None:
train_Yvar = torch.zeros_like(train_Y)
self.train_Yvar = train_Yvar # shape: (n, 1)
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar)
self.pfn = model.to(device=train_X.device)
self.batch_first = batch_first
Expand All @@ -180,6 +183,17 @@ def __init__(
self.style = style
if input_transform is not None:
self.input_transform = input_transform
self._compute_styles()

def _compute_styles(self):
"""
Can be used to compute styles to be used for PFN prediction based on
training data.

When implemented, will directly modify self.style_hyperparameters or
self.style.
"""
pass

def posterior(
self,
Expand Down Expand Up @@ -221,15 +235,10 @@ def posterior(
if posterior_transform is not None:
raise UnsupportedError("posterior_transform is not supported for PFNModel.")

X, train_X, train_Y, orig_X_shape = self._prepare_data(
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(
X, negate_train_ys=negate_train_ys
)

styles = self._get_styles(
hps=self.style_hyperparameters,
batch_size=X.shape[0],
)

probabilities = self.pfn_predict(
X=X,
train_X=train_X,
Expand All @@ -248,7 +257,7 @@ def posterior(

def _prepare_data(
self, X: Tensor, negate_train_ys: bool = False
) -> tuple[Tensor, Tensor, Tensor, torch.Size]:
) -> tuple[Tensor, Tensor, Tensor, torch.Size, dict[str, Tensor]]:
orig_X_shape = X.shape # X has shape b? x q? x d
if len(X.shape) > 3:
raise UnsupportedError(f"X must be at most 3-d, got {X.shape}.")
Expand All @@ -258,17 +267,19 @@ def _prepare_data(
X = self.transform_inputs(X) # shape (b , q, d)

train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d)
train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1)
if negate_train_ys:
assert self.train_Y.mean().abs() < 1e-4, "train_Y must be zero-centered."
train_Y = match_batch_shape(
-self.train_Y if negate_train_ys else self.train_Y, X
) # shape (b, n, 1)
return X, train_X, train_Y, orig_X_shape
train_Y = -train_Y
styles = self._get_styles(
batch_size=X.shape[0],
) # shape (b, num_styles)
return X, train_X, train_Y, orig_X_shape, styles

def _get_styles(self, hps, batch_size) -> dict[str, Tensor]:
def _get_styles(self, batch_size) -> dict[str, Tensor]:
style_kwargs = get_styles(
model=self.pfn,
hps=hps,
hps=self.style_hyperparameters,
batch_size=batch_size,
device=self.train_X.device,
)
Expand All @@ -277,7 +288,10 @@ def _get_styles(self, hps, batch_size) -> dict[str, Tensor]:
style_kwargs == {}
), "Cannot provide both style and style_hyperparameters."
style_kwargs["style"] = (
self.style[None].repeat(batch_size, 1, 1).to(self.train_X.device)
self.style[None]
.repeat(batch_size, 1, 1)
.to(self.train_X.device)
.float()
)
return style_kwargs

Expand Down Expand Up @@ -306,9 +320,9 @@ def pfn_predict(
train_Y = train_Y.transpose(0, 1) # shape (n, b, 1)

logits = self.pfn(
train_X.float(),
train_Y.float(),
X.float(),
x=train_X.float(),
y=train_Y.float(),
test_x=X.float(),
**forward_kwargs,
)
if not self.batch_first:
Expand Down Expand Up @@ -368,15 +382,10 @@ def posterior(
if posterior_transform is not None:
raise UnsupportedError("posterior_transform is not supported for PFNModel.")

X, train_X, train_Y, orig_X_shape = self._prepare_data(
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(
X, negate_train_ys=negate_train_ys
)

styles = self._get_styles(
hps=self.style_hyperparameters,
batch_size=X.shape[0],
)

if pending_X is not None:
assert pending_X.dim() == 2, "pending_X must be 2-dimensional."
pending_X = pending_X[None].repeat(X.shape[0], 1, 1) # shape (b, n', d)
Expand Down Expand Up @@ -452,12 +461,13 @@ def posterior(
if len(X.shape) == 1 or X.shape[-2] == 1:
# No q dimension, or q=1
return marginals
X, train_X, train_Y, orig_X_shape = self._prepare_data(X)
X, train_X, train_Y, orig_X_shape, styles = self._prepare_data(X)
# Estimate correlation structure, making another forward pass.
R = self.estimate_correlations(
X=X,
train_X=train_X,
train_Y=train_Y,
styles=styles,
marginals=marginals,
) # (b, q, q)
R = R.view(*orig_X_shape[:-2], X.shape[-2], X.shape[-2]) # (b?, q, q)
Expand All @@ -472,6 +482,7 @@ def estimate_correlations(
X: Tensor,
train_X: Tensor,
train_Y: Tensor,
styles: dict[str, Tensor],
marginals: BoundedRiemannPosterior,
) -> Tensor:
"""
Expand All @@ -488,6 +499,7 @@ def estimate_correlations(
X: evaluation point, shape (b, q, d)
train_X: Training X, shape (b, n, d)
train_Y: Training Y, shape (b, n, 1)
styles: dict from name to tensor shaped (b, ns) for any styles.
marginals: A posterior object with marginal posteriors for f(X), but no
correlation structure yet added. posterior.probabilities has
shape (b?, q, num_buckets).
Expand All @@ -499,6 +511,7 @@ def estimate_correlations(
X=X,
train_X=train_X,
train_Y=train_Y,
styles=styles,
marginals=marginals,
)
# Get marginal moments
Expand All @@ -525,6 +538,7 @@ def _compute_conditional_means(
X: Tensor,
train_X: Tensor,
train_Y: Tensor,
styles: dict[str, Tensor],
marginals: BoundedRiemannPosterior,
) -> tuple[Tensor, Tensor]:
"""
Expand All @@ -538,6 +552,7 @@ def _compute_conditional_means(
X: evaluation point, shape (b, q, d)
train_X: Training X, shape (b, n, d)
train_Y: Training Y, shape (b, n, 1)
styles: dict from name to tensor shaped (b, ns) for any styles.
marginals: A posterior object with marginal posteriors for f(X), but no
correlation structure yet added. posterior.probabilities has
shape (b?, q, num_buckets).
Expand All @@ -561,13 +576,18 @@ def _compute_conditional_means(
train_Y = train_Y.unsqueeze(1).expand(b, q, n, 1)
cond_Y = cond_val.unsqueeze(-1).unsqueeze(-1) # (b, q, 1, 1)
train_Y = torch.cat((train_Y, cond_Y), dim=-2) # (b, q, n+1, 1)
cond_styles = {}
for name, style in styles.items():
ns = style.shape[-1]
cond_styles[name] = style.unsqueeze(-2).expand(b, q, ns).reshape(b * q, ns)
# Construct eval points
eval_X = X.unsqueeze(1).expand(b, q, q, d)
# Squeeze everything into necessary 2 batch dims, and do PFN forward pass
cond_probabilities = self.pfn_predict(
X=eval_X.reshape(b * q, q, d),
train_X=train_X.reshape(b * q, n + 1, d),
train_Y=train_Y.reshape(b * q, n + 1, 1),
**cond_styles,
) # (b * q, q, num_buckets)
# Object for conditional posteriors
cond_posterior = BoundedRiemannPosterior(
Expand Down
29 changes: 23 additions & 6 deletions test_community/models/test_prior_fitted_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ def __init__(self, n_buckets: int = 1000):
self.style_encoder = None
self.y_style_encoder = None

def forward(self, train_X: Tensor, train_Y: Tensor, test_X: Tensor) -> Tensor:
return torch.zeros(*test_X.shape[:-1], self.n_buckets, device=test_X.device)
def forward(
self,
x: Tensor,
y: Tensor,
test_x: Tensor,
style: Tensor | None = None,
y_style: Tensor | None = None,
) -> Tensor:
return torch.zeros(*test_x.shape[:-1], self.n_buckets, device=test_x.device)


class TestPriorFittedNetwork(BotorchTestCase):
Expand Down Expand Up @@ -162,11 +169,15 @@ def test_shapes(self):

# prepare_data
X = torch.rand(5, 3, **tkwargs)
X, train_X, train_Y, orig_X_shape = pfn._prepare_data(X)
X, train_X, train_Y, orig_X_shape, styles = pfn._prepare_data(X)
self.assertEqual(X.shape, torch.Size([1, 5, 3]))
self.assertEqual(train_X.shape, torch.Size([1, 10, 3]))
self.assertEqual(train_Y.shape, torch.Size([1, 10, 1]))
self.assertEqual(orig_X_shape, torch.Size([5, 3]))
self.assertEqual(styles, {})
pfn.style = torch.rand(4, **tkwargs)
X, train_X, train_Y, orig_X_shape, styles = pfn._prepare_data(X)
self.assertEqual(styles["style"].shape, torch.Size([1, 1, 4]))

def test_input_transform(self):
model = PFNModel(
Expand Down Expand Up @@ -204,7 +215,7 @@ def test_style_hyperparameters(self):
orig_forward = dummy_pfn.forward
dummy_pfn.forward = lambda *a, **kw: (
captured.update(kw),
orig_forward(*a[:3]),
orig_forward(**kw),
)[1]

pfn.posterior(torch.rand(5, 3))
Expand All @@ -227,7 +238,7 @@ def test_style_params_require_style_hyperparameters(self):

captured = {}
orig = dummy_pfn.forward
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(*a[:3]))[1]
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(**kw))[1]

PFNModel(train_X, train_Y, dummy_pfn).posterior(torch.rand(5, 3))
self.assertNotIn("style", captured)
Expand All @@ -244,7 +255,7 @@ def test_raw_style_tensor(self):
captured = {}
dummy_pfn = DummyPFN()
orig = dummy_pfn.forward
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(*a[:3]))[1]
dummy_pfn.forward = lambda *a, **kw: (captured.update(kw), orig(**kw))[1]

pfn = PFNModel(train_X, train_Y, dummy_pfn, style=style)
pfn.posterior(torch.rand(5, 3))
Expand Down Expand Up @@ -444,6 +455,7 @@ def test_compute_conditional_means(self):
X=X,
train_X=torch.zeros(3, 4, 5),
train_Y=torch.zeros(3, 4, 1),
styles={"style": torch.zeros(3, 7)},
marginals=marginals,
)
res = mock_pfn_predict.call_args[1]
Expand All @@ -463,6 +475,9 @@ def test_compute_conditional_means(self):
self.assertTrue(
torch.equal(torch.round(res["train_Y"], decimals=2), torch.cat(a, dim=0))
)
# Verify style and y_style are passed correctly
self.assertTrue(torch.equal(res["style"], torch.zeros(6, 7)))
self.assertNotIn("y_style", res)

def test_estimate_correlations(self):
probabilities = torch.ones(2, 3, 1000)
Expand All @@ -481,6 +496,7 @@ def test_estimate_correlations(self):
X=torch.ones(2, 3, 5),
train_X=torch.zeros(2, 4, 5),
train_Y=torch.zeros(2, 4, 1),
styles={"style": torch.zeros(2, 4), "y_style": torch.ones(2, 4)},
marginals=marginals,
)
self.assertAllClose(torch.diagonal(R, dim1=-2, dim2=-1), torch.ones(2, 3))
Expand All @@ -499,6 +515,7 @@ def test_estimate_correlations(self):
X=torch.ones(1, 3, 5),
train_X=torch.zeros(1, 4, 5),
train_Y=torch.zeros(1, 4, 1),
styles={"style": torch.zeros(1, 4), "y_style": torch.ones(1, 4)},
marginals=marginals,
)
self.assertEqual(R.shape, torch.Size([1, 3, 3]))
Expand Down