@@ -393,14 +393,19 @@ class AR(SymbolicDistribution):
393
393
394
394
"""
395
395
396
- def __new__ (cls , * args , steps = None , ** kwargs ):
396
+ def __new__ (cls , name , rho , * args , steps = None , constant = False , ar_order = None , ** kwargs ):
397
+ rhos = at .atleast_1d (at .as_tensor_variable (floatX (rho )))
398
+ ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
397
399
steps = get_steps (
398
400
steps = steps ,
399
401
shape = None , # Shape will be checked in `cls.dist`
400
402
dims = kwargs .get ("dims" , None ),
401
403
observed = kwargs .get ("observed" , None ),
404
+ step_shape_offset = ar_order ,
405
+ )
406
+ return super ().__new__ (
407
+ cls , name , rhos , * args , steps = steps , constant = constant , ar_order = ar_order , ** kwargs
402
408
)
403
- return super ().__new__ (cls , * args , steps = steps , ** kwargs )
404
409
405
410
@classmethod
406
411
def dist (
@@ -426,34 +431,12 @@ def dist(
426
431
)
427
432
init_dist = kwargs ["init" ]
428
433
429
- steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ))
434
+ ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
435
+ steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ), step_shape_offset = ar_order )
430
436
if steps is None :
431
437
raise ValueError ("Must specify steps or shape parameter" )
432
438
steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
433
439
434
- if ar_order is None :
435
- # If ar_order is not specified we do constant folding on the shape of rhos
436
- # to retrieve it. For example, this will detect that
437
- # Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
438
- shape_fg = FunctionGraph (
439
- outputs = [rhos .shape [- 1 ]],
440
- features = [ShapeFeature ()],
441
- clone = True ,
442
- )
443
- (folded_shape ,) = optimize_graph (shape_fg , custom_opt = topo_constant_folding ).outputs
444
- folded_shape = getattr (folded_shape , "data" , None )
445
- if folded_shape is None :
446
- raise ValueError (
447
- "Could not infer ar_order from last dimension of rho. Pass it "
448
- "explictily or make sure rho have a static shape"
449
- )
450
- ar_order = int (folded_shape ) - int (constant )
451
- if ar_order < 1 :
452
- raise ValueError (
453
- "Inferred ar_order is smaller than 1. Increase the last dimension "
454
- "of rho or remove constant_term"
455
- )
456
-
457
440
if init_dist is not None :
458
441
if not isinstance (init_dist , TensorVariable ) or not isinstance (
459
442
init_dist .owner .op , RandomVariable
@@ -477,6 +460,41 @@ def dist(
477
460
478
461
return super ().dist ([rhos , sigma , init_dist , steps , ar_order , constant ], ** kwargs )
479
462
463
+ @classmethod
464
+ def _get_ar_order (cls , rhos : TensorVariable , ar_order : Optional [int ], constant : bool ) -> int :
465
+ """Compute ar_order given inputs
466
+
467
+ If ar_order is not specified we do constant folding on the shape of rhos
468
+ to retrieve it. For example, this will detect that
469
+ Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
470
+
471
+ Raises
472
+ ------
473
+ ValueError
474
+ If inferred ar_order cannot be inferred from rhos or if it is less than 1
475
+ """
476
+ if ar_order is None :
477
+ shape_fg = FunctionGraph (
478
+ outputs = [rhos .shape [- 1 ]],
479
+ features = [ShapeFeature ()],
480
+ clone = True ,
481
+ )
482
+ (folded_shape ,) = optimize_graph (shape_fg , custom_opt = topo_constant_folding ).outputs
483
+ folded_shape = getattr (folded_shape , "data" , None )
484
+ if folded_shape is None :
485
+ raise ValueError (
486
+ "Could not infer ar_order from last dimension of rho. Pass it "
487
+ "explictily or make sure rho have a static shape"
488
+ )
489
+ ar_order = int (folded_shape ) - int (constant )
490
+ if ar_order < 1 :
491
+ raise ValueError (
492
+ "Inferred ar_order is smaller than 1. Increase the last dimension "
493
+ "of rho or remove constant_term"
494
+ )
495
+
496
+ return ar_order
497
+
480
498
@classmethod
481
499
def num_rngs (cls , * args , ** kwargs ):
482
500
return 2
@@ -540,7 +558,7 @@ def step(*args):
540
558
fn = step ,
541
559
outputs_info = [{"initial" : init_ .T , "taps" : range (- ar_order , 0 )}],
542
560
non_sequences = [rhos_bcast_ .T [::- 1 ], sigma_ .T , noise_rng ],
543
- n_steps = at . max (( 0 , steps_ - ar_order )) ,
561
+ n_steps = steps_ ,
544
562
strict = True ,
545
563
)
546
564
(noise_next_rng ,) = tuple (innov_updates_ .values ())
0 commit comments