137
137
except ImportError :
138
138
pass
139
139
140
+ from functools import reduce
140
141
from typing import Tuple
141
142
142
143
import pytensor .scalar
@@ -639,10 +640,8 @@ def c_header_dirs(self, **kwargs):
639
640
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
640
641
"""
641
642
642
- # broadcast_xy = None
643
-
644
643
check_dims = """
645
- if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
644
+ if (Nx[0] != Nz[0])
646
645
{
647
646
PyErr_Format(PyExc_ValueError,
648
647
"Shape mismatch: x has %%ld rows but z has %%ld rows",
@@ -656,7 +655,7 @@ def c_header_dirs(self, **kwargs):
656
655
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
657
656
%(fail)s;
658
657
}
659
- if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
658
+ if (Ny[1] != Nz[1])
660
659
{
661
660
PyErr_Format(PyExc_ValueError,
662
661
"Shape mismatch: y has %%ld cols but z has %%ld cols",
@@ -842,14 +841,14 @@ def build_gemm_call(self):
842
841
else :
843
842
setup_z_Nz_Sz = self .setup_z_Nz_Sz
844
843
845
- return "" .join (
844
+ return reduce (
845
+ str .__add__ ,
846
846
(
847
847
self .declare_NS ,
848
848
self .check_xyz_rank2 ,
849
849
setup_z_Nz_Sz ,
850
850
self .check_xyz_double_or_float ,
851
851
self .check_ab_double_or_float ,
852
- self .broadcast_xy ,
853
852
self .check_dims ,
854
853
self .check_strides ,
855
854
self .encode_strides_in_unit ,
@@ -862,7 +861,8 @@ def build_gemm_call(self):
862
861
self .case_double_ab_constants ,
863
862
self .case_double_gemm ,
864
863
self .end_switch_typenum ,
865
- )
864
+ ),
865
+ "" ,
866
866
)
867
867
868
868
def build_gemm_version (self ):
@@ -992,11 +992,6 @@ def perform(self, node, inp, out, params):
992
992
z .itemset (z * a + b * np .dot (x , y ))
993
993
zout [0 ] = z
994
994
else :
995
- # Broadcast Z if needed
996
- if (x .shape [0 ] > z .shape [0 ]) or (y .shape [1 ] > z .shape [1 ]):
997
- z = np .broadcast_to (
998
- z , (max (x .shape [0 ], z .shape [0 ]), max (y .shape [1 ], z .shape [1 ]))
999
- ).copy ()
1000
995
if b == 0.0 :
1001
996
if a == 1.0 :
1002
997
z [:] = np .dot (x , y )
@@ -1017,135 +1012,88 @@ def perform(self, node, inp, out, params):
1017
1012
zout [0 ] = z
1018
1013
1019
1014
def infer_shape (self , fgraph , node , input_shapes ):
1020
- z_shape , _ , x_shape , y_shape , _ = input_shapes
1021
- return [
1022
- (
1023
- pytensor .scalar .scalar_maximum (z_shape [0 ], x_shape [0 ]),
1024
- pytensor .scalar .scalar_maximum (z_shape [1 ], y_shape [1 ]),
1025
- )
1026
- ]
1015
+ return [input_shapes [0 ]]
1027
1016
1028
1017
setup_z_Nz_Sz_inplace = """
1029
- // Needs broadcasting
1030
- if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
1031
-
1032
- npy_intp dims[2];
1033
- dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1034
- dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1035
-
1036
- // Check if we need to allocate new array
1037
- if((NULL == %(_zout)s)
1038
- || (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1039
- || (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1040
- {
1041
- // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\ n", dims[0], dims[1]);
1042
- Py_XDECREF(%(_zout)s);
1043
- %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1044
- }
1045
-
1046
- // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\ n", dims[0], dims[1]);
1047
- if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1048
- {
1049
- %(fail)s;
1050
- }
1051
-
1052
- } else {
1053
- if (%(_zout)s != %(_z)s)
1018
+ if (%(_zout)s != %(_z)s)
1019
+ {
1020
+ if (%(_zout)s)
1054
1021
{
1055
- Py_XDECREF(%(_zout)s);
1056
- %(_zout)s = %(_z)s;
1057
- Py_INCREF(%(_zout)s);
1022
+ Py_DECREF(%(_zout)s);
1058
1023
}
1024
+ %(_zout)s = %(_z)s;
1025
+ Py_INCREF(%(_zout)s);
1059
1026
}
1060
-
1061
- Nz = PyArray_DIMS(%(_zout)s);
1062
- Sz = PyArray_STRIDES(%(_zout)s);
1027
+ Nz = PyArray_DIMS(%(_z)s);
1028
+ Sz = PyArray_STRIDES(%(_z)s);
1063
1029
"""
1064
1030
1065
1031
setup_z_Nz_Sz_outplace = """
1066
- npy_intp dims[2];
1067
- dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
1068
- dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
1069
-
1070
- // Check if we need to allocate new array
1071
1032
if ((NULL == %(_zout)s)
1072
- || (PyArray_DIMS(%(_zout)s)[0] != dims[0])
1073
- || (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
1033
+ || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
1034
+ || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
1035
+ || (PyArray_STRIDES(%(_zout)s)[0] <= 0)
1036
+ || (PyArray_STRIDES(%(_zout)s)[1] <= 0)
1037
+ || (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
1038
+ || (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
1039
+ || ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
1040
+ && (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
1074
1041
{
1075
1042
Py_XDECREF(%(_zout)s);
1076
- %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
1077
- // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\ n", dims[0], dims[1]);
1043
+ npy_intp dims[2];
1044
+ dims[0] = PyArray_DIMS(%(_z)s)[0];
1045
+ dims[1] = PyArray_DIMS(%(_z)s)[1];
1046
+ %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
1047
+ PyArray_TYPE(%(_z)s));
1048
+ //fprintf(stderr, "Gemm Allocating %%i %%i\\ n", dims[0], dims[1]);
1078
1049
if(!%(_zout)s) {
1079
1050
PyErr_SetString(PyExc_MemoryError,
1080
1051
"failed to alloc gemm_no_inplace output");
1081
1052
%(fail)s
1082
1053
}
1083
1054
}
1084
-
1085
- // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\ n", dims[0], dims[1]);
1086
- if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
1087
- {
1088
- %(fail)s
1089
- }
1090
-
1091
1055
Nz = PyArray_DIMS(%(_zout)s);
1092
1056
Sz = PyArray_STRIDES(%(_zout)s);
1093
- """
1094
1057
1095
- broadcast_xy = """
1096
- // Broadcast X if needed
1097
- if (Nz[0] > Nx[0])
1058
+ if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
1098
1059
{
1099
- npy_intp dims[2];
1100
- dims[0] = Nz[0];
1101
- dims[1] = Nx[1];
1102
- // fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\ n", dims[0], dims[1]);
1103
- PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1104
- if(!x_new) {
1105
- PyErr_SetString(PyExc_MemoryError,
1106
- "failed to alloc gemm_inplace input");
1107
- %(fail)s
1108
- }
1109
-
1110
- if(PyArray_MoveInto(x_new, %(_x)s) == -1)
1060
+ float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
1061
+ int zoi = Sz[0] / sizeof(float);
1062
+ int zoj = Sz[1] / sizeof(float);
1063
+ const float * zdata = (float*)PyArray_DATA(%(_z)s);
1064
+ int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
1065
+ int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
1066
+ for (int i = 0; i < Nz[0]; ++i)
1111
1067
{
1112
- %(fail)s
1068
+ for (int j = 0; j < Nz[1]; ++j)
1069
+ {
1070
+ zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1071
+ }
1113
1072
}
1114
-
1115
- Py_DECREF(%(_x)s);
1116
- %(_x)s = x_new;
1117
-
1118
- Nx = PyArray_DIMS(%(_x)s);
1119
- Sx = PyArray_STRIDES(%(_x)s);
1120
1073
}
1121
-
1122
- // Broadcast Y if needed
1123
- if (Nz[1] > Ny[1])
1074
+ else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
1124
1075
{
1125
- npy_intp dims[2];
1126
- dims[0] = Ny[0];
1127
- dims[1] = Nz[1];
1128
- // fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\ n", dims[0], dims[1]);
1129
- PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
1130
- if(!y_new) {
1131
- PyErr_SetString(PyExc_MemoryError,
1132
- "failed to alloc gemm_inplace input");
1133
- %(fail)s
1134
- }
1135
-
1136
- if(PyArray_MoveInto(y_new, %(_y)s) == -1)
1076
+ double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
1077
+ int zoi = Sz[0] / sizeof(double);
1078
+ int zoj = Sz[1] / sizeof(double);
1079
+ const double * zdata = (double*)PyArray_DATA(%(_z)s);
1080
+ int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
1081
+ int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
1082
+ for (int i = 0; i < Nz[0]; ++i)
1137
1083
{
1138
- %(fail)s
1084
+ for (int j = 0; j < Nz[1]; ++j)
1085
+ {
1086
+ zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
1087
+ }
1139
1088
}
1140
-
1141
- Py_DECREF(%(_y)s);
1142
- %(_y)s = y_new;
1143
-
1144
- Ny = PyArray_DIMS(%(_y)s);
1145
- Sy = PyArray_STRIDES(%(_y)s);
1146
1089
}
1147
-
1148
- """
1090
+ else
1091
+ {
1092
+ PyErr_SetString(PyExc_AssertionError,
1093
+ "neither float nor double dtype");
1094
+ %(fail)s
1095
+ }
1096
+ """
1149
1097
1150
1098
case_float_ab_constants = """
1151
1099
#define REAL float
@@ -1179,7 +1127,7 @@ def c_code(self, node, name, inp, out, sub):
1179
1127
def c_code_cache_version (self ):
1180
1128
gv = self .build_gemm_version ()
1181
1129
if gv :
1182
- return (7 ,) + gv
1130
+ return (6 ,) + gv
1183
1131
else :
1184
1132
return gv
1185
1133
@@ -1253,6 +1201,7 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
1253
1201
if M .owner and M .owner .op == _dot22 :
1254
1202
Ml , Mr = M .owner .inputs
1255
1203
rval = [gemm_no_inplace (L , alpha , Ml , Mr , beta )]
1204
+ # print 'GEMM 0', rval, beta, L, alpha, M
1256
1205
return rval , M
1257
1206
1258
1207
# it also might be the case that there is a dimshuffle between the +
@@ -1719,7 +1668,6 @@ def infer_shape(self, fgraph, node, input_shapes):
1719
1668
Sz = PyArray_STRIDES(%(_zout)s);
1720
1669
1721
1670
"""
1722
- broadcast_xy = ""
1723
1671
check_ab_double_or_float = ""
1724
1672
case_float_ab_constants = """
1725
1673
float a = 1.0;
@@ -2003,7 +1951,6 @@ def infer_shape(self, fgraph, node, input_shapes):
2003
1951
return [[input_shapes [0 ][0 ], input_shapes [1 ][1 ]]]
2004
1952
2005
1953
setup_z_Nz_Sz = Dot22 .setup_z_Nz_Sz
2006
- broadcast_xy = ""
2007
1954
2008
1955
check_ab_double_or_float = """
2009
1956
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
0 commit comments