Skip to content

Commit 1972970

Browse files
committed
Revert "Broadcast input matrices in Gemm"
This reverts commit a3dc0a7.
1 parent 289ddb4 commit 1972970

File tree

2 files changed

+63
-187
lines changed

2 files changed

+63
-187
lines changed

pytensor/tensor/blas.py

Lines changed: 63 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
except ImportError:
138138
pass
139139

140+
from functools import reduce
140141
from typing import Tuple
141142

142143
import pytensor.scalar
@@ -639,10 +640,8 @@ def c_header_dirs(self, **kwargs):
639640
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
640641
"""
641642

642-
# broadcast_xy = None
643-
644643
check_dims = """
645-
if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
644+
if (Nx[0] != Nz[0])
646645
{
647646
PyErr_Format(PyExc_ValueError,
648647
"Shape mismatch: x has %%ld rows but z has %%ld rows",
@@ -656,7 +655,7 @@ def c_header_dirs(self, **kwargs):
656655
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
657656
%(fail)s;
658657
}
659-
if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
658+
if (Ny[1] != Nz[1])
660659
{
661660
PyErr_Format(PyExc_ValueError,
662661
"Shape mismatch: y has %%ld cols but z has %%ld cols",
@@ -842,14 +841,14 @@ def build_gemm_call(self):
842841
else:
843842
setup_z_Nz_Sz = self.setup_z_Nz_Sz
844843

845-
return "".join(
844+
return reduce(
845+
str.__add__,
846846
(
847847
self.declare_NS,
848848
self.check_xyz_rank2,
849849
setup_z_Nz_Sz,
850850
self.check_xyz_double_or_float,
851851
self.check_ab_double_or_float,
852-
self.broadcast_xy,
853852
self.check_dims,
854853
self.check_strides,
855854
self.encode_strides_in_unit,
@@ -862,7 +861,8 @@ def build_gemm_call(self):
862861
self.case_double_ab_constants,
863862
self.case_double_gemm,
864863
self.end_switch_typenum,
865-
)
864+
),
865+
"",
866866
)
867867

868868
def build_gemm_version(self):
@@ -992,11 +992,6 @@ def perform(self, node, inp, out, params):
992992
z.itemset(z * a + b * np.dot(x, y))
993993
zout[0] = z
994994
else:
995-
# Broadcast Z if needed
996-
if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]):
997-
z = np.broadcast_to(
998-
z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1]))
999-
).copy()
1000995
if b == 0.0:
1001996
if a == 1.0:
1002997
z[:] = np.dot(x, y)
@@ -1017,135 +1012,88 @@ def perform(self, node, inp, out, params):
10171012
zout[0] = z
10181013

10191014
def infer_shape(self, fgraph, node, input_shapes):
1020-
z_shape, _, x_shape, y_shape, _ = input_shapes
1021-
return [
1022-
(
1023-
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
1024-
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
1025-
)
1026-
]
1015+
return [input_shapes[0]]
10271016

10281017
setup_z_Nz_Sz_inplace = """
1029-
// Needs broadcasting
1030-
if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
1031-
1032-
npy_intp dims[2];
1033-
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1034-
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1035-
1036-
// Check if we need to allocate new array
1037-
if((NULL == %(_zout)s)
1038-
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1039-
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1040-
{
1041-
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
1042-
Py_XDECREF(%(_zout)s);
1043-
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1044-
}
1045-
1046-
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
1047-
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1048-
{
1049-
%(fail)s;
1050-
}
1051-
1052-
} else {
1053-
if (%(_zout)s != %(_z)s)
1018+
if (%(_zout)s != %(_z)s)
1019+
{
1020+
if (%(_zout)s)
10541021
{
1055-
Py_XDECREF(%(_zout)s);
1056-
%(_zout)s = %(_z)s;
1057-
Py_INCREF(%(_zout)s);
1022+
Py_DECREF(%(_zout)s);
10581023
}
1024+
%(_zout)s = %(_z)s;
1025+
Py_INCREF(%(_zout)s);
10591026
}
1060-
1061-
Nz = PyArray_DIMS(%(_zout)s);
1062-
Sz = PyArray_STRIDES(%(_zout)s);
1027+
Nz = PyArray_DIMS(%(_z)s);
1028+
Sz = PyArray_STRIDES(%(_z)s);
10631029
"""
10641030

10651031
setup_z_Nz_Sz_outplace = """
1066-
npy_intp dims[2];
1067-
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1068-
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1069-
1070-
// Check if we need to allocate new array
10711032
if ((NULL == %(_zout)s)
1072-
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1073-
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1033+
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
1034+
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
1035+
|| (PyArray_STRIDES(%(_zout)s)[0] <= 0)
1036+
|| (PyArray_STRIDES(%(_zout)s)[1] <= 0)
1037+
|| (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
1038+
|| (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
1039+
|| ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
1040+
&& (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
10741041
{
10751042
Py_XDECREF(%(_zout)s);
1076-
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1077-
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
1043+
npy_intp dims[2];
1044+
dims[0] = PyArray_DIMS(%(_z)s)[0];
1045+
dims[1] = PyArray_DIMS(%(_z)s)[1];
1046+
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
1047+
PyArray_TYPE(%(_z)s));
1048+
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
10781049
if(!%(_zout)s) {
10791050
PyErr_SetString(PyExc_MemoryError,
10801051
"failed to alloc gemm_no_inplace output");
10811052
%(fail)s
10821053
}
10831054
}
1084-
1085-
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
1086-
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1087-
{
1088-
%(fail)s
1089-
}
1090-
10911055
Nz = PyArray_DIMS(%(_zout)s);
10921056
Sz = PyArray_STRIDES(%(_zout)s);
1093-
"""
10941057
1095-
broadcast_xy = """
1096-
// Broadcast X if needed
1097-
if (Nz[0] > Nx[0])
1058+
if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
10981059
{
1099-
npy_intp dims[2];
1100-
dims[0] = Nz[0];
1101-
dims[1] = Nx[1];
1102-
// fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]);
1103-
PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1104-
if(!x_new) {
1105-
PyErr_SetString(PyExc_MemoryError,
1106-
"failed to alloc gemm_inplace input");
1107-
%(fail)s
1108-
}
1109-
1110-
if(PyArray_MoveInto(x_new, %(_x)s) == -1)
1060+
float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
1061+
int zoi = Sz[0] / sizeof(float);
1062+
int zoj = Sz[1] / sizeof(float);
1063+
const float * zdata = (float*)PyArray_DATA(%(_z)s);
1064+
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
1065+
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
1066+
for (int i = 0; i < Nz[0]; ++i)
11111067
{
1112-
%(fail)s
1068+
for (int j = 0; j < Nz[1]; ++j)
1069+
{
1070+
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1071+
}
11131072
}
1114-
1115-
Py_DECREF(%(_x)s);
1116-
%(_x)s = x_new;
1117-
1118-
Nx = PyArray_DIMS(%(_x)s);
1119-
Sx = PyArray_STRIDES(%(_x)s);
11201073
}
1121-
1122-
// Broadcast Y if needed
1123-
if (Nz[1] > Ny[1])
1074+
else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
11241075
{
1125-
npy_intp dims[2];
1126-
dims[0] = Ny[0];
1127-
dims[1] = Nz[1];
1128-
// fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]);
1129-
PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1130-
if(!y_new) {
1131-
PyErr_SetString(PyExc_MemoryError,
1132-
"failed to alloc gemm_inplace input");
1133-
%(fail)s
1134-
}
1135-
1136-
if(PyArray_MoveInto(y_new, %(_y)s) == -1)
1076+
double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
1077+
int zoi = Sz[0] / sizeof(double);
1078+
int zoj = Sz[1] / sizeof(double);
1079+
const double * zdata = (double*)PyArray_DATA(%(_z)s);
1080+
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
1081+
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
1082+
for (int i = 0; i < Nz[0]; ++i)
11371083
{
1138-
%(fail)s
1084+
for (int j = 0; j < Nz[1]; ++j)
1085+
{
1086+
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1087+
}
11391088
}
1140-
1141-
Py_DECREF(%(_y)s);
1142-
%(_y)s = y_new;
1143-
1144-
Ny = PyArray_DIMS(%(_y)s);
1145-
Sy = PyArray_STRIDES(%(_y)s);
11461089
}
1147-
1148-
"""
1090+
else
1091+
{
1092+
PyErr_SetString(PyExc_AssertionError,
1093+
"neither float nor double dtype");
1094+
%(fail)s
1095+
}
1096+
"""
11491097

11501098
case_float_ab_constants = """
11511099
#define REAL float
@@ -1179,7 +1127,7 @@ def c_code(self, node, name, inp, out, sub):
11791127
def c_code_cache_version(self):
11801128
gv = self.build_gemm_version()
11811129
if gv:
1182-
return (7,) + gv
1130+
return (6,) + gv
11831131
else:
11841132
return gv
11851133

@@ -1253,6 +1201,7 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
12531201
if M.owner and M.owner.op == _dot22:
12541202
Ml, Mr = M.owner.inputs
12551203
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
1204+
# print 'GEMM 0', rval, beta, L, alpha, M
12561205
return rval, M
12571206

12581207
# it also might be the case that there is a dimshuffle between the +
@@ -1719,7 +1668,6 @@ def infer_shape(self, fgraph, node, input_shapes):
17191668
Sz = PyArray_STRIDES(%(_zout)s);
17201669
17211670
"""
1722-
broadcast_xy = ""
17231671
check_ab_double_or_float = ""
17241672
case_float_ab_constants = """
17251673
float a = 1.0;
@@ -2003,7 +1951,6 @@ def infer_shape(self, fgraph, node, input_shapes):
20031951
return [[input_shapes[0][0], input_shapes[1][1]]]
20041952

20051953
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
2006-
broadcast_xy = ""
20071954

20081955
check_ab_double_or_float = """
20091956
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)

tests/tensor/test_blas.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from copy import copy
22
from itertools import product
3-
from random import shuffle
43

54
import numpy as np
65
import pytest
@@ -68,7 +67,6 @@
6867
matrix,
6968
row,
7069
scalar,
71-
scalars,
7270
tensor,
7371
tensor3,
7472
tensor4,
@@ -1044,41 +1042,6 @@ def test_inplace1():
10441042
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
10451043

10461044

1047-
@pytest.mark.parametrize("linker", ("py", "cvm"))
1048-
@pytest.mark.parametrize("inplace", (False, True))
1049-
def test_gemm_broadcasting(inplace, linker):
1050-
a, b = scalars("a", "b")
1051-
z, x, y = matrices("z", "x", "y")
1052-
1053-
mode = Mode(linker=linker)
1054-
if inplace:
1055-
out = gemm_inplace(z, a, x, y, b)
1056-
f = pytensor.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
1057-
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
1058-
else:
1059-
out = gemm_no_inplace(z, a, x, y, b)
1060-
f = pytensor.function([z, x, y, a, b], out, mode=mode)
1061-
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
1062-
1063-
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
1064-
shapes_x = [(5, 4), (1, 4)]
1065-
shapes_y = [(4, 3), (4, 1)]
1066-
1067-
rng = np.random.default_rng()
1068-
shuffle(shapes_z)
1069-
shuffle(shapes_x)
1070-
shuffle(shapes_y)
1071-
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
1072-
z_v = rng.random(size=shape_z).astype(config.floatX)
1073-
x_v = rng.random(size=shape_x).astype(config.floatX)
1074-
y_v = rng.random(size=shape_y).astype(config.floatX)
1075-
# We have to copy for the inplace case
1076-
z_v_np = z_v.copy()
1077-
np.testing.assert_allclose(
1078-
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
1079-
)
1080-
1081-
10821045
def test_dot22():
10831046
for dtype1 in ["float32", "float64", "complex64", "complex128"]:
10841047
a = matrix(dtype=dtype1)
@@ -2513,40 +2476,6 @@ def test_gemm(self):
25132476
Gemm,
25142477
)
25152478

2516-
def test_gemm_broadcast(self):
2517-
rng = np.random.default_rng(unittest_tools.fetch_seed())
2518-
x, y, z = matrices("xyz")
2519-
a = scalar("a")
2520-
b = scalar("b")
2521-
2522-
# Broadcast Z
2523-
self._compile_and_check(
2524-
[x, y, a, z, b],
2525-
[gemm(z, a, x, y, b)],
2526-
[
2527-
rng.random((2, 3)).astype(config.floatX),
2528-
rng.random((3, 4)).astype(config.floatX),
2529-
np.asarray(0.5, dtype=config.floatX),
2530-
rng.random((1, 4)).astype(config.floatX),
2531-
np.asarray(0.5, dtype=config.floatX),
2532-
],
2533-
Gemm,
2534-
)
2535-
2536-
# Broadcast dot(X, Y)
2537-
self._compile_and_check(
2538-
[x, y, a, z, b],
2539-
[gemm(z, a, x, y, b)],
2540-
[
2541-
rng.random((1, 3)).astype(config.floatX),
2542-
rng.random((3, 4)).astype(config.floatX),
2543-
np.asarray(0.5, dtype=config.floatX),
2544-
rng.random((5, 4)).astype(config.floatX),
2545-
np.asarray(1, dtype=config.floatX),
2546-
],
2547-
Gemm,
2548-
)
2549-
25502479
def test_gemv(self):
25512480
rng = np.random.default_rng(unittest_tools.fetch_seed())
25522481
A = matrix("A")

0 commit comments

Comments
 (0)