@@ -51,6 +51,7 @@ def get_styles(
5151 .hyperparameters_dict_to_tensor (hps_subset )
5252 .repeat (batch_size , 1 )
5353 .to (device )
54+ .float ()
5455 ) # shape (batch_size, num_styles)
5556 style_kwargs ["style" ] = style
5657
@@ -65,6 +66,7 @@ def get_styles(
6566 .hyperparameters_dict_to_tensor (hps_subset )
6667 .repeat (batch_size , 1 )
6768 .to (device )
69+ .float ()
6870 ) # shape (batch_size, num_styles)
6971 style_kwargs ["y_style" ] = y_style
7072 return style_kwargs
@@ -172,6 +174,7 @@ def __init__(
172174 # so here we use a FixedNoiseGaussianLikelihood that is unused.
173175 if train_Yvar is None :
174176 train_Yvar = torch .zeros_like (train_Y )
177+ self .train_Yvar = train_Yvar # shape: (n, 1)
175178 self .likelihood = FixedNoiseGaussianLikelihood (noise = train_Yvar )
176179 self .pfn = model .to (device = train_X .device )
177180 self .batch_first = batch_first
@@ -180,6 +183,17 @@ def __init__(
180183 self .style = style
181184 if input_transform is not None :
182185 self .input_transform = input_transform
186+ self ._compute_styles ()
187+
188+ def _compute_styles (self ):
189+ """
190+ Can be used to compute styles to be used for PFN prediction based on
191+ training data.
192+
193+ When implemented, will directly modify self.style_hyperparameters or
194+ self.style.
195+ """
196+ pass
183197
184198 def posterior (
185199 self ,
@@ -221,15 +235,10 @@ def posterior(
221235 if posterior_transform is not None :
222236 raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
223237
224- X , train_X , train_Y , orig_X_shape = self ._prepare_data (
238+ X , train_X , train_Y , orig_X_shape , styles = self ._prepare_data (
225239 X , negate_train_ys = negate_train_ys
226240 )
227241
228- styles = self ._get_styles (
229- hps = self .style_hyperparameters ,
230- batch_size = X .shape [0 ],
231- )
232-
233242 probabilities = self .pfn_predict (
234243 X = X ,
235244 train_X = train_X ,
@@ -248,7 +257,7 @@ def posterior(
248257
249258 def _prepare_data (
250259 self , X : Tensor , negate_train_ys : bool = False
251- ) -> tuple [Tensor , Tensor , Tensor , torch .Size ]:
260+ ) -> tuple [Tensor , Tensor , Tensor , torch .Size , dict [ str , Tensor ] ]:
252261 orig_X_shape = X .shape # X has shape b? x q? x d
253262 if len (X .shape ) > 3 :
254263 raise UnsupportedError (f"X must be at most 3-d, got { X .shape } ." )
@@ -258,17 +267,19 @@ def _prepare_data(
258267 X = self .transform_inputs (X ) # shape (b , q, d)
259268
260269 train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
270+ train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
261271 if negate_train_ys :
262272 assert self .train_Y .mean ().abs () < 1e-4 , "train_Y must be zero-centered."
263- train_Y = match_batch_shape (
264- - self .train_Y if negate_train_ys else self .train_Y , X
265- ) # shape (b, n, 1)
266- return X , train_X , train_Y , orig_X_shape
273+ train_Y = - train_Y
274+ styles = self ._get_styles (
275+ batch_size = X .shape [0 ],
276+ ) # shape (b, num_styles)
277+ return X , train_X , train_Y , orig_X_shape , styles
267278
268- def _get_styles (self , hps , batch_size ) -> dict [str , Tensor ]:
279+ def _get_styles (self , batch_size ) -> dict [str , Tensor ]:
269280 style_kwargs = get_styles (
270281 model = self .pfn ,
271- hps = hps ,
282+ hps = self . style_hyperparameters ,
272283 batch_size = batch_size ,
273284 device = self .train_X .device ,
274285 )
@@ -277,7 +288,10 @@ def _get_styles(self, hps, batch_size) -> dict[str, Tensor]:
277288 style_kwargs == {}
278289 ), "Cannot provide both style and style_hyperparameters."
279290 style_kwargs ["style" ] = (
280- self .style [None ].repeat (batch_size , 1 , 1 ).to (self .train_X .device )
291+ self .style [None ]
292+ .repeat (batch_size , 1 , 1 )
293+ .to (self .train_X .device )
294+ .float ()
281295 )
282296 return style_kwargs
283297
@@ -306,9 +320,9 @@ def pfn_predict(
306320 train_Y = train_Y .transpose (0 , 1 ) # shape (n, b, 1)
307321
308322 logits = self .pfn (
309- train_X .float (),
310- train_Y .float (),
311- X .float (),
323+ x = train_X .float (),
324+ y = train_Y .float (),
325+ test_x = X .float (),
312326 ** forward_kwargs ,
313327 )
314328 if not self .batch_first :
@@ -368,15 +382,10 @@ def posterior(
368382 if posterior_transform is not None :
369383 raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
370384
371- X , train_X , train_Y , orig_X_shape = self ._prepare_data (
385+ X , train_X , train_Y , orig_X_shape , styles = self ._prepare_data (
372386 X , negate_train_ys = negate_train_ys
373387 )
374388
375- styles = self ._get_styles (
376- hps = self .style_hyperparameters ,
377- batch_size = X .shape [0 ],
378- )
379-
380389 if pending_X is not None :
381390 assert pending_X .dim () == 2 , "pending_X must be 2-dimensional."
382391 pending_X = pending_X [None ].repeat (X .shape [0 ], 1 , 1 ) # shape (b, n', d)
@@ -452,12 +461,13 @@ def posterior(
452461 if len (X .shape ) == 1 or X .shape [- 2 ] == 1 :
453462 # No q dimension, or q=1
454463 return marginals
455- X , train_X , train_Y , orig_X_shape = self ._prepare_data (X )
464+ X , train_X , train_Y , orig_X_shape , styles = self ._prepare_data (X )
456465 # Estimate correlation structure, making another forward pass.
457466 R = self .estimate_correlations (
458467 X = X ,
459468 train_X = train_X ,
460469 train_Y = train_Y ,
470+ styles = styles ,
461471 marginals = marginals ,
462472 ) # (b, q, q)
463473 R = R .view (* orig_X_shape [:- 2 ], X .shape [- 2 ], X .shape [- 2 ]) # (b?, q, q)
@@ -472,6 +482,7 @@ def estimate_correlations(
472482 X : Tensor ,
473483 train_X : Tensor ,
474484 train_Y : Tensor ,
485+ styles : dict [str , Tensor ],
475486 marginals : BoundedRiemannPosterior ,
476487 ) -> Tensor :
477488 """
@@ -488,6 +499,7 @@ def estimate_correlations(
488499 X: evaluation point, shape (b, q, d)
489500 train_X: Training X, shape (b, n, d)
490501 train_Y: Training Y, shape (b, n, 1)
502+ styles: dict from name to tensor shaped (b, ns) for any styles.
491503 marginals: A posterior object with marginal posteriors for f(X), but no
492504 correlation structure yet added. posterior.probabilities has
493505 shape (b?, q, num_buckets).
@@ -499,6 +511,7 @@ def estimate_correlations(
499511 X = X ,
500512 train_X = train_X ,
501513 train_Y = train_Y ,
514+ styles = styles ,
502515 marginals = marginals ,
503516 )
504517 # Get marginal moments
@@ -525,6 +538,7 @@ def _compute_conditional_means(
525538 X : Tensor ,
526539 train_X : Tensor ,
527540 train_Y : Tensor ,
541+ styles : dict [str , Tensor ],
528542 marginals : BoundedRiemannPosterior ,
529543 ) -> tuple [Tensor , Tensor ]:
530544 """
@@ -538,6 +552,7 @@ def _compute_conditional_means(
538552 X: evaluation point, shape (b, q, d)
539553 train_X: Training X, shape (b, n, d)
540554 train_Y: Training Y, shape (b, n, 1)
555+ styles: dict from name to tensor shaped (b, ns) for any styles.
541556 marginals: A posterior object with marginal posteriors for f(X), but no
542557 correlation structure yet added. posterior.probabilities has
543558 shape (b?, q, num_buckets).
@@ -561,13 +576,18 @@ def _compute_conditional_means(
561576 train_Y = train_Y .unsqueeze (1 ).expand (b , q , n , 1 )
562577 cond_Y = cond_val .unsqueeze (- 1 ).unsqueeze (- 1 ) # (b, q, 1, 1)
563578 train_Y = torch .cat ((train_Y , cond_Y ), dim = - 2 ) # (b, q, n+1, 1)
579+ cond_styles = {}
580+ for name , style in styles .items ():
581+ ns = style .shape [- 1 ]
582+ cond_styles [name ] = style .unsqueeze (- 2 ).expand (b , q , ns ).reshape (b * q , ns )
564583 # Construct eval points
565584 eval_X = X .unsqueeze (1 ).expand (b , q , q , d )
566585 # Squeeze everything into necessary 2 batch dims, and do PFN forward pass
567586 cond_probabilities = self .pfn_predict (
568587 X = eval_X .reshape (b * q , q , d ),
569588 train_X = train_X .reshape (b * q , n + 1 , d ),
570589 train_Y = train_Y .reshape (b * q , n + 1 , 1 ),
590+ ** cond_styles ,
571591 ) # (b * q, q, num_buckets)
572592 # Object for conditional posteriors
573593 cond_posterior = BoundedRiemannPosterior (
0 commit comments