@@ -86,10 +86,10 @@ def test_gaussian_random_walk_init_dist_shape(self, init):
86
86
grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , size = (5 ,))
87
87
assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == (5 ,)
88
88
89
- grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = 1 )
89
+ grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = 2 )
90
90
assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == ()
91
91
92
- grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = (5 , 1 ))
92
+ grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = (5 , 2 ))
93
93
assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == (5 ,)
94
94
95
95
grw = pm .GaussianRandomWalk .dist (mu = [0 , 0 ], sigma = 1 , steps = 1 , init = init )
@@ -113,6 +113,21 @@ def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
113
113
assert tuple (grw .shape .eval ()) == (2 , 3 , 5 )
114
114
assert grw .eval ().shape == (2 , 3 , 5 )
115
115
116
+ @pytest .mark .parametrize ("shape" , ((6 ,), (3 , 6 )))
117
+ def test_inferred_steps_from_shape (self , shape ):
118
+ x = GaussianRandomWalk .dist (shape = shape )
119
+ steps = x .owner .inputs [- 1 ]
120
+ assert steps .eval () == 5
121
+
122
+ @pytest .mark .parametrize ("shape" , (None , (5 , ...)))
123
+ def test_missing_steps (self , shape ):
124
+ with pytest .raises (ValueError , match = "Must specify steps or shape parameter" ):
125
+ GaussianRandomWalk .dist (shape = shape )
126
+
127
+ def test_inconsistent_steps_and_shape (self ):
128
+ with pytest .raises (AssertionError , match = "Steps do not match last shape dimension" ):
129
+ x = GaussianRandomWalk .dist (steps = 12 , shape = 45 )
130
+
116
131
@pytest .mark .parametrize (
117
132
"init" ,
118
133
[
0 commit comments