Skip to content

Static broadcast #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable.

"""
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return

msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
Expand All @@ -169,13 +168,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using pytensor.tensor."
"{unbroadcast, specify_broadcastable}."
)
size = min(v1.type.ndim, v2.type.ndim)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))


Expand Down Expand Up @@ -624,7 +623,6 @@ def validate_inner_graph(self):
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ class DenseFromSparse(Op):

"""

__props__ = ()
__props__ = ("sparse_grad",)

def __init__(self, structured=True):
self.sparse_grad = structured
Expand Down
4 changes: 2 additions & 2 deletions pytensor/sparse/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,13 +1099,13 @@ def c_code_cache_version(self):
csm_grad_c = CSMGradC()


@node_rewriter([csm_grad(None)])
@node_rewriter([csm_grad()])
def local_csm_grad_c(fgraph, node):
"""
csm_grad(None) -> csm_grad_c

"""
if node.op == csm_grad(None):
if node.op == csm_grad():
return [csm_grad_c(*node.inputs)]
return False

Expand Down
9 changes: 7 additions & 2 deletions pytensor/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def __init__(
):
if shape is None and broadcastable is None:
shape = (None, None)

if broadcastable is None:
broadcastable = (False, False)
if broadcastable != (False, False):
raise ValueError("Broadcasting sparse types is not yet implemented")
if format not in self.format_cls:
raise ValueError(
f'unsupported format "{format}" not in list',
Expand All @@ -95,7 +98,9 @@ def clone(
dtype = self.dtype
if shape is None:
shape = self.shape
return type(self)(format, dtype, shape=shape, **kwargs)
return type(self)(
format, dtype, shape=shape, broadcastable=broadcastable, **kwargs
)

def filter(self, value, strict=False, allow_downcast=None):
if isinstance(value, Variable):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def get_scalar_constant_value(
for i in v.owner.inputs
]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
v.owner.op.scalar_op.perform(v.owner, const, ret)
return np.asarray(ret[0][0].copy())
elif (
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)
Expand Down
179 changes: 63 additions & 116 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
except ImportError:
pass

from functools import reduce
from typing import Tuple

import pytensor.scalar
Expand Down Expand Up @@ -639,10 +640,8 @@ def c_header_dirs(self, **kwargs):
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
"""

# broadcast_xy = None

check_dims = """
if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
if (Nx[0] != Nz[0])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows",
Expand All @@ -656,7 +655,7 @@ def c_header_dirs(self, **kwargs):
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
%(fail)s;
}
if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
if (Ny[1] != Nz[1])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols",
Expand Down Expand Up @@ -842,14 +841,14 @@ def build_gemm_call(self):
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz

return "".join(
return reduce(
str.__add__,
(
self.declare_NS,
self.check_xyz_rank2,
setup_z_Nz_Sz,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.broadcast_xy,
self.check_dims,
self.check_strides,
self.encode_strides_in_unit,
Expand All @@ -862,7 +861,8 @@ def build_gemm_call(self):
self.case_double_ab_constants,
self.case_double_gemm,
self.end_switch_typenum,
)
),
"",
)

def build_gemm_version(self):
Expand Down Expand Up @@ -992,11 +992,6 @@ def perform(self, node, inp, out, params):
z.itemset(z * a + b * np.dot(x, y))
zout[0] = z
else:
# Broadcast Z if needed
if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]):
z = np.broadcast_to(
z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1]))
).copy()
if b == 0.0:
if a == 1.0:
z[:] = np.dot(x, y)
Expand All @@ -1017,135 +1012,88 @@ def perform(self, node, inp, out, params):
zout[0] = z

def infer_shape(self, fgraph, node, input_shapes):
z_shape, _, x_shape, y_shape, _ = input_shapes
return [
(
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
)
]
return [input_shapes[0]]

setup_z_Nz_Sz_inplace = """
// Needs broadcasting
if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){

npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];

// Check if we need to allocate new array
if((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
{
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
}

// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s;
}

} else {
if (%(_zout)s != %(_z)s)
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
{
Py_XDECREF(%(_zout)s);
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}

Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
Nz = PyArray_DIMS(%(_z)s);
Sz = PyArray_STRIDES(%(_z)s);
"""

setup_z_Nz_Sz_outplace = """
npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];

// Check if we need to allocate new array
if ((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
|| (PyArray_STRIDES(%(_zout)s)[0] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[1] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
|| (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
|| ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
&& (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
{
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(_z)s)[0];
dims[1] = PyArray_DIMS(%(_z)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
PyArray_TYPE(%(_z)s));
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_no_inplace output");
%(fail)s
}
}

// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s
}

Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
"""

broadcast_xy = """
// Broadcast X if needed
if (Nz[0] > Nx[0])
if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
{
npy_intp dims[2];
dims[0] = Nz[0];
dims[1] = Nx[1];
// fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]);
PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
if(!x_new) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}

if(PyArray_MoveInto(x_new, %(_x)s) == -1)
float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(float);
int zoj = Sz[1] / sizeof(float);
const float * zdata = (float*)PyArray_DATA(%(_z)s);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
{
%(fail)s
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
}

Py_DECREF(%(_x)s);
%(_x)s = x_new;

Nx = PyArray_DIMS(%(_x)s);
Sx = PyArray_STRIDES(%(_x)s);
}

// Broadcast Y if needed
if (Nz[1] > Ny[1])
else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
{
npy_intp dims[2];
dims[0] = Ny[0];
dims[1] = Nz[1];
// fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]);
PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
if(!y_new) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}

if(PyArray_MoveInto(y_new, %(_y)s) == -1)
double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(double);
int zoj = Sz[1] / sizeof(double);
const double * zdata = (double*)PyArray_DATA(%(_z)s);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
{
%(fail)s
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
}

Py_DECREF(%(_y)s);
%(_y)s = y_new;

Ny = PyArray_DIMS(%(_y)s);
Sy = PyArray_STRIDES(%(_y)s);
}

"""
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
"""

case_float_ab_constants = """
#define REAL float
Expand Down Expand Up @@ -1179,7 +1127,7 @@ def c_code(self, node, name, inp, out, sub):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (7,) + gv
return (8,) + gv
else:
return gv

Expand Down Expand Up @@ -1253,6 +1201,7 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
# print 'GEMM 0', rval, beta, L, alpha, M
return rval, M

# it also might be the case that there is a dimshuffle between the +
Expand Down Expand Up @@ -1719,7 +1668,6 @@ def infer_shape(self, fgraph, node, input_shapes):
Sz = PyArray_STRIDES(%(_zout)s);

"""
broadcast_xy = ""
check_ab_double_or_float = ""
case_float_ab_constants = """
float a = 1.0;
Expand Down Expand Up @@ -2003,7 +1951,6 @@ def infer_shape(self, fgraph, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]

setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
broadcast_xy = ""

check_ab_double_or_float = """
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
Expand Down
Loading