Skip to content

Vectorize dispatch for shape operations #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def toposort_key(


@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError

Expand All @@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return _vectorize_node(op, node, *batched_inputs)


def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)


@overload
def vectorize_graph(
outputs: Variable,
Expand Down
32 changes: 17 additions & 15 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.graph.replace import (
_vectorize_node,
_vectorize_not_needed,
vectorize_graph,
)
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
Expand Down Expand Up @@ -37,17 +41,6 @@ def operand_sig(operand: Variable, prefix: str) -> str:
return f"{inputs_sig}->{outputs_sig}"


@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))


class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.

Expand Down Expand Up @@ -361,6 +354,15 @@ def __str__(self):
return self.name


@_vectorize_node.register(Blockwise)
def vectorize_not_needed(op, node, *batch_inputs):
return op.make_node(*batch_inputs)
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))


_vectorize_node.register(Blockwise, _vectorize_not_needed)
5 changes: 2 additions & 3 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
Expand All @@ -22,7 +22,6 @@
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
Expand Down Expand Up @@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var):
raise ValueError(f"Length of {var} cannot be determined")


_vectorize_node.register(Elemwise, vectorize_not_needed)
_vectorize_node.register(Elemwise, _vectorize_not_needed)


@_vectorize_node.register(DimShuffle)
Expand Down
60 changes: 59 additions & 1 deletion pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
Expand Down Expand Up @@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim


_vectorize_node.register(Shape, _vectorize_not_needed)


def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
r"""Get a tuple of symbolic shape values.

Expand Down Expand Up @@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
raise ValueError(f"Length of {var} cannot be determined")


@_vectorize_node.register(SpecifyShape)
def _vectorize_specify_shape(op, node, x, *shape):
old_x, *old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim

if any(
as_tensor_variable(dim).type.ndim != 0
for dim in shape
if not (NoneConst.equals(dim) or dim is None)
):
raise NotImplementedError(
"It is not possible to vectorize the shape argument of SpecifyShape"
)

if len(shape) == len(old_shape):
new_shape = tuple([None] * batched_ndims) + shape
elif len(shape) == (len(old_shape) + batched_ndims):
new_shape = shape
else:
raise ValueError(
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
)

return specify_shape(x, new_shape).owner


class Reshape(COp):
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be
Expand Down Expand Up @@ -638,7 +668,7 @@ def make_node(self, x, shp):

return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])

def perform(self, node, inp, out_, params):
def perform(self, node, inp, out_, params=None):
x, shp = inp
(out,) = out_
if len(shp) != self.ndim:
Expand Down Expand Up @@ -770,6 +800,26 @@ def c_code(self, node, name, inputs, outputs, sub):
"""


@_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape):
old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim

if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError(
"It is not possible to vectorize the shape argument of Reshape"
)

if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape]
elif len(tuple(old_shape)) == (len(tuple(shape)) - batched_ndims):
new_shape = shape
else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape")

return reshape(x, new_shape, ndim=len(new_shape)).owner


def reshape(x, newshape, ndim=None):
if ndim is None:
newshape = at.as_tensor_variable(newshape)
Expand Down Expand Up @@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)


@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner
93 changes: 89 additions & 4 deletions tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from pytensor import Mode, function, grad
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
Expand Down Expand Up @@ -706,3 +706,88 @@ def test_shape_tuple():
assert isinstance(res[1], ScalarConstant)
assert res[1].data == 2
assert not isinstance(res[2], ScalarConstant)


class TestVectorize:
def test_shape(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

node = shape(vec).owner
vect_node = vectorize_node(node, mat)
assert equal_computations(vect_node.outputs, [shape(mat)])

def test_reshape(self):
x = scalar("x", dtype=int)
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

shape = (2, x)
node = reshape(vec, shape).owner
vect_node = vectorize_node(node, mat, shape)
assert equal_computations(
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]
)

new_shape = (5, 2, x)
vect_node = vectorize_node(node, mat, new_shape)
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])

with pytest.raises(NotImplementedError):
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))

with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, vec, (5, 2, x))

with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, mat, (5, 3, 2, x))

def test_specify_shape(self):
x = scalar("x", dtype=int)
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))

shape = (x, None)
node = specify_shape(mat, shape).owner
vect_node = vectorize_node(node, tns, *shape)
assert equal_computations(
vect_node.outputs, [specify_shape(tns, (None, x, None))]
)

new_shape = (5, 2, x)
vect_node = vectorize_node(node, tns, *new_shape)
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])

with pytest.raises(NotImplementedError):
vectorize_node(node, mat, *([x, x], None))

with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, mat, *(5, 2, x))

with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))

def test_unbroadcast(self):
mat = tensor(
shape=(
1,
1,
)
)
tns = tensor(shape=(4, 1, 1, 1))

node = unbroadcast(mat, 0).owner
vect_node = vectorize_node(node, tns)
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])