Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
"murmurhash>=1.0.2,<1.1.0",
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"blis>=0.4.0,<0.8.0",
"blis>=0.9.0,<0.10.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
srsly>=2.4.0,<3.0.0
wasabi>=0.8.1,<1.1.0
catalogue>=2.0.4,<2.1.0
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=1.0.2,<1.1.0
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
install_requires =
# Explosion-provided dependencies
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

PACKAGES = find_packages()
MOD_NAMES = [
"thinc.backends.cblas",
"thinc.backends.linalg",
"thinc.backends.numpy_ops",
"thinc.extra.search",
Expand Down
24 changes: 24 additions & 0 deletions thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from libcpp.memory cimport shared_ptr


ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
int ldb, float beta, float* C, int ldc) nogil


ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil


# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so
# that consumers of the CBlas class cannot become dependent on its size or
# ordering.
cdef struct BlasFuncs


cdef class CBlas:
cdef shared_ptr[BlasFuncs] ptr
cdef saxpy_ptr saxpy(self) nogil
cdef sgemm_ptr sgemm(self) nogil
cdef void set_saxpy(self, saxpy_ptr saxpy) nogil
cdef void set_sgemm(self, sgemm_ptr sgemm) nogil
32 changes: 32 additions & 0 deletions thinc/backends/cblas.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
cimport blis.cy
from cython.operator cimport dereference as deref
from libcpp.memory cimport make_shared


cdef struct BlasFuncs:
saxpy_ptr saxpy
sgemm_ptr sgemm


cdef class CBlas:
__slots__ = []

def __init__(self):
"""Construct a CBlas instance set to use BLIS implementations of the
supported BLAS functions."""
cdef BlasFuncs funcs
funcs.saxpy = blis.cy.saxpy
funcs.sgemm = blis.cy.sgemm
self.ptr = make_shared[BlasFuncs](funcs)

cdef saxpy_ptr saxpy(self) nogil:
return deref(self.ptr).saxpy

cdef sgemm_ptr sgemm(self) nogil:
return deref(self.ptr).sgemm

cdef void set_saxpy(self, saxpy_ptr saxpy) nogil:
deref(self.ptr).saxpy = saxpy

cdef void set_sgemm(self, sgemm_ptr sgemm) nogil:
deref(self.ptr).sgemm = sgemm
4 changes: 4 additions & 0 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cimport blis.cy
from .. import registry
from ..util import copy_array, get_array_module
from ..types import DeviceTypes, DTypes, Shape, ArrayXd
from .cblas cimport CBlas
from .linalg cimport VecVec, Vec
from .ops import Ops

Expand Down Expand Up @@ -82,6 +83,9 @@ class NumpyOps(Ops):
else:
return self.xp.empty(shape, dtype=dtype)

def cblas(self) -> CBlas:
return CBlas()

def gemm(self, np.ndarray x, np.ndarray y, *, np.ndarray out=None, trans1=False, trans2=False):
if x.ndim != 2:
raise ValueError(f"Provided 'x' array should be 2-dimensional, but found {x.ndim} dimension(s).")
Expand Down
6 changes: 6 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator
from ..util import get_array_module, is_xp_array, to_numpy

from .cblas import CBlas

ArrayT = TypeVar("ArrayT", bound=ArrayXd)
FloatsT = TypeVar("FloatsT", bound=_Floats)
Expand All @@ -31,6 +32,11 @@ def __init__(
self.device_type = device_type
self.device_id = device_id

def cblas(self) -> CBlas:
"""Return C BLAS function table."""
err = f"{type(self).__name__} does not provide C BLAS functions"
raise NotImplementedError(err)

def to_numpy(self, data, *, byte_order=None): # pragma: no cover
if isinstance(data, numpy.ndarray):
if byte_order:
Expand Down
21 changes: 21 additions & 0 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,27 @@ the inputs and outputs.
| `zeros` | <tt>bool</tt> | Fill the array with zeros (default: `True`). |
| **RETURNS** | <tt>ArrayXd</tt> | An array of the correct shape and data type. |

### Ops.cblas {#cblas tag="method"}

<inline-list>

- **default:** <i name="no"></i>
- **numpy:** <i name="yes"></i>
- **cupy:** <i name="no"></i>

</inline-list>

Get a table of C BLAS functions usable in Cython `cdef nogil` functions. This
method does not take any arguments.

<infobox variant="warning">

This method is only supported by `NumpyOps` and subclasses of `NumpyOps`. A
`NotImplementedError` exception is raised when calling this method on `Ops` or
`CupyOps`.

</infobox>

### Ops.to_numpy {#to_numpy tag="method"}

<inline-list>
Expand Down