Skip to content

Commit 9e98224

Browse files
committed
Revert "Clean up some usage of the TensorType interface in Scan"
This reverts commit 471657a.
1 parent 1972970 commit 9e98224

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

pytensor/scan/op.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,8 @@ def check_broadcast(v1, v2):
156156
which may wrongly be interpreted as broadcastable.
157157
158158
"""
159-
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
159+
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
160160
return
161-
162161
msg = (
163162
"The broadcast pattern of the output of scan (%s) is "
164163
"inconsistent with the one provided in `output_info` "
@@ -169,13 +168,13 @@ def check_broadcast(v1, v2):
169168
"them consistent, e.g. using pytensor.tensor."
170169
"{unbroadcast, specify_broadcastable}."
171170
)
172-
size = min(v1.type.ndim, v2.type.ndim)
171+
size = min(len(v1.broadcastable), len(v2.broadcastable))
173172
for n, (b1, b2) in enumerate(
174-
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
173+
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
175174
):
176175
if b1 != b2:
177-
a1 = n + size - v1.type.ndim + 1
178-
a2 = n + size - v2.type.ndim + 1
176+
a1 = n + size - len(v1.broadcastable) + 1
177+
a2 = n + size - len(v2.broadcastable) + 1
179178
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
180179

181180

@@ -624,7 +623,6 @@ def validate_inner_graph(self):
624623
type_input = self.inner_inputs[inner_iidx].type
625624
type_output = self.inner_outputs[inner_oidx].type
626625
if (
627-
# TODO: Use the `Type` interface for this
628626
type_input.dtype != type_output.dtype
629627
or type_input.broadcastable != type_output.broadcastable
630628
):

0 commit comments

Comments
 (0)