@@ -500,12 +500,13 @@ def test_infer_shape(self):
500
500
501
501
502
502
class ApplyDefaultTestOp (Op ):
503
- def __init__ (self , id ):
503
+ def __init__ (self , id , n_outs = 1 ):
504
504
self .default_output = id
505
+ self .n_outs = n_outs
505
506
506
507
def make_node (self , x ):
507
508
x = at .as_tensor_variable (x )
508
- return Apply (self , [x ], [x .type ()])
509
+ return Apply (self , [x ], [x .type () for _ in range ( self . n_outs ) ])
509
510
510
511
def perform (self , * args , ** kwargs ):
511
512
raise NotImplementedError ()
@@ -556,16 +557,26 @@ def test_tensor_from_scalar(self):
556
557
y = as_tensor_variable (aes .int8 ())
557
558
assert isinstance (y .owner .op , TensorFromScalar )
558
559
559
- def test_multi_outputs (self ):
560
- good_apply_var = ApplyDefaultTestOp (0 ).make_node (self .x )
561
- as_tensor_variable (good_apply_var )
560
+ def test_default_output (self ):
561
+ good_apply_var = ApplyDefaultTestOp (0 , n_outs = 1 ).make_node (self .x )
562
+ as_tensor_variable (good_apply_var ) is good_apply_var
562
563
563
- bad_apply_var = ApplyDefaultTestOp (- 1 ).make_node (self .x )
564
- with pytest .raises (ValueError ):
564
+ good_apply_var = ApplyDefaultTestOp (- 1 , n_outs = 1 ).make_node (self .x )
565
+ as_tensor_variable (good_apply_var ) is good_apply_var
566
+
567
+ bad_apply_var = ApplyDefaultTestOp (1 , n_outs = 1 ).make_node (self .x )
568
+ with pytest .raises (IndexError ):
565
569
_ = as_tensor_variable (bad_apply_var )
566
570
567
- bad_apply_var = ApplyDefaultTestOp (2 ).make_node (self .x )
568
- with pytest .raises (ValueError ):
571
+ bad_apply_var = ApplyDefaultTestOp (2.0 , n_outs = 1 ).make_node (self .x )
572
+ with pytest .raises (TypeError ):
573
+ _ = as_tensor_variable (bad_apply_var )
574
+
575
+ good_apply_var = ApplyDefaultTestOp (1 , n_outs = 2 ).make_node (self .x )
576
+ as_tensor_variable (good_apply_var ) is good_apply_var .outputs [1 ]
577
+
578
+ bad_apply_var = ApplyDefaultTestOp (None , n_outs = 2 ).make_node (self .x )
579
+ with pytest .raises (TypeError , match = "Multi-output Op without default_output" ):
569
580
_ = as_tensor_variable (bad_apply_var )
570
581
571
582
def test_list (self ):
@@ -578,7 +589,7 @@ def test_list(self):
578
589
_ = as_tensor_variable (y )
579
590
580
591
bad_apply_var = ApplyDefaultTestOp ([0 , 1 ]).make_node (self .x )
581
- with pytest .raises (ValueError ):
592
+ with pytest .raises (TypeError ):
582
593
as_tensor_variable (bad_apply_var )
583
594
584
595
def test_ndim_strip_leading_broadcastable (self ):
0 commit comments