Skip to content

Commit acfa749

Browse files
ricardoV94michaelosthege
authored andcommitted
Allow default_output to be any valid Python index
1 parent ae7f39e commit acfa749

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

pytensor/graph/basic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,9 @@ def default_output(self):
193193
if len(self.outputs) == 1:
194194
return self.outputs[0]
195195
else:
196-
raise ValueError(f"{self.op}.default_output should be an output index.")
197-
elif not isinstance(do, int):
198-
raise ValueError(f"{self.op}.default_output should be an int or long")
199-
elif do < 0 or do >= len(self.outputs):
200-
raise ValueError(f"{self.op}.default_output is out of range.")
196+
raise ValueError(
197+
f"Multi-output Op {self.op} default_output not specified"
198+
)
201199
return self.outputs[do]
202200

203201
def __str__(self):

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _as_tensor_Apply(x, name, ndim, **kwargs):
8080
# use Apply's default output mechanism
8181
if (x.op.default_output is None) and (len(x.outputs) != 1):
8282
raise TypeError(
83-
"Multi-output Op encountered. "
83+
"Multi-output Op without default_output encountered. "
8484
"Retry using only one of the outputs directly."
8585
)
8686

tests/tensor/test_basic.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,12 +500,13 @@ def test_infer_shape(self):
500500

501501

502502
class ApplyDefaultTestOp(Op):
503-
def __init__(self, id):
503+
def __init__(self, id, n_outs=1):
504504
self.default_output = id
505+
self.n_outs = n_outs
505506

506507
def make_node(self, x):
507508
x = at.as_tensor_variable(x)
508-
return Apply(self, [x], [x.type()])
509+
return Apply(self, [x], [x.type() for _ in range(self.n_outs)])
509510

510511
def perform(self, *args, **kwargs):
511512
raise NotImplementedError()
@@ -556,16 +557,26 @@ def test_tensor_from_scalar(self):
556557
y = as_tensor_variable(aes.int8())
557558
assert isinstance(y.owner.op, TensorFromScalar)
558559

559-
def test_multi_outputs(self):
560-
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
561-
as_tensor_variable(good_apply_var)
560+
def test_default_output(self):
561+
good_apply_var = ApplyDefaultTestOp(0, n_outs=1).make_node(self.x)
562+
as_tensor_variable(good_apply_var) is good_apply_var
562563

563-
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
564-
with pytest.raises(ValueError):
564+
good_apply_var = ApplyDefaultTestOp(-1, n_outs=1).make_node(self.x)
565+
as_tensor_variable(good_apply_var) is good_apply_var
566+
567+
bad_apply_var = ApplyDefaultTestOp(1, n_outs=1).make_node(self.x)
568+
with pytest.raises(IndexError):
565569
_ = as_tensor_variable(bad_apply_var)
566570

567-
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
568-
with pytest.raises(ValueError):
571+
bad_apply_var = ApplyDefaultTestOp(2.0, n_outs=1).make_node(self.x)
572+
with pytest.raises(TypeError):
573+
_ = as_tensor_variable(bad_apply_var)
574+
575+
good_apply_var = ApplyDefaultTestOp(1, n_outs=2).make_node(self.x)
576+
as_tensor_variable(good_apply_var) is good_apply_var.outputs[1]
577+
578+
bad_apply_var = ApplyDefaultTestOp(None, n_outs=2).make_node(self.x)
579+
with pytest.raises(TypeError, match="Multi-output Op without default_output"):
569580
_ = as_tensor_variable(bad_apply_var)
570581

571582
def test_list(self):
@@ -578,7 +589,7 @@ def test_list(self):
578589
_ = as_tensor_variable(y)
579590

580591
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
581-
with pytest.raises(ValueError):
592+
with pytest.raises(TypeError):
582593
as_tensor_variable(bad_apply_var)
583594

584595
def test_ndim_strip_leading_broadcastable(self):

0 commit comments

Comments
 (0)