Skip to content

Commit bd72bc7

Browse files
committed
Vectorize shape operations
1 parent 7bb18f3 commit bd72bc7

File tree

5 files changed

+168
-24
lines changed

5 files changed

+168
-24
lines changed

pytensor/graph/replace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def toposort_key(
202202

203203

204204
@singledispatch
205-
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
205+
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
206206
# Default implementation is provided in pytensor.tensor.blockwise
207207
raise NotImplementedError
208208

@@ -213,6 +213,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
213213
return _vectorize_node(op, node, *batched_inputs)
214214

215215

216+
def _vectorize_not_needed(op, node, *batch_inputs):
217+
return op.make_node(*batch_inputs)
218+
219+
216220
@overload
217221
def vectorize(
218222
outputs: Variable,

pytensor/tensor/blockwise.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.graph.basic import Apply, Constant, Variable
99
from pytensor.graph.null_type import NullType
1010
from pytensor.graph.op import Op
11-
from pytensor.graph.replace import _vectorize_node, vectorize
11+
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed, vectorize
1212
from pytensor.tensor import as_tensor_variable
1313
from pytensor.tensor.shape import shape_padleft
1414
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
@@ -72,17 +72,6 @@ def operand_sig(operand: Variable, prefix: str) -> str:
7272
return f"{inputs_sig}->{outputs_sig}"
7373

7474

75-
@_vectorize_node.register(Op)
76-
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
77-
if hasattr(op, "gufunc_signature"):
78-
signature = op.gufunc_signature
79-
else:
80-
# TODO: This is pretty bad for shape inference and merge optimization!
81-
# Should get better as we add signatures to our Ops
82-
signature = safe_signature(node.inputs, node.outputs)
83-
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
84-
85-
8675
class Blockwise(Op):
8776
"""Generalizes a core `Op` to work with batched dimensions.
8877
@@ -378,6 +367,15 @@ def __str__(self):
378367
return self.name
379368

380369

381-
@_vectorize_node.register(Blockwise)
382-
def vectorize_not_needed(op, node, *batch_inputs):
383-
return op.make_node(*batch_inputs)
370+
@_vectorize_node.register(Op)
371+
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
372+
if hasattr(op, "gufunc_signature"):
373+
signature = op.gufunc_signature
374+
else:
375+
# TODO: This is pretty bad for shape inference and merge optimization!
376+
# Should get better as we add signatures to our Ops
377+
signature = safe_signature(node.inputs, node.outputs)
378+
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
379+
380+
381+
_vectorize_node.register(Blockwise, _vectorize_not_needed)

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.gradient import DisconnectedType
99
from pytensor.graph.basic import Apply
1010
from pytensor.graph.null_type import NullType
11-
from pytensor.graph.replace import _vectorize_node
11+
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
1212
from pytensor.graph.utils import MethodNotDefined
1313
from pytensor.link.c.basic import failure_code
1414
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
@@ -23,7 +23,6 @@
2323
from pytensor.tensor import elemwise_cgen as cgen
2424
from pytensor.tensor import get_vector_length
2525
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
26-
from pytensor.tensor.blockwise import vectorize_not_needed
2726
from pytensor.tensor.type import (
2827
TensorType,
2928
continuous_dtypes,
@@ -1742,7 +1741,7 @@ def _get_vector_length_Elemwise(op, var):
17421741
raise ValueError(f"Length of {var} cannot be determined")
17431742

17441743

1745-
_vectorize_node.register(Elemwise, vectorize_not_needed)
1744+
_vectorize_node.register(Elemwise, _vectorize_not_needed)
17461745

17471746

17481747
@_vectorize_node.register(DimShuffle)

pytensor/tensor/shape.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytensor
99
from pytensor.gradient import DisconnectedType
1010
from pytensor.graph.basic import Apply, Variable
11+
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
1112
from pytensor.graph.type import HasShape
1213
from pytensor.link.c.op import COp
1314
from pytensor.link.c.params_type import ParamsType
@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
154155
return var.owner.inputs[0].type.ndim
155156

156157

158+
_vectorize_node.register(Shape, _vectorize_not_needed)
159+
160+
157161
def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
158162
r"""Get a tuple of symbolic shape values.
159163
@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
580584
raise ValueError(f"Length of {var} cannot be determined")
581585

582586

587+
@_vectorize_node.register(SpecifyShape)
588+
def _vectorize_specify_shape(op, node, x, *shape):
589+
old_x, *old_shape = node.inputs
590+
batched_ndims = x.type.ndim - old_x.type.ndim
591+
592+
if any(
593+
as_tensor_variable(dim).type.ndim != 0
594+
for dim in shape
595+
if not (NoneConst.equals(dim) or dim is None)
596+
):
597+
raise NotImplementedError(
598+
"It is not possible to vectorize the shape argument of SpecifyShape"
599+
)
600+
601+
if len(shape) == len(old_shape):
602+
new_shape = tuple([None] * batched_ndims) + shape
603+
elif len(shape) == (len(old_shape) + batched_ndims):
604+
new_shape = shape
605+
else:
606+
raise ValueError(
607+
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
608+
)
609+
610+
return specify_shape(x, new_shape).owner
611+
612+
583613
class Reshape(COp):
584614
"""Perform a reshape operation of the input x to the new shape shp.
585615
The number of dimensions to which to reshape to (ndim) must be
@@ -638,7 +668,7 @@ def make_node(self, x, shp):
638668

639669
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
640670

641-
def perform(self, node, inp, out_, params):
671+
def perform(self, node, inp, out_, params=None):
642672
x, shp = inp
643673
(out,) = out_
644674
if len(shp) != self.ndim:
@@ -770,6 +800,26 @@ def c_code(self, node, name, inputs, outputs, sub):
770800
"""
771801

772802

803+
@_vectorize_node.register(Reshape)
804+
def _vectorize_reshape(op, node, x, shape):
805+
old_x, old_shape = node.inputs
806+
batched_ndims = x.type.ndim - old_x.type.ndim
807+
808+
if as_tensor_variable(shape).type.ndim != 1:
809+
raise NotImplementedError(
810+
"It is not possible to vectorize the shape argument of Reshape"
811+
)
812+
813+
if len(tuple(old_shape)) == len(tuple(shape)):
814+
new_shape = [*x.shape[:batched_ndims], *shape]
815+
elif len(tuple(old_shape)) == (len(tuple(shape)) - batched_ndims):
816+
new_shape = shape
817+
else:
818+
raise ValueError("Invalid shape length passed into vectorize node of Reshape")
819+
820+
return reshape(x, new_shape, ndim=len(new_shape)).owner
821+
822+
773823
def reshape(x, newshape, ndim=None):
774824
if ndim is None:
775825
newshape = at.as_tensor_variable(newshape)
@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
10341084
if not unbroadcasted_axes:
10351085
return x
10361086
return Unbroadcast(*unbroadcasted_axes)(x)
1087+
1088+
1089+
@_vectorize_node.register(Unbroadcast)
1090+
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
1091+
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1092+
old_axes = op.axes
1093+
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
1094+
return unbroadcast(x, *new_axes).owner

tests/tensor/test_shape.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from pytensor import Mode, function, grad
66
from pytensor.compile.ops import DeepCopyOp
77
from pytensor.configdefaults import config
8-
from pytensor.graph.basic import Variable
8+
from pytensor.graph.basic import Variable, equal_computations
99
from pytensor.graph.fg import FunctionGraph
10-
from pytensor.graph.replace import clone_replace
10+
from pytensor.graph.replace import clone_replace, vectorize_node
1111
from pytensor.graph.type import Type
1212
from pytensor.misc.safe_asarray import _asarray
1313
from pytensor.scalar.basic import ScalarConstant
14-
from pytensor.tensor import as_tensor_variable, get_vector_length, row
15-
from pytensor.tensor.basic import MakeVector, constant
14+
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
15+
from pytensor.tensor.basic import MakeVector, as_tensor, constant
1616
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1717
from pytensor.tensor.rewriting.shape import ShapeFeature
1818
from pytensor.tensor.shape import (
@@ -706,3 +706,88 @@ def test_shape_tuple():
706706
assert isinstance(res[1], ScalarConstant)
707707
assert res[1].data == 2
708708
assert not isinstance(res[2], ScalarConstant)
709+
710+
711+
class TestVectorize:
712+
def test_shape(self):
713+
vec = tensor(shape=(None,))
714+
mat = tensor(shape=(None, None))
715+
716+
node = shape(vec).owner
717+
vect_node = vectorize_node(node, mat)
718+
assert equal_computations(vect_node.outputs, [shape(mat)])
719+
720+
def test_reshape(self):
721+
x = scalar("x", dtype=int)
722+
vec = tensor(shape=(None,))
723+
mat = tensor(shape=(None, None))
724+
725+
shape = (2, x)
726+
node = reshape(vec, shape).owner
727+
vect_node = vectorize_node(node, mat, shape)
728+
assert equal_computations(
729+
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]
730+
)
731+
732+
new_shape = (5, 2, x)
733+
vect_node = vectorize_node(node, mat, new_shape)
734+
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])
735+
736+
with pytest.raises(NotImplementedError):
737+
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))
738+
739+
with pytest.raises(
740+
ValueError,
741+
match="Invalid shape length passed into vectorize node of Reshape",
742+
):
743+
vectorize_node(node, vec, (5, 2, x))
744+
745+
with pytest.raises(
746+
ValueError,
747+
match="Invalid shape length passed into vectorize node of Reshape",
748+
):
749+
vectorize_node(node, mat, (5, 3, 2, x))
750+
751+
def test_specify_shape(self):
752+
x = scalar("x", dtype=int)
753+
mat = tensor(shape=(None, None))
754+
tns = tensor(shape=(None, None, None))
755+
756+
shape = (x, None)
757+
node = specify_shape(mat, shape).owner
758+
vect_node = vectorize_node(node, tns, *shape)
759+
assert equal_computations(
760+
vect_node.outputs, [specify_shape(tns, (None, x, None))]
761+
)
762+
763+
new_shape = (5, 2, x)
764+
vect_node = vectorize_node(node, tns, *new_shape)
765+
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])
766+
767+
with pytest.raises(NotImplementedError):
768+
vectorize_node(node, mat, *([x, x], None))
769+
770+
with pytest.raises(
771+
ValueError,
772+
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
773+
):
774+
vectorize_node(node, mat, *(5, 2, x))
775+
776+
with pytest.raises(
777+
ValueError,
778+
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
779+
):
780+
vectorize_node(node, tns, *(5, 3, 2, x))
781+
782+
def test_unbroadcast(self):
783+
mat = tensor(
784+
shape=(
785+
1,
786+
1,
787+
)
788+
)
789+
tns = tensor(shape=(4, 1, 1, 1))
790+
791+
node = unbroadcast(mat, 0).owner
792+
vect_node = vectorize_node(node, tns)
793+
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])

0 commit comments

Comments
 (0)