Skip to content
Merged
31 changes: 31 additions & 0 deletions thinc/backends/_custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ __global__ void clipped_linear(T* Y, const T* X, double slope, double offset, do
}


template <typename T>
__global__ void dish(T* Y, const T* X, int N)
{
int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int i = _loop_start; i < N; i += _loop_stride)
{
T x = X[i];
Y[i] = 0.5 * x * (x / sqrt(1 + x * x) + 1);
}
}


template <typename T>
__global__ void gelu(T* Y, const T* X, double threshold, int N)
{
Expand Down Expand Up @@ -414,6 +428,23 @@ __global__ void backprop_hard_swish_mobilenet(T* dX, const T* dY, const T* X, in
}


template <typename T>
__global__ void backprop_dish(T* dX, const T* dY, const T* X, int N)
{

int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;

for (int i = _loop_start; i < N; i += _loop_stride)
{
T x = X[i];
T x_sq = x * x;
T x_sq_plus_one = x_sq + 1.0;
dX[i] = dY[i] * (x/sqrt(x_sq_plus_one) - (0.5 * x * x_sq)
/ pow(x_sq_plus_one, static_cast<T>(1.5)) + 0.5);
}
}


template <typename T>
__global__ void backprop_gelu(T* dX, const T* dY, const T* X,
Expand Down
48 changes: 48 additions & 0 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
KERNELS_LIST = [
"backprop_clipped_linear<double>",
"backprop_clipped_linear<float>",
"backprop_dish<double>",
"backprop_dish<float>",
"backprop_gelu<double>",
"backprop_gelu<float>",
"backprop_hard_swish<double>",
Expand All @@ -32,6 +34,8 @@
"backprop_swish<float>",
"clipped_linear<double>",
"clipped_linear<float>",
"dish<double>",
"dish<float>",
"gather_add<double>",
"gather_add<float>",
"gelu<double>",
Expand Down Expand Up @@ -78,6 +82,8 @@ def compile_mmh(src):

clipped_linear_kernel_float = _get_kernel("clipped_linear<float>")
clipped_linear_kernel_double = _get_kernel("clipped_linear<double>")
dish_kernel_float = _get_kernel("dish<float>")
dish_kernel_double = _get_kernel("dish<double>")
gather_add_kernel_float = _get_kernel("gather_add<float>")
gather_add_kernel_double = _get_kernel("gather_add<double>")
gelu_kernel_float = _get_kernel("gelu<float>")
Expand All @@ -98,6 +104,8 @@ def compile_mmh(src):

backprop_clipped_linear_kernel_double = _get_kernel("backprop_clipped_linear<double>")
backprop_clipped_linear_kernel_float = _get_kernel("backprop_clipped_linear<float>")
backprop_dish_kernel_double = _get_kernel("backprop_dish<double>")
backprop_dish_kernel_float = _get_kernel("backprop_dish<float>")
backprop_gelu_kernel_double = _get_kernel("backprop_gelu<double>")
backprop_gelu_kernel_float = _get_kernel("backprop_gelu<float>")
backprop_hard_swish_kernel_double = _get_kernel("backprop_hard_swish<double>")
Expand Down Expand Up @@ -199,6 +207,19 @@ def gather_add(table, indices, *, threads_per_block=128, num_blocks=128):
return out


def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128):
_is_float_array(X)

out = X
if not inplace:
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size))
else:
dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size))
return out


def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128):
_is_float_array(X)

Expand Down Expand Up @@ -483,6 +504,33 @@ def backprop_hard_swish_mobilenet(
return out


def backprop_dish(
dY,
X,
*,
inplace: bool = False,
threads_per_block=128,
num_blocks=128,
):
_is_float_array(dY)
_is_float_array(X, shape=dY.shape)

out = dY
if not inplace:
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_dish_kernel_float(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)
else:
backprop_dish_kernel_double(
(num_blocks,), (threads_per_block,), (out, dY, X, out.size)
)

return out


def backprop_gelu(
dY,
X,
Expand Down
12 changes: 12 additions & 0 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def gather_add(self, table, indices):
else:
return super().gather_add(table, indices)

def dish(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.dish(X, inplace=inplace)
else:
return super().dish(X, inplace=inplace)

def backprop_dish(self, dY, X, inplace=False):
if X.dtype == dY.dtype and X.dtype in ("float32", "float64"):
return _custom_kernels.backprop_dish(dY, X, inplace=inplace)
else:
return super().backprop_dish(dY, X, inplace=inplace)

def gelu(self, X, inplace=False):
if X.dtype in ("float32", "float64"):
return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0)
Expand Down
29 changes: 29 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,35 @@ def backprop_hard_swish_mobilenet(
return dY
return dX * dY

def dish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
tmp = self.xp.square(X)
tmp += 1.0
self.xp.sqrt(tmp, out=tmp)
tmp = X / tmp
tmp += 1
tmp *= 0.5
if inplace:
X *= tmp
return X
else:
return X * tmp

def backprop_dish(
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
x_sq = self.xp.square(X)
x_sq_plus_one = x_sq + 1.0
deriv = X / self.xp.sqrt(x_sq_plus_one)
second = 0.5 * X * x_sq
second /= x_sq_plus_one**1.5
deriv -= second
deriv += 0.5
if inplace:
dY *= deriv
return dY
else:
return dY * deriv

# Code snippet taken from:
# https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
def erf(self, X: FloatsXdT) -> FloatsXdT:
Expand Down
29 changes: 22 additions & 7 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def torch_hard_swish_mobilenet(x):
def torch_sigmoid(x):
return torch.sigmoid(x)

def torch_dish(x):
return 0.5 * x * (x / (1 + x * x).sqrt() + 1)
Comment on lines +67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is torch getting credit here? :p


# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37
def torch_gelu_approx(x):
return (
Expand All @@ -89,6 +92,7 @@ def torch_gelu(x):
("swish", torch_swish),
("hard_swish", torch_hard_swish),
("hard_swish_mobilenet", torch_hard_swish_mobilenet),
("dish", torch_dish),
("gelu_approx", torch_gelu_approx),
("gelu", torch_gelu),
("sigmoid", torch_sigmoid),
Expand Down Expand Up @@ -1043,6 +1047,7 @@ def test_mish(ops, X):
"op",
[
"backprop_clipped_linear",
"backprop_dish",
"backprop_gelu",
"backprop_gelu_approx",
"backprop_hard_sigmoid",
Expand Down Expand Up @@ -1160,6 +1165,16 @@ def test_gelu_approx(ops, X):
assert not ops.xp.isnan(Y).any()


@pytest.mark.parametrize("ops", ALL_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
def test_dish(ops, X):
X = ops.asarray(X)
Y = ops.dish(X)
assert Y.shape == X.shape
assert not ops.xp.isnan(Y).any()


@pytest.mark.parametrize("ops", ALL_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
Expand Down Expand Up @@ -1350,8 +1365,8 @@ def test_ngrams():
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("torch_func", TORCH_FUNCS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(x=strategies.floats(min_value=-30, max_value=30))
def test_compare_activations_to_torch(ops, dtype, x, torch_func):
@given(x=strategies.floats(min_value=-30, max_value=30), dY=strategies.floats(min_value=-1, max_value=1))
def test_compare_activations_to_torch(ops, dtype, x, dY, torch_func):
import torch

func_name, pytorch_func = torch_func
Expand All @@ -1369,9 +1384,9 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
y_think_inplace = forward(x_thinc, inplace=True)
assert y_think_inplace is x_thinc
assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06)
assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-06)
assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-05)
x_thinc = ops.asarray([x], dtype=dtype)
dY_thinc = ops.asarray([1.0], dtype=dtype)
dY_thinc = ops.asarray([dY], dtype=dtype)
dY_thinc_inplace = dY_thinc.copy()

s = inspect.signature(backward)
Expand All @@ -1386,23 +1401,23 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func):
)
assert dx_thinc_inplace is dY_thinc_inplace
assert ops.xp.isclose(dx_thinc, dx_thinc_inplace)
assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06)
assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06)
elif params == {"Y", "dY"}:
dx_thinc = backward(dY_thinc, Y=y_thinc)
assert dx_thinc.dtype == x_thinc.dtype
assert ops.xp.isclose(
dx_thinc,
backward(dY=dY_thinc_inplace, Y=y_thinc, inplace=True),
)
assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06)
assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06)
elif params == {"dY", "X"}:
dx_thinc = backward(dY_thinc, X=x_thinc)
assert dx_thinc.dtype == x_thinc.dtype
assert ops.xp.isclose(
dx_thinc, backward(dY=dY_thinc_inplace, X=x_thinc, inplace=True)
)
assert ops.xp.isclose(
x_torch.grad.item(), float(backward(dY_thinc, X=x_thinc)), atol=1e-06
x_torch.grad.item() * dY, float(backward(dY_thinc, X=x_thinc)), atol=1e-06
)
else:
raise NotImplementedError(
Expand Down
41 changes: 41 additions & 0 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,47 @@ Backpropagate the Swish activation
| `inplace` | <tt>bool</tt> | If `True`, the `dY` array is modified in place. |
| **RETURNS** | <tt>FloatsXd</tt> | The gradient of the input. |

### Ops.dish {#dish tag="method" new="8.1.1"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** <i name="no"></i>
- **cupy:** <i name="yes"></i>

</inline-list>

Dish or "Daniël's Swish-like activation" is an activation function with a
similar shape to Swish or GELU. However, Dish does not rely on elementary
functions like `exp` or `erf`, making it much [faster to
compute](https://twitter.com/danieldekok/status/1484898130441166853) in most
cases.

| Argument | Type | Description |
| ----------- | ----------------- | ------------------------------------------ |
| `X` | <tt>FloatsXd</tt> | The inputs. |
| `inplace` | <tt>bool</tt> | If `True`, the array is modified in place. |
| **RETURNS** | <tt>FloatsXd</tt> | The outputs. |

### Ops.backprop_dish {#backprop_dish tag="method" new="8.1.1"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** <i name="no"></i>
- **cupy:** <i name="yes"></i>

</inline-list>

Backpropagate the Dish activation.

| Argument | Type | Description |
| ----------- | ----------------- | ----------------------------------------------- |
| `dY` | <tt>FloatsXd</tt> | Gradients of the output array. |
| `X` | <tt>FloatsXd</tt> | The inputs to the forward pass. |
| `inplace` | <tt>bool</tt> | If `True`, the `dY` array is modified in place. |
| **RETURNS** | <tt>FloatsXd</tt> | The gradient of the input. |

### Ops.gelu {#gelu tag="method"}

<inline-list>
Expand Down