Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions thinc/backends/_custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ struct Constants<float> {
};


template <typename U>
__global__ void gather_add(U* out_bo, const U* table_to, const int* indices_bk,
int T, int O, int B, int K)
{
int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int b = _loop_start; b < B; b += _loop_stride) {
for (int k = 0; k < K; ++k) {
int idx = indices_bk[b * K + k];
const U* table = table_to + idx * O;
U* out = out_bo + b * O;
for (int o = 0; o < O; ++o) {
out[o] += table[o];
}
}
}
}


template <typename T>
__global__ void seq2col(T* output, const T* X, const int* lengths,
int nW, int B, int I, int nL)
Expand Down
33 changes: 33 additions & 0 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
"gelu<float>",
"maxout<double>",
Expand Down Expand Up @@ -76,6 +78,8 @@ def compile_mmh(src):

clipped_linear_kernel_float = _get_kernel("clipped_linear<float>")
clipped_linear_kernel_double = _get_kernel("clipped_linear<double>")
gather_add_kernel_float = _get_kernel("gather_add<float>")
gather_add_kernel_double = _get_kernel("gather_add<double>")
gelu_kernel_float = _get_kernel("gelu<float>")
gelu_kernel_double = _get_kernel("gelu<double>")
hash_data_kernel = compile_mmh(MMH_SRC)
Expand Down Expand Up @@ -165,6 +169,28 @@ def clipped_linear(
return out


def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
_is_float_array(table)
indices = indices.astype("int32")
_check_indices(indices, table.shape[0])

B = indices.shape[0]
K = indices.shape[1]
T = table.shape[0]
O = table.shape[1]

out = _alloc((B, O), dtype=table.dtype, zeros=True)
if table.dtype == "float32":
gather_add_kernel_float(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
else:
gather_add_kernel_double(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
return out


def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)

Expand Down Expand Up @@ -647,6 +673,13 @@ def _check_lengths(lengths, n_elems: int, *, min_length=0):
raise IndexError("lengths must sum up to the batch size")


def _check_indices(indices, n: int):
assert indices.dtype == "int32", "indices should be encoded as 32-bit integers"

if not _values_within_range(indices, 0, n):
raise IndexError(f"index out of bounds, must be >= 0 && < {n}")


def _check_which_maxout(which, B: int, I: int, P: int):
shape = (B, I)
msg = "maximum index (which) should be encoded as 32-bit integers"
Expand Down
6 changes: 6 additions & 0 deletions thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil


ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX,
double *Y, int incY) nogil


# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so
# that consumers of the CBlas class cannot become dependent on its size or
# ordering.
Expand All @@ -18,7 +22,9 @@ cdef struct BlasFuncs

cdef class CBlas:
cdef shared_ptr[BlasFuncs] ptr
cdef daxpy_ptr daxpy(self) nogil
cdef saxpy_ptr saxpy(self) nogil
cdef sgemm_ptr sgemm(self) nogil
cdef void set_daxpy(self, daxpy_ptr daxpy) nogil
cdef void set_saxpy(self, saxpy_ptr saxpy) nogil
cdef void set_sgemm(self, sgemm_ptr sgemm) nogil
8 changes: 8 additions & 0 deletions thinc/backends/cblas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from libcpp.memory cimport make_shared


cdef struct BlasFuncs:
daxpy_ptr daxpy
saxpy_ptr saxpy
sgemm_ptr sgemm

Expand All @@ -15,16 +16,23 @@ cdef class CBlas:
"""Construct a CBlas instance set to use BLIS implementations of the
supported BLAS functions."""
cdef BlasFuncs funcs
funcs.daxpy = blis.cy.daxpy
funcs.saxpy = blis.cy.saxpy
funcs.sgemm = blis.cy.sgemm
self.ptr = make_shared[BlasFuncs](funcs)

cdef daxpy_ptr daxpy(self) nogil:
return deref(self.ptr).daxpy

cdef saxpy_ptr saxpy(self) nogil:
return deref(self.ptr).saxpy

cdef sgemm_ptr sgemm(self) nogil:
return deref(self.ptr).sgemm

cdef void set_daxpy(self, daxpy_ptr daxpy) nogil:
deref(self.ptr).daxpy = daxpy

cdef void set_saxpy(self, saxpy_ptr saxpy) nogil:
deref(self.ptr).saxpy = saxpy

Expand Down
30 changes: 30 additions & 0 deletions thinc/backends/cpu_kernels.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
#include <string>
#include <type_traits>

// Ideally we'd use an alias declaration for a generic definition of
// *axpy. But Cython does support alias declarations yet:
//
// https://github.com/cython/cython/issues/3272
//
// template <typename T>
// using axpy = void (*)(int N, float alpha, const float* X, int incX,
// float *Y, int incY);
//
// So, instead we'll do this the pre-C++11 way:

template <typename T>
struct axpy {
typedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY);
};


// All elementwise functions, such as most activations, work in-place.

Expand Down Expand Up @@ -395,4 +411,18 @@ void backprop_seq2col(A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L
}
}

template <typename F, typename I, typename L>
void cpu_gather_add(typename axpy<F>::ptr axpy, F* out_bo, const F* table_to, const I* indices_bk, L T, L O, L B, L K) {
for (L b = 0; b < B; ++b) {
for (L k = 0; k < K; ++k) {
I idx = indices_bk[b * K + k];
if (idx > T) {
throw std::out_of_range("Embedding index out-of-bounds");
}
axpy(O, 1.0, table_to + idx * O, 1, out_bo + b * O, 1);
}
}
}


#endif // CPU_KERNELS_HH
6 changes: 6 additions & 0 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def to_numpy(self, data, *, byte_order=None):
data = numpy.asarray(data, dtype=dtype)
return data

def gather_add(self, table, indices):
if table.dtype in ("float32", "float64"):
return _custom_kernels.gather_add(table, indices)
else:
return super().gather_add(table, indices)

def gelu(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0)
Expand Down
13 changes: 13 additions & 0 deletions thinc/backends/numpy_ops.pxd
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from .cblas cimport saxpy_ptr

ctypedef double[:, ::1] double2d_t
ctypedef double[:, :, ::1] double3d_t
ctypedef float[:, ::1] float2d_t
ctypedef float[:, :, ::1] float3d_t
ctypedef int[:, ::1] int2d_t
ctypedef unsigned int[:, ::1] uint2d_t

cdef fused ints2d_ft:
int2d_t
uint2d_t

cdef fused reals2d_ft:
float2d_t
Expand All @@ -13,6 +21,9 @@ cdef fused reals3d_ft:


cdef extern from "cpu_kernels.hh":
cdef cppclass axpy[T]:
ctypedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY);

void cpu_maxout[A, L](A* best__bo, L* which__bo, const A* cands_bop,
L B, L O, L P)
void cpu_backprop_maxout[A, L](A* dX__bop, const A* dX__bo, const L* which__bo,
Expand All @@ -35,3 +46,5 @@ cdef extern from "cpu_kernels.hh":
void cpu_relu[A, L](A* X, L N)
void backprop_seq2col[A, L](A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L nW, L nL)
void seq2col[A, L](A* output, const A* X, const L* lengths, L nW, L B, L I, L nL)
void cpu_gather_add[F, I, L](axpy[F].ptr axpy, F* out_bo, const F* table_to, const I* indices_bk,
L T, L O, L B, L K) except +
17 changes: 17 additions & 0 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,23 @@ class NumpyOps(Ops):

return dX

def gather_add(self, reals2d_ft table, ints2d_ft indices):
cdef CBlas cblas = self.cblas()
rows = indices.shape[0]
dims = table.shape[1]

cdef np.ndarray output
if reals2d_ft is float2d_t:
output = self.xp.zeros((rows, dims), dtype="float32")
cpu_gather_add(cblas.saxpy(), <float *>output.data, &table[0, 0], &indices[0, 0],
table.shape[0], dims, rows, indices.shape[1])
else:
output = self.xp.zeros((rows, dims), dtype="float64")
cpu_gather_add(cblas.daxpy(), <double *>output.data, &table[0, 0], &indices[0, 0],
table.shape[0], dims, rows, indices.shape[1])

return output

def scatter_add(self, np.ndarray table, np.ndarray indices, np.ndarray values):
if table.dtype == 'float32' \
and indices.dtype == 'int32' \
Expand Down
3 changes: 3 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,9 @@ def position_encode(
numpy_ops = NumpyOps()
return self.asarray2f(numpy_ops.position_encode(N, D, period, out))

def gather_add(self, table: FloatsXd, indices: IntsXd) -> FloatsXd:
return table[indices].sum(axis=1) # type: ignore[call-overload]

def scatter_add(
self, table: FloatsXd, indices: IntsXd, values: FloatsXd
) -> FloatsXd:
Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/hashembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(
nN = ids.shape[0]
seed: int = model.attrs["seed"]
keys = model.ops.hash(ids, seed) % nV
output = vectors[keys].sum(axis=1)
output = cast(Floats2d, model.ops.gather_add(vectors, keys))
drop_mask = None
if is_train:
dropout: Optional[float] = model.attrs.get("dropout_rate")
Expand Down
34 changes: 33 additions & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect

from .. import strategies
from ..strategies import ndarrays_of_shape
from ..strategies import arrays_BI, ndarrays_of_shape


MAX_EXAMPLES = 10
Expand Down Expand Up @@ -198,6 +198,38 @@ def test_get_dropout_not_empty(ops):
assert mask.shape == shape


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
@pytest.mark.parametrize("index_dtype", ["int32", "uint32"])
def test_gather_add(ops, dtype, index_dtype):
table = ops.xp.arange(12, dtype=dtype).reshape(4, 3)
indices = ops.xp.array([[0, 2], [3, 1], [0, 1]], dtype=index_dtype)
gathered = ops.gather_add(table, indices)
ops.xp.testing.assert_allclose(
gathered, [[6.0, 8.0, 10.0], [12.0, 14.0, 16.0], [3.0, 5.0, 7.0]]
)


@pytest.mark.parametrize("ops", XP_OPS)
@given(table=strategies.arrays_BI())
def test_gather_add_against_numpy(ops, table):
table = ops.asarray(table)
indices = ops.xp.arange(100, dtype="i").reshape(25, 4) % table.shape[0]
ops.xp.testing.assert_allclose(
ops.gather_add(table, indices),
table[indices].sum(1),
atol=1e-5,
)


@pytest.mark.parametrize("ops", ALL_OPS)
def test_gather_add_oob_raises(ops):
table = ops.xp.arange(12, dtype="f").reshape(4, 3)
indices = ops.xp.array([[0, 2], [3, 1], [5, 1]], dtype="i")
with pytest.raises(IndexError):
ops.gather_add(table, indices)


@pytest.mark.parametrize("ops", CPU_OPS)
def test_seq2col_window_one_small(ops):
seq = ops.asarray([[1.0], [3.0], [4.0], [5]], dtype="float32")
Expand Down