File tree 1 file changed +8
-2
lines changed
1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -555,12 +555,15 @@ def test_outputs_consistency(self):
555
555
556
556
def test_explicit_input_from_constant (self ):
557
557
x = pt .dscalar ("x" )
558
- y = constant (1.0 , name = "y" )
558
+ y = constant (1.0 , dtype = x . type . dtype , name = "y" )
559
559
test_ofg = OpFromGraph ([x , y ], [x + y ])
560
560
561
561
out = test_ofg (x , y )
562
562
assert out .eval ({x : 5 }) == 6
563
563
564
+ out = test_ofg (x , x )
565
+ assert out .eval ({x : 5 }) == 10
566
+
564
567
def test_explicit_input_from_shared (self ):
565
568
x = pt .dscalar ("x" )
566
569
y = shared (1.0 , name = "y" )
@@ -576,7 +579,10 @@ def test_explicit_input_from_shared(self):
576
579
out = test_ofg (x , y )
577
580
assert out .eval ({x : 5 }) == 6
578
581
y .set_value (2.0 )
579
- assert out .eval ({x : 6 })
582
+ assert out .eval ({x : 6 }) == 8
583
+
584
+ out = test_ofg (y , y )
585
+ assert out .eval () == 4
580
586
581
587
582
588
@config .change_flags (floatX = "float64" )
You can’t perform that action at this time.
0 commit comments