Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .layers import CauchySimilarity, ParametricAttention, Logistic
from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear
from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid
from .layers import HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
from .layers import PyTorchWrapper_v2, Softmax_v2
Expand Down
31 changes: 31 additions & 0 deletions thinc/backends/_custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ __global__ void clipped_linear(T* Y, const T* X, double slope, double offset, do
}


template <typename T>
__global__ void dish(T* Y, const T* X, int N)
{
int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int i = _loop_start; i < N; i += _loop_stride)
{
T x = X[i];
Y[i] = 0.5 * x * (x / sqrt(1 + x * x) + 1);
}
}


template <typename T>
__global__ void gelu(T* Y, const T* X, double threshold, int N)
{
Expand Down Expand Up @@ -414,6 +428,23 @@ __global__ void backprop_hard_swish_mobilenet(T* dX, const T* dY, const T* X, in
}


template <typename T>
__global__ void backprop_dish(T* dX, const T* dY, const T* X, int N)
{

int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int i = _loop_start; i < N; i += _loop_stride)
{
T x = X[i];
T x_sq = x * x;
T x_sq_plus_one = x_sq + 1.0;
dX[i] = dY[i] * (x/sqrt(x_sq_plus_one) - (0.5 * x * x_sq)
/ pow(x_sq_plus_one, static_cast<T>(1.5)) + 0.5);
}
}


template <typename T>
__global__ void backprop_gelu(T* dX, const T* dY, const T* X,
Expand Down
48 changes: 48 additions & 0 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
KERNELS_LIST = [
"backprop_clipped_linear<double>",
"backprop_clipped_linear<float>",
"backprop_dish<double>",
"backprop_dish<float>",
"backprop_gelu<double>",
"backprop_gelu<float>",
"backprop_hard_swish<double>",
Expand All @@ -32,6 +34,8 @@
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"dish<double>",
"dish<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
Expand Down Expand Up @@ -78,6 +82,8 @@ def compile_mmh(src):

clipped_linear_kernel_float = _get_kernel("clipped_linear<float>")
clipped_linear_kernel_double = _get_kernel("clipped_linear<double>")
dish_kernel_float = _get_kernel("dish<float>")
dish_kernel_double = _get_kernel("dish<double>")
gather_add_kernel_float = _get_kernel("gather_add<float>")
gather_add_kernel_double = _get_kernel("gather_add<double>")
gelu_kernel_float = _get_kernel("gelu<float>")
Expand All @@ -98,6 +104,8 @@ def compile_mmh(src):

backprop_clipped_linear_kernel_double = _get_kernel("backprop_clipped_linear<double>")
backprop_clipped_linear_kernel_float = _get_kernel("backprop_clipped_linear<float>")
backprop_dish_kernel_double = _get_kernel("backprop_dish<double>")
backprop_dish_kernel_float = _get_kernel("backprop_dish<float>")
backprop_gelu_kernel_double = _get_kernel("backprop_gelu<double>")
backprop_gelu_kernel_float = _get_kernel("backprop_gelu<float>")
backprop_hard_swish_kernel_double = _get_kernel("backprop_hard_swish<double>")
Expand Down Expand Up @@ -199,6 +207,19 @@ def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
return out


def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128):
_is_float_array(X)

out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size))
else:
dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size))
return out


def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)

Expand Down Expand Up @@ -483,6 +504,33 @@ def backprop_hard_swish_mobilenet(
return out


def backprop_dish(
dY,
X,
*,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)

out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_dish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_dish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)

return out


def backprop_gelu(
dY,
X,
Expand Down
12 changes: 12 additions & 0 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def gather_add(self, table, indices):
else:
return super().gather_add(table, indices)

def dish(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.dish(X, inplace=inplace)
else:
return super().dish(X, inplace=inplace)

def backprop_dish(self, dY, X, inplace=False):
if X.dtype == dY.dtype and X.dtype in ("float32", "float64"):
return _custom_kernels.backprop_dish(dY, X, inplace=inplace)
else:
return super().backprop_dish(dY, X, inplace=inplace)

def gelu(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0)
Expand Down
29 changes: 29 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,35 @@ def backprop_hard_swish_mobilenet(
return dY
return dX * dY

def dish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
tmp = self.xp.square(X)
tmp += 1.0
self.xp.sqrt(tmp, out=tmp)
tmp = X / tmp
tmp += 1
tmp *= 0.5
if inplace:
X *= tmp
return X
else:
return X * tmp

def backprop_dish(
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
x_sq = self.xp.square(X)
x_sq_plus_one = x_sq + 1.0
deriv = X / self.xp.sqrt(x_sq_plus_one)
second = 0.5 * X * x_sq
second /= x_sq_plus_one**1.5
deriv -= second
deriv += 0.5
if inplace:
dY *= deriv
return dY
else:
return dY * deriv

# Code snippet taken from:
# https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
def erf(self, X: FloatsXdT) -> FloatsXdT:
Expand Down
1 change: 1 addition & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Weights layers
from .cauchysimilarity import CauchySimilarity
from .dish import Dish
from .dropout import Dropout
from .embed import Embed
from .expand_window import expand_window
Expand Down
66 changes: 66 additions & 0 deletions thinc/layers/dish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Tuple, Optional, Callable, cast

from ..config import registry
from ..model import Model
from .chain import chain
from .layernorm import LayerNorm
from .dropout import Dropout
from ..types import Floats1d, Floats2d
from ..util import partial, get_width
from ..initializers import he_normal_init, zero_init


@registry.layers("Dish.v1")
def Dish(
nO: Optional[int] = None,
nI: Optional[int] = None,
*,
init_W: Callable = he_normal_init,
init_b: Callable = zero_init,
dropout: Optional[float] = None,
normalize: bool = False,
) -> Model[Floats2d, Floats2d]:
model: Model[Floats2d, Floats2d] = Model(
"dish",
forward,
init=partial(init, init_W, init_b),
dims={"nO": nO, "nI": nI},
params={"W": None, "b": None},
)
if normalize:
model = chain(model, LayerNorm(nI=nO))
if dropout is not None:
model = chain(model, cast(Model[Floats2d, Floats2d], Dropout(dropout)))
return model


def forward(
model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool
) -> Tuple[Floats2d, Callable]:
W = cast(Floats2d, model.get_param("W"))
b = cast(Floats1d, model.get_param("b"))
Y_preact = model.ops.affine(X, W, b)
Y = model.ops.dish(Y_preact)

def backprop(dY: Floats2d) -> Floats2d:
dY = model.ops.backprop_dish(dY, X, inplace=False)
model.inc_grad("b", dY.sum(axis=0))
model.inc_grad("W", model.ops.gemm(dY, X, trans1=True))
return model.ops.gemm(dY, W)

return Y, backprop


def init(
init_W: Callable,
init_b: Callable,
model: Model[Floats2d, Floats2d],
X: Optional[Floats2d] = None,
Y: Optional[Floats2d] = None,
) -> None:
if X is not None:
model.set_dim("nI", get_width(X))
if Y is not None:
model.set_dim("nO", get_width(Y))
model.set_param("W", init_W(model.ops, (model.get_dim("nO"), model.get_dim("nI"))))
model.set_param("b", init_b(model.ops, (model.get_dim("nO"),)))
29 changes: 22 additions & 7 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def torch_hard_swish_mobilenet(x):
def torch_sigmoid(x):
return torch.sigmoid(x)

def torch_dish(x):
return 0.5 * x * (x / (1 + x * x).sqrt() + 1)
Comment on lines +67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is torch getting credit here? :p


# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37
def torch_gelu_approx(x):
return (
Expand All @@ -89,6 +92,7 @@ def torch_gelu(x):
("swish", torch_swish),
("hard_swish", torch_hard_swish),
("hard_swish_mobilenet", torch_hard_swish_mobilenet),
("dish", torch_dish),
("gelu_approx", torch_gelu_approx),
("gelu", torch_gelu),
("sigmoid", torch_sigmoid),
Expand Down Expand Up @@ -1043,6 +1047,7 @@ def test_mish(ops, X):
"op",
[
"backprop_clipped_linear",
"backprop_dish",
"backprop_gelu",
"backprop_gelu_approx",
"backprop_hard_sigmoid",
Expand Down Expand Up @@ -1160,6 +1165,16 @@ def test_gelu_approx(ops, X):
assert not ops.xp.isnan(Y).any()


@pytest.mark.parametrize("ops", ALL_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
def test_dish(ops, X):
X = ops.asarray(X)
Y = ops.dish(X)
assert Y.shape == X.shape
assert not ops.xp.isnan(Y).any()


@pytest.mark.parametrize("ops", ALL_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
Expand Down Expand Up @@ -1350,8 +1365,8 @@ def test_ngrams():
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("torch_func", TORCH_FUNCS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(x=strategies.floats(min_value=-30, max_value=30))
def test_compare_activations_to_torch(ops, dtype, x, torch_func):
@given(x=strategies.floats(min_value=-30, max_value=30), dY=strategies.floats(min_value=-1, max_value=1))
def test_compare_activations_to_torch(ops, dtype, x, dY, torch_func):
import torch

func_name, pytorch_func = torch_func
Expand All @@ -1369,9 +1384,9 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
y_think_inplace = forward(x_thinc, inplace=True)
assert y_think_inplace is x_thinc
assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06)
assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-06)
assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-05)
x_thinc = ops.asarray([x], dtype=dtype)
dY_thinc = ops.asarray([1.0], dtype=dtype)
dY_thinc = ops.asarray([dY], dtype=dtype)
dY_thinc_inplace = dY_thinc.copy()

s = inspect.signature(backward)
Expand All @@ -1386,23 +1401,23 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
)
assert dx_thinc_inplace is dY_thinc_inplace
assert ops.xp.isclose(dx_thinc, dx_thinc_inplace)
assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06)
assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06)
elif params == {"Y", "dY"}:
dx_thinc = backward(dY_thinc, Y=y_thinc)
assert dx_thinc.dtype == x_thinc.dtype
assert ops.xp.isclose(
dx_thinc,
backward(dY=dY_thinc_inplace, Y=y_thinc, inplace=True),
)
assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06)
assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06)
elif params == {"dY", "X"}:
dx_thinc = backward(dY_thinc, X=x_thinc)
assert dx_thinc.dtype == x_thinc.dtype
assert ops.xp.isclose(
dx_thinc, backward(dY=dY_thinc_inplace, X=x_thinc, inplace=True)
)
assert ops.xp.isclose(
x_torch.grad.item(), float(backward(dY_thinc, X=x_thinc)), atol=1e-06
x_torch.grad.item() * dY, float(backward(dY_thinc, X=x_thinc)), atol=1e-06
)
else:
raise NotImplementedError(
Expand Down
2 changes: 2 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def assert_data_match(Y, out_data):

TEST_CASES_SUMMABLE = [
# Array to array
("Dish.v1", {}, array2d, array2d),
("Dish.v1", {"nO": 4, "nI": 4}, array2d, array2d),
("Dropout.v1", {}, array2d, array2d),
("LayerNorm.v1", {}, array2d, array2d),
("Linear.v1", {}, array2d, array2d),
Expand Down
Loading