Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions thinc/backends/_custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ struct Constants<float> {
};


template <typename U>
__global__ void gather_add(U* out_bo, const U* table_to, const int* indices_bk,
int T, int O, int B, int K)
{
int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int b = _loop_start; b < B; b += _loop_stride) {
for (int k = 0; k < K; ++k) {
int idx = indices_bk[b * K + k];
const U* table = table_to + idx * O;
U* out = out_bo + b * O;
for (int o = 0; o < O; ++o) {
out[o] += table[o];
}
}
}
}


template <typename T>
__global__ void seq2col(T* output, const T* X, const int* lengths,
int nW, int B, int I, int nL)
Expand Down
37 changes: 37 additions & 0 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
"gelu<float>",
"maxout<double>",
Expand Down Expand Up @@ -76,6 +78,8 @@ def compile_mmh(src):

clipped_linear_kernel_float = _get_kernel("clipped_linear<float>")
clipped_linear_kernel_double = _get_kernel("clipped_linear<double>")
gather_add_kernel_float = _get_kernel("gather_add<float>")
gather_add_kernel_double = _get_kernel("gather_add<double>")
gelu_kernel_float = _get_kernel("gelu<float>")
gelu_kernel_double = _get_kernel("gelu<double>")
hash_data_kernel = compile_mmh(MMH_SRC)
Expand Down Expand Up @@ -165,6 +169,32 @@ def clipped_linear(
return out


def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
if table.ndim != 2:
raise ValueError(f"gather_add expects table with dimensionality 2, was: {table.ndim}")
if indices.ndim != 2:
raise ValueError(f"gather_add expects indices with dimensionality 2, was: {indices.ndim}")
_is_float_array(table)
indices = indices.astype("int32")
_check_indices(indices, table.shape[0])

B = indices.shape[0]
K = indices.shape[1]
T = table.shape[0]
O = table.shape[1]

out = _alloc((B, O), dtype=table.dtype, zeros=True)
if table.dtype == "float32":
gather_add_kernel_float(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
else:
gather_add_kernel_double(
(num_blocks,), (threads_per_block,), (out, table, indices, T, O, B, K)
)
return out


def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)

Expand Down Expand Up @@ -647,6 +677,13 @@ def _check_lengths(lengths, n_elems: int, *, min_length=0):
raise IndexError("lengths must sum up to the batch size")


def _check_indices(indices, n: int):
assert indices.dtype == "int32", "indices should be encoded as 32-bit integers"

if not _values_within_range(indices, 0, n):
raise IndexError(f"index out of bounds, must be >= 0 && < {n}")


def _check_which_maxout(which, B: int, I: int, P: int):
shape = (B, I)
msg = "maximum index (which) should be encoded as 32-bit integers"
Expand Down
6 changes: 6 additions & 0 deletions thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil


ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX,
double *Y, int incY) nogil


# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so
# that consumers of the CBlas class cannot become dependent on its size or
# ordering.
Expand All @@ -26,7 +30,9 @@ cdef class CBlas:
#
# See https://github.com/explosion/thinc/pull/700 for more information.

cdef daxpy_ptr daxpy(CBlas cblas) nogil
cdef saxpy_ptr saxpy(CBlas cblas) nogil
cdef sgemm_ptr sgemm(CBlas cblas) nogil
cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil
cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil
cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil
8 changes: 8 additions & 0 deletions thinc/backends/cblas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from libcpp.memory cimport make_shared


cdef struct BlasFuncs:
daxpy_ptr daxpy
saxpy_ptr saxpy
sgemm_ptr sgemm

Expand All @@ -15,16 +16,23 @@ cdef class CBlas:
"""Construct a CBlas instance set to use BLIS implementations of the
supported BLAS functions."""
cdef BlasFuncs funcs
funcs.daxpy = blis.cy.daxpy
funcs.saxpy = blis.cy.saxpy
funcs.sgemm = blis.cy.sgemm
self.ptr = make_shared[BlasFuncs](funcs)

cdef daxpy_ptr daxpy(CBlas cblas) nogil:
return deref(cblas.ptr).daxpy

cdef saxpy_ptr saxpy(CBlas cblas) nogil:
return deref(cblas.ptr).saxpy

cdef sgemm_ptr sgemm(CBlas cblas) nogil:
return deref(cblas.ptr).sgemm

cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil:
deref(cblas.ptr).daxpy = daxpy

cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil:
deref(cblas.ptr).saxpy = saxpy

Expand Down
30 changes: 30 additions & 0 deletions thinc/backends/cpu_kernels.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
#include <string>
#include <type_traits>

// Ideally we'd use an alias declaration for a generic definition of
// *axpy. But Cython doesn't support alias declarations yet:
//
// https://github.com/cython/cython/issues/3272
//
// template <typename T>
// using axpy = void (*)(int N, T alpha, const T* X, int incX,
// T *Y, int incY);
//
// So, instead we'll do this the pre-C++11 way:

template <typename T>
struct axpy {
typedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY);
};


// All elementwise functions, such as most activations, work in-place.

Expand Down Expand Up @@ -395,4 +411,18 @@ void backprop_seq2col(A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L
}
}

template <typename F, typename I, typename L>
void cpu_gather_add(typename axpy<F>::ptr axpy, F* out_bo, const F* table_to, const I* indices_bk, L T, L O, L B, L K) {
for (L b = 0; b < B; ++b) {
for (L k = 0; k < K; ++k) {
I idx = indices_bk[b * K + k];
if (idx > T) {
throw std::out_of_range("Embedding index out-of-bounds");
}
axpy(O, 1.0, table_to + idx * O, 1, out_bo + b * O, 1);
}
}
}


#endif // CPU_KERNELS_HH
6 changes: 6 additions & 0 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def to_numpy(self, data, *, byte_order=None):
data = numpy.asarray(data, dtype=dtype)
return data

def gather_add(self, table, indices):
if table.dtype in ("float32", "float64"):
return _custom_kernels.gather_add(table, indices)
else:
return super().gather_add(table, indices)

def gelu(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0)
Expand Down
13 changes: 13 additions & 0 deletions thinc/backends/numpy_ops.pxd
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from .cblas cimport saxpy_ptr

ctypedef double[:, ::1] double2d_t
ctypedef double[:, :, ::1] double3d_t
ctypedef float[:, ::1] float2d_t
ctypedef float[:, :, ::1] float3d_t
ctypedef int[:, ::1] int2d_t
ctypedef unsigned int[:, ::1] uint2d_t

cdef fused ints2d_ft:
int2d_t
uint2d_t

cdef fused reals2d_ft:
float2d_t
Expand All @@ -13,6 +21,9 @@ cdef fused reals3d_ft:


cdef extern from "cpu_kernels.hh":
cdef cppclass axpy[T]:
ctypedef void (*ptr)(int N, T alpha, const T* X, int incX, T *Y, int incY);

void cpu_maxout[A, L](A* best__bo, L* which__bo, const A* cands_bop,
L B, L O, L P)
void cpu_backprop_maxout[A, L](A* dX__bop, const A* dX__bo, const L* which__bo,
Expand All @@ -35,3 +46,5 @@ cdef extern from "cpu_kernels.hh":
void cpu_relu[A, L](A* X, L N)
void backprop_seq2col[A, L](A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L nW, L nL)
void seq2col[A, L](A* output, const A* X, const L* lengths, L nW, L B, L I, L nL)
void cpu_gather_add[F, I, L](axpy[F].ptr axpy, F* out_bo, const F* table_to, const I* indices_bk,
L T, L O, L B, L K) except +
19 changes: 18 additions & 1 deletion thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cimport blis.cy
from .. import registry
from ..util import copy_array, get_array_module
from ..types import DeviceTypes, DTypes, Shape, ArrayXd
from .cblas cimport CBlas
from .cblas cimport CBlas, daxpy, saxpy
from .linalg cimport VecVec, Vec
from .ops import Ops

Expand Down Expand Up @@ -437,6 +437,23 @@ class NumpyOps(Ops):

return dX

def gather_add(self, reals2d_ft table, ints2d_ft indices):
cdef CBlas cblas = self.cblas()
rows = indices.shape[0]
dims = table.shape[1]

cdef np.ndarray output
if reals2d_ft is float2d_t:
output = self.xp.zeros((rows, dims), dtype="float32")
cpu_gather_add(saxpy(cblas), <float *>output.data, &table[0, 0], &indices[0, 0],
table.shape[0], dims, rows, indices.shape[1])
else:
output = self.xp.zeros((rows, dims), dtype="float64")
cpu_gather_add(daxpy(cblas), <double *>output.data, &table[0, 0], &indices[0, 0],
table.shape[0], dims, rows, indices.shape[1])

return output

def scatter_add(self, np.ndarray table, np.ndarray indices, np.ndarray values):
if table.dtype == 'float32' \
and indices.dtype == 'int32' \
Expand Down
3 changes: 3 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,9 @@ def position_encode(
numpy_ops = NumpyOps()
return self.asarray2f(numpy_ops.position_encode(N, D, period, out))

def gather_add(self, table: Floats2d, indices: Ints2d) -> Floats2d:
return table[indices].sum(axis=1) # type: ignore[return-value]

def scatter_add(
self, table: FloatsXd, indices: IntsXd, values: FloatsXd
) -> FloatsXd:
Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/hashembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(
nN = ids.shape[0]
seed: int = model.attrs["seed"]
keys = model.ops.hash(ids, seed) % nV
output = vectors[keys].sum(axis=1)
output = model.ops.gather_add(vectors, keys)
drop_mask = None
if is_train:
dropout: Optional[float] = model.attrs.get("dropout_rate")
Expand Down
34 changes: 33 additions & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect

from .. import strategies
from ..strategies import ndarrays_of_shape
from ..strategies import arrays_BI, ndarrays_of_shape


MAX_EXAMPLES = 10
Expand Down Expand Up @@ -198,6 +198,38 @@ def test_get_dropout_not_empty(ops):
assert mask.shape == shape


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
@pytest.mark.parametrize("index_dtype", ["int32", "uint32"])
def test_gather_add(ops, dtype, index_dtype):
table = ops.xp.arange(12, dtype=dtype).reshape(4, 3)
indices = ops.xp.array([[0, 2], [3, 1], [0, 1]], dtype=index_dtype)
gathered = ops.gather_add(table, indices)
ops.xp.testing.assert_allclose(
gathered, [[6.0, 8.0, 10.0], [12.0, 14.0, 16.0], [3.0, 5.0, 7.0]]
)


@pytest.mark.parametrize("ops", XP_OPS)
@given(table=strategies.arrays_BI())
def test_gather_add_against_numpy(ops, table):
table = ops.asarray(table)
indices = ops.xp.arange(100, dtype="i").reshape(25, 4) % table.shape[0]
ops.xp.testing.assert_allclose(
ops.gather_add(table, indices),
table[indices].sum(1),
atol=1e-5,
)


@pytest.mark.parametrize("ops", ALL_OPS)
def test_gather_add_oob_raises(ops):
table = ops.xp.arange(12, dtype="f").reshape(4, 3)
indices = ops.xp.array([[0, 2], [3, 1], [5, 1]], dtype="i")
with pytest.raises(IndexError):
ops.gather_add(table, indices)


@pytest.mark.parametrize("ops", CPU_OPS)
def test_seq2col_window_one_small(ops):
seq = ops.asarray([[1.0], [3.0], [4.0], [5]], dtype="float32")
Expand Down
23 changes: 21 additions & 2 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -1298,8 +1298,8 @@ Backpropagate the `reduce_mean` operation.
</inline-list>

Perform sequence-wise max pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in
`lengths` is zero.
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.

| Argument | Type | Description |
| ----------- | -------------------------------- | --------------------------- |
Expand Down Expand Up @@ -1364,6 +1364,25 @@ Create hashed ngram features.
| `keys` | <tt>Ints1d</tt> | The input sequence. |
| **RETURNS** | <tt>Ints1d</tt> | The hashed ngrams. |

### Ops.gather_add {#gather_add tag="method" new="8.1"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** <i name="yes"></i>
- **cupy:** <i name="yes"></i>

</inline-list>

Gather rows from `table` with shape `(T, O)` using array `indices` with shape
`(B, K)`, then sum the resulting array with shape `(B, K, O)` over the `K` axis.

| Argument | Type | Description |
| ----------- | ----------------- | ----------------------- |
| `table` | <tt>Floats2d</tt> | The array to increment. |
| `indices` | <tt>Ints2d</tt> | The indices to use. |
| **RETURNS** | <tt>Floats2d</tt> | The summed rows. |

### Ops.scatter_add {#scatter_add tag="method"}

<inline-list>
Expand Down