diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 3c07e21232..8688edcf91 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -204,7 +204,7 @@ def toposort_key( @singledispatch -def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: +def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply: # Default implementation is provided in pytensor.tensor.blockwise raise NotImplementedError @@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply: return _vectorize_node(op, node, *batched_inputs) +def _vectorize_not_needed(op, node, *batched_inputs): + return op.make_node(*batched_inputs) + + @overload def vectorize_graph( outputs: Variable, diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 8e44fbbbc6..96357f59f8 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -8,7 +8,11 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op -from pytensor.graph.replace import _vectorize_node, vectorize_graph +from pytensor.graph.replace import ( + _vectorize_node, + _vectorize_not_needed, + vectorize_graph, +) from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor @@ -37,17 +41,6 @@ def operand_sig(operand: Variable, prefix: str) -> str: return f"{inputs_sig}->{outputs_sig}" -@_vectorize_node.register(Op) -def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: - if hasattr(op, "gufunc_signature"): - signature = op.gufunc_signature - else: - # TODO: This is pretty bad for shape inference and merge optimization! - # Should get better as we add signatures to our Ops - signature = safe_signature(node.inputs, node.outputs) - return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) - - class Blockwise(Op): """Generalizes a core `Op` to work with batched dimensions. @@ -361,6 +354,15 @@ def __str__(self): return self.name -@_vectorize_node.register(Blockwise) -def vectorize_not_needed(op, node, *batch_inputs): - return op.make_node(*batch_inputs) +@_vectorize_node.register(Op) +def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: + if hasattr(op, "gufunc_signature"): + signature = op.gufunc_signature + else: + # TODO: This is pretty bad for shape inference and merge optimization! + # Should get better as we add signatures to our Ops + signature = safe_signature(node.inputs, node.outputs) + return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) + + +_vectorize_node.register(Blockwise, _vectorize_not_needed) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 9b7d1d59df..f43ec0846b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -7,7 +7,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.null_type import NullType -from pytensor.graph.replace import _vectorize_node +from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp @@ -22,7 +22,6 @@ from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable -from pytensor.tensor.blockwise import vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var): raise ValueError(f"Length of {var} cannot be determined") -_vectorize_node.register(Elemwise, vectorize_not_needed) +_vectorize_node.register(Elemwise, _vectorize_not_needed) @_vectorize_node.register(DimShuffle) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 2284d563bc..7f83b7c197 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -8,6 +8,7 @@ import pytensor from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable +from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.type import HasShape from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType @@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var): return var.owner.inputs[0].type.ndim +_vectorize_node.register(Shape, _vectorize_not_needed) + + def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: r"""Get a tuple of symbolic shape values. @@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var): raise ValueError(f"Length of {var} cannot be determined") +@_vectorize_node.register(SpecifyShape) +def _vectorize_specify_shape(op, node, x, *shape): + old_x, *old_shape = node.inputs + batched_ndims = x.type.ndim - old_x.type.ndim + + if any( + as_tensor_variable(dim).type.ndim != 0 + for dim in shape + if not (NoneConst.equals(dim) or dim is None) + ): + raise NotImplementedError( + "It is not possible to vectorize the shape argument of SpecifyShape" + ) + + if len(shape) == len(old_shape): + new_shape = tuple([None] * batched_ndims) + shape + elif len(shape) == (len(old_shape) + batched_ndims): + new_shape = shape + else: + raise ValueError( + "Invalid number of shape arguments passed into vectorize node of SpecifyShape" + ) + + return specify_shape(x, new_shape).owner + + class Reshape(COp): """Perform a reshape operation of the input x to the new shape shp. The number of dimensions to which to reshape to (ndim) must be @@ -638,7 +668,7 @@ def make_node(self, x, shp): return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)]) - def perform(self, node, inp, out_, params): + def perform(self, node, inp, out_, params=None): x, shp = inp (out,) = out_ if len(shp) != self.ndim: @@ -770,6 +800,26 @@ def c_code(self, node, name, inputs, outputs, sub): """ +@_vectorize_node.register(Reshape) +def _vectorize_reshape(op, node, x, shape): + old_x, old_shape = node.inputs + batched_ndims = x.type.ndim - old_x.type.ndim + + if as_tensor_variable(shape).type.ndim != 1: + raise NotImplementedError( + "It is not possible to vectorize the shape argument of Reshape" + ) + + if len(tuple(old_shape)) == len(tuple(shape)): + new_shape = [*x.shape[:batched_ndims], *shape] + elif len(tuple(old_shape)) == (len(tuple(shape)) - batched_ndims): + new_shape = shape + else: + raise ValueError("Invalid shape length passed into vectorize node of Reshape") + + return reshape(x, new_shape, ndim=len(new_shape)).owner + + def reshape(x, newshape, ndim=None): if ndim is None: newshape = at.as_tensor_variable(newshape) @@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes): if not unbroadcasted_axes: return x return Unbroadcast(*unbroadcasted_axes)(x) + + +@_vectorize_node.register(Unbroadcast) +def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + old_axes = op.axes + new_axes = (old_axis + batched_ndims for old_axis in old_axes) + return unbroadcast(x, *new_axes).owner diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 67a47bdee1..32392949ea 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -5,14 +5,14 @@ from pytensor import Mode, function, grad from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, vectorize_node from pytensor.graph.type import Type from pytensor.misc.safe_asarray import _asarray from pytensor.scalar.basic import ScalarConstant -from pytensor.tensor import as_tensor_variable, get_vector_length, row -from pytensor.tensor.basic import MakeVector, constant +from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row +from pytensor.tensor.basic import MakeVector, as_tensor, constant from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.shape import ( @@ -706,3 +706,88 @@ def test_shape_tuple(): assert isinstance(res[1], ScalarConstant) assert res[1].data == 2 assert not isinstance(res[2], ScalarConstant) + + +class TestVectorize: + def test_shape(self): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + node = shape(vec).owner + vect_node = vectorize_node(node, mat) + assert equal_computations(vect_node.outputs, [shape(mat)]) + + def test_reshape(self): + x = scalar("x", dtype=int) + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + shape = (2, x) + node = reshape(vec, shape).owner + vect_node = vectorize_node(node, mat, shape) + assert equal_computations( + vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))] + ) + + new_shape = (5, 2, x) + vect_node = vectorize_node(node, mat, new_shape) + assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)]) + + with pytest.raises(NotImplementedError): + vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3))) + + with pytest.raises( + ValueError, + match="Invalid shape length passed into vectorize node of Reshape", + ): + vectorize_node(node, vec, (5, 2, x)) + + with pytest.raises( + ValueError, + match="Invalid shape length passed into vectorize node of Reshape", + ): + vectorize_node(node, mat, (5, 3, 2, x)) + + def test_specify_shape(self): + x = scalar("x", dtype=int) + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + shape = (x, None) + node = specify_shape(mat, shape).owner + vect_node = vectorize_node(node, tns, *shape) + assert equal_computations( + vect_node.outputs, [specify_shape(tns, (None, x, None))] + ) + + new_shape = (5, 2, x) + vect_node = vectorize_node(node, tns, *new_shape) + assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))]) + + with pytest.raises(NotImplementedError): + vectorize_node(node, mat, *([x, x], None)) + + with pytest.raises( + ValueError, + match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", + ): + vectorize_node(node, mat, *(5, 2, x)) + + with pytest.raises( + ValueError, + match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", + ): + vectorize_node(node, tns, *(5, 3, 2, x)) + + def test_unbroadcast(self): + mat = tensor( + shape=( + 1, + 1, + ) + ) + tns = tensor(shape=(4, 1, 1, 1)) + + node = unbroadcast(mat, 0).owner + vect_node = vectorize_node(node, tns) + assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])