@@ -187,6 +187,7 @@ def posterior(
187187 output_indices : Optional [list [int ]] = None ,
188188 observation_noise : Union [bool , Tensor ] = False ,
189189 posterior_transform : Optional [PosteriorTransform ] = None ,
190+ negate_train_ys : bool = False ,
190191 ) -> BoundedRiemannPosterior :
191192 r"""Computes the posterior over model outputs at the provided points.
192193
@@ -200,6 +201,8 @@ def posterior(
200201 output_indices: **Currently not supported for PFNModel.**
201202 observation_noise: **Currently not supported for PFNModel**.
202203 posterior_transform: **Currently not supported for PFNModel**.
204+ negate_train_ys: Whether to negate the training Ys. This is useful
205+ for minimization.
203206
204207 Returns:
205208 A `BoundedRiemannPosterior`, representing a batch of b? x q?`
@@ -218,17 +221,14 @@ def posterior(
218221 if posterior_transform is not None :
219222 raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
220223
221- X , train_X , train_Y , orig_X_shape = self ._prepare_data (X )
224+ X , train_X , train_Y , orig_X_shape = self ._prepare_data (
225+ X , negate_train_ys = negate_train_ys
226+ )
222227
223- styles = get_styles (
224- model = self .pfn ,
228+ styles = self ._get_styles (
225229 hps = self .style_hyperparameters ,
226230 batch_size = X .shape [0 ],
227- device = X .device ,
228231 )
229- if self .style is not None :
230- assert styles == {}, "Cannot provide both style and style_hyperparameters."
231- styles ["style" ] = self .style [None ].repeat (X .shape [0 ], 1 , 1 ).to (X .device )
232232
233233 probabilities = self .pfn_predict (
234234 X = X ,
@@ -246,7 +246,9 @@ def posterior(
246246 probabilities = probabilities ,
247247 )
248248
249- def _prepare_data (self , X : Tensor ) -> tuple [Tensor , Tensor , Tensor , torch .Size ]:
249+ def _prepare_data (
250+ self , X : Tensor , negate_train_ys : bool = False
251+ ) -> tuple [Tensor , Tensor , Tensor , torch .Size ]:
250252 orig_X_shape = X .shape # X has shape b? x q? x d
251253 if len (X .shape ) > 3 :
252254 raise UnsupportedError (f"X must be at most 3-d, got { X .shape } ." )
@@ -256,9 +258,29 @@ def _prepare_data(self, X: Tensor) -> tuple[Tensor, Tensor, Tensor, torch.Size]:
256258 X = self .transform_inputs (X ) # shape (b , q, d)
257259
258260 train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
259- train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
261+ if negate_train_ys :
262+ assert self .train_Y .mean ().abs () < 1e-4 , "train_Y must be zero-centered."
263+ train_Y = match_batch_shape (
264+ - self .train_Y if negate_train_ys else self .train_Y , X
265+ ) # shape (b, n, 1)
260266 return X , train_X , train_Y , orig_X_shape
261267
268+ def _get_styles (self , hps , batch_size ) -> dict [str , Tensor ]:
269+ style_kwargs = get_styles (
270+ model = self .pfn ,
271+ hps = hps ,
272+ batch_size = batch_size ,
273+ device = self .train_X .device ,
274+ )
275+ if self .style is not None :
276+ assert (
277+ style_kwargs == {}
278+ ), "Cannot provide both style and style_hyperparameters."
279+ style_kwargs ["style" ] = (
280+ self .style [None ].repeat (batch_size , 1 , 1 ).to (self .train_X .device )
281+ )
282+ return style_kwargs
283+
262284 def pfn_predict (
263285 self ,
264286 X : Tensor ,
@@ -277,6 +299,7 @@ def pfn_predict(
277299
278300 Returns: probabilities (b, q, num_buckets) for Riemann posterior.
279301 """
302+
280303 if not self .batch_first :
281304 X = X .transpose (0 , 1 ) # shape (q, b, d)
282305 train_X = train_X .transpose (0 , 1 ) # shape (n, b, d)
@@ -300,6 +323,93 @@ def borders(self):
300323 return self .pfn .criterion .borders .to (self .train_X .dtype )
301324
302325
326+ class PFNModelWithPendingPoints (PFNModel ):
327+ def posterior (
328+ self ,
329+ X : Tensor ,
330+ output_indices : Optional [list [int ]] = None ,
331+ observation_noise : Union [bool , Tensor ] = False ,
332+ posterior_transform : Optional [PosteriorTransform ] = None ,
333+ pending_X : Optional [Tensor ] = None ,
334+ negate_train_ys : bool = False ,
335+ ) -> BoundedRiemannPosterior :
336+ r"""Computes the posterior over model outputs at the provided points.
337+
338+ Note: The input transforms should be applied here using
339+ `self.transform_inputs(X)` after the `self.eval()` call and before
340+ any `model.forward` or `model.likelihood` calls.
341+
342+ Args:
343+ X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
344+ feature space.
345+ output_indices: **Currently not supported for PFNModel.**
346+ observation_noise: **Currently not supported for PFNModel**.
347+ posterior_transform: **Currently not supported for PFNModel**.
348+ pending_X: A tensor of shape n'' x d, where n'' is the number of
349+ pending points, which are to be observed but the value is
350+ not yet known.
351+ negate_train_ys: Whether to negate the training Ys. This is useful
352+ for minimization.
353+
354+ Returns:
355+ A `BoundedRiemannPosterior`, representing a batch of b? x q?`
356+ distributions.
357+ """
358+ self .pfn .eval ()
359+ if output_indices is not None :
360+ raise UnsupportedError (
361+ "output_indices is not None. PFNModel should not "
362+ "be a multi-output model."
363+ )
364+ if observation_noise :
365+ logger .warning (
366+ "observation_noise is not supported for PFNModel and is being ignored."
367+ )
368+ if posterior_transform is not None :
369+ raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
370+
371+ X , train_X , train_Y , orig_X_shape = self ._prepare_data (
372+ X , negate_train_ys = negate_train_ys
373+ )
374+
375+ styles = self ._get_styles (
376+ hps = self .style_hyperparameters ,
377+ batch_size = X .shape [0 ],
378+ )
379+
380+ if pending_X is not None :
381+ assert pending_X .dim () == 2 , "pending_X must be 2-dimensional."
382+ pending_X = pending_X [None ].repeat (X .shape [0 ], 1 , 1 ) # shape (b, n', d)
383+ train_X = torch .cat ([train_X , pending_X ], dim = 1 ) # shape (b, n+n', d)
384+ train_Y = torch .cat (
385+ [
386+ train_Y ,
387+ torch .full (
388+ (train_Y .shape [0 ], pending_X .shape [1 ], 1 ),
389+ torch .nan ,
390+ device = train_Y .device ,
391+ ),
392+ ],
393+ dim = 1 ,
394+ ) # shape (b, n+n', 1)
395+
396+ probabilities = self .pfn_predict (
397+ X = X ,
398+ train_X = train_X ,
399+ train_Y = train_Y ,
400+ ** self .constant_model_kwargs ,
401+ ** styles ,
402+ ) # (b, q, num_buckets)
403+ probabilities = probabilities .view (
404+ * orig_X_shape [:- 1 ], - 1
405+ ) # (b?, q?, num_buckets)
406+
407+ return BoundedRiemannPosterior (
408+ borders = self .borders ,
409+ probabilities = probabilities ,
410+ )
411+
412+
303413class MultivariatePFNModel (PFNModel ):
304414 """A multivariate PFN model that returns a joint posterior over q batch inputs.
305415
0 commit comments