@@ -156,9 +156,8 @@ def check_broadcast(v1, v2):
156
156
which may wrongly be interpreted as broadcastable.
157
157
158
158
"""
159
- if not isinstance (v1 . type , TensorType ) and not isinstance (v2 . type , TensorType ):
159
+ if not hasattr (v1 , "broadcastable" ) and not hasattr (v2 , "broadcastable" ):
160
160
return
161
-
162
161
msg = (
163
162
"The broadcast pattern of the output of scan (%s) is "
164
163
"inconsistent with the one provided in `output_info` "
@@ -169,13 +168,13 @@ def check_broadcast(v1, v2):
169
168
"them consistent, e.g. using pytensor.tensor."
170
169
"{unbroadcast, specify_broadcastable}."
171
170
)
172
- size = min (v1 .type . ndim , v2 .type . ndim )
171
+ size = min (len ( v1 .broadcastable ), len ( v2 .broadcastable ) )
173
172
for n , (b1 , b2 ) in enumerate (
174
- zip (v1 .type . broadcastable [- size :], v2 . type .broadcastable [- size :])
173
+ zip (v1 .broadcastable [- size :], v2 .broadcastable [- size :])
175
174
):
176
175
if b1 != b2 :
177
- a1 = n + size - v1 .type . ndim + 1
178
- a2 = n + size - v2 .type . ndim + 1
176
+ a1 = n + size - len ( v1 .broadcastable ) + 1
177
+ a2 = n + size - len ( v2 .broadcastable ) + 1
179
178
raise TypeError (msg % (v1 .type , v2 .type , a1 , b1 , b2 , a2 ))
180
179
181
180
@@ -624,7 +623,6 @@ def validate_inner_graph(self):
624
623
type_input = self .inner_inputs [inner_iidx ].type
625
624
type_output = self .inner_outputs [inner_oidx ].type
626
625
if (
627
- # TODO: Use the `Type` interface for this
628
626
type_input .dtype != type_output .dtype
629
627
or type_input .broadcastable != type_output .broadcastable
630
628
):
0 commit comments