-
-
Notifications
You must be signed in to change notification settings - Fork 151
Description
This bug is an unexpected consequence of #928 and rewrites that make certain assumptions: #1089 (comment)
import aesara
import aesara.tensor as at
import numpy as np
x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")
x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)
f_row = aesara.function([x_row, y], x_row_grad)
print(f_row(np.ones((1, 5)), np.ones((5, 5))))
# [[5. 5. 5. 5. 5.]]
f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
print(f_matrix(np.ones((1, 5)), np.ones((5, 5))))
# [[1. 1. 1. 1. 1.]
# [1. 1. 1. 1. 1.]
# [1. 1. 1. 1. 1.]
# [1. 1. 1. 1. 1.]
# [1. 1. 1. 1. 1.]]
The faulty logic is found here:
aesara/aesara/tensor/elemwise.py
Lines 552 to 570 in 7f4e0ab
# sum out the broadcasted dimensions | |
for i, ipt in enumerate(inputs): | |
if isinstance(rval[i].type, (NullType, DisconnectedType)): | |
continue | |
# List of all the dimensions that are broadcastable for input[i] so | |
# we can sum over them | |
# TODO: only count dimensions that were effectively broadcasted | |
to_sum = [ | |
j | |
for j, bcast in enumerate(ipt.type.broadcastable) | |
if bcast and not outs[0].broadcastable[j] | |
] | |
if to_sum: | |
sr = at_sum(rval[i], axis=to_sum, keepdims=True) | |
rval[i] = sr | |
return rval |
This is also likely a problem in the grad of BroadcastTo
which calls infer_broadcastable
and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.
aesara/aesara/tensor/extra_ops.py
Line 1593 in 7f8af9b
_, shape_bcast = at.infer_broadcastable(shape) |
Line 1313 in 7f8af9b
def infer_broadcastable(shape): |
And also GEMM since #986
I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.
Solution might look something like: #1089 (comment)