Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions thinc/backends/cpu_kernels.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <type_traits>


typedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY);

// All elementwise functions, such as most activations, work in-place.

template <typename A, typename L>
Expand Down Expand Up @@ -395,4 +398,20 @@ void backprop_seq2col(A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L
}
}


template <typename I, typename L>
void cpu_gather_add(saxpy_ptr saxpy, float* out_bo, const float* table_to, const I* indices_bk, L T, L O, L B, L K)
{
for (L b = 0; b < B; ++b) {
for (L k = 0; k < K; ++k) {
I idx = indices_bk[b * K + k];
if (idx > T) {
throw std::out_of_range("Embedding index out-of-bounds");
}
(*saxpy)(O, 1.0, table_to + idx * O, 1, out_bo + b * O, 1);
}
}
}


#endif // CPU_KERNELS_HH
4 changes: 4 additions & 0 deletions thinc/backends/numpy_ops.pxd
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .cblas cimport saxpy_ptr

ctypedef double[:, ::1] double2d_t
ctypedef double[:, :, ::1] double3d_t
ctypedef float[:, ::1] float2d_t
Expand Down Expand Up @@ -35,3 +37,5 @@ cdef extern from "cpu_kernels.hh":
void cpu_relu[A, L](A* X, L N)
void backprop_seq2col[A, L](A* d_seqs, const A* d_cols, const L* lengths, L B, L I, L nW, L nL)
void seq2col[A, L](A* output, const A* X, const L* lengths, L nW, L B, L I, L nL)
void cpu_gather_add[I, L](saxpy_ptr saxpy, float* out_bo, const float* table_to, const I* indices_bk,
L T, L O, L B, L K) except +
9 changes: 9 additions & 0 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,15 @@ class NumpyOps(Ops):

return dX

def gather_add(self, float[:, ::1] table, unsigned int[:, ::1] indices):
cdef CBlas cblas = self.cblas()
rows = indices.shape[0]
dims = table.shape[1]
cdef np.ndarray output = self.xp.zeros((rows, dims), dtype="float32")
cpu_gather_add(cblas.saxpy(), <float *>output.data, &table[0, 0], &indices[0, 0],
table.shape[0], dims, rows, indices.shape[1])
return output

def scatter_add(self, np.ndarray table, np.ndarray indices, np.ndarray values):
if table.dtype == 'float32' \
and indices.dtype == 'int32' \
Expand Down
3 changes: 3 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,9 @@ def position_encode(
numpy_ops = NumpyOps()
return self.asarray2f(numpy_ops.position_encode(N, D, period, out))

def gather_add(self, table: FloatsXd, indices: IntsXd):
return table[indices].sum(axis=1)

def scatter_add(
self, table: FloatsXd, indices: IntsXd, values: FloatsXd
) -> FloatsXd:
Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/hashembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(
nN = ids.shape[0]
seed: int = model.attrs["seed"]
keys = model.ops.hash(ids, seed) % nV
output = vectors[keys].sum(axis=1)
output = model.ops.gather_add(vectors, keys)
drop_mask = None
if is_train:
dropout: Optional[float] = model.attrs.get("dropout_rate")
Expand Down