Skip to content

Wrong gradients when inputs are dynamically broadcasted #1089

@ricardoV94

Description

@ricardoV94

This bug is an unexpected consequence of #928 and rewrites that make certain assumptions: #1089 (comment)

import aesara
import aesara.tensor as at
import numpy as np

x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_row = aesara.function([x_row, y], x_row_grad)
print(f_row(np.ones((1, 5)), np.ones((5, 5))))
# [[5. 5. 5. 5. 5.]]

f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
print(f_matrix(np.ones((1, 5)), np.ones((5, 5))))
# [[1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]]

The faulty logic is found here:

# sum out the broadcasted dimensions
for i, ipt in enumerate(inputs):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue
# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
if bcast and not outs[0].broadcastable[j]
]
if to_sum:
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr
return rval

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

_, shape_bcast = at.infer_broadcastable(shape)

def infer_broadcastable(shape):

And also GEMM since #986

I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.

Solution might look something like: #1089 (comment)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions