Skip to content
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .layers import CauchySimilarity, ParametricAttention, Logistic
from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear
from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid
from .layers import HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
from .layers import PyTorchWrapper_v2, Softmax_v2
Expand Down
1 change: 1 addition & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Weights layers
from .cauchysimilarity import CauchySimilarity
from .dish import Dish
from .dropout import Dropout
from .embed import Embed
from .expand_window import expand_window
Expand Down
66 changes: 66 additions & 0 deletions thinc/layers/dish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Tuple, Optional, Callable, cast

from ..config import registry
from ..model import Model
from .chain import chain
from .layernorm import LayerNorm
from .dropout import Dropout
from ..types import Floats1d, Floats2d
from ..util import partial, get_width
from ..initializers import he_normal_init, zero_init


@registry.layers("Dish.v1")
def Dish(
nO: Optional[int] = None,
nI: Optional[int] = None,
*,
init_W: Callable = he_normal_init,
init_b: Callable = zero_init,
dropout: Optional[float] = None,
normalize: bool = False,
) -> Model[Floats2d, Floats2d]:
model: Model[Floats2d, Floats2d] = Model(
"dish",
forward,
init=partial(init, init_W, init_b),
dims={"nO": nO, "nI": nI},
params={"W": None, "b": None},
)
if normalize:
model = chain(model, LayerNorm(nI=nO))
if dropout is not None:
model = chain(model, cast(Model[Floats2d, Floats2d], Dropout(dropout)))
return model


def forward(
model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool
) -> Tuple[Floats2d, Callable]:
W = cast(Floats2d, model.get_param("W"))
b = cast(Floats1d, model.get_param("b"))
Y_preact = model.ops.affine(X, W, b)
Y = model.ops.dish(Y_preact)

def backprop(dY: Floats2d) -> Floats2d:
dY = model.ops.backprop_dish(dY, X, inplace=False)
model.inc_grad("b", dY.sum(axis=0))
model.inc_grad("W", model.ops.gemm(dY, X, trans1=True))
return model.ops.gemm(dY, W)

return Y, backprop


def init(
init_W: Callable,
init_b: Callable,
model: Model[Floats2d, Floats2d],
X: Optional[Floats2d] = None,
Y: Optional[Floats2d] = None,
) -> None:
if X is not None:
model.set_dim("nI", get_width(X))
if Y is not None:
model.set_dim("nO", get_width(Y))
model.set_param("W", init_W(model.ops, (model.get_dim("nO"), model.get_dim("nI"))))
model.set_param("b", init_b(model.ops, (model.get_dim("nO"),)))
2 changes: 2 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def assert_data_match(Y, out_data):

TEST_CASES_SUMMABLE = [
# Array to array
("Dish.v1", {}, array2d, array2d),
("Dish.v1", {"nO": 4, "nI": 4}, array2d, array2d),
("Dropout.v1", {}, array2d, array2d),
("LayerNorm.v1", {}, array2d, array2d),
("Linear.v1", {}, array2d, array2d),
Expand Down
11 changes: 6 additions & 5 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -937,11 +937,12 @@ Backpropagate the Swish activation

</inline-list>

Dish or "Daniël's Swish-like activation" is an activation function with a
similar shape to Swish or GELU. However, Dish does not rely on elementary
functions like `exp` or `erf`, making it much [faster to
compute](https://twitter.com/danieldekok/status/1484898130441166853) in most
cases.
A dense layer with the Dish activation function. Dish or "Daniël's Swish-like
activation" is an activation function with a non-monotinic shape similar to
[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on
elementary functions like `exp` or `erf`, making it much
[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853)
in most cases.

| Argument | Type | Description |
| ----------- | ----------------- | ------------------------------------------ |
Expand Down
37 changes: 35 additions & 2 deletions website/docs/api-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,39 @@ Primarily used within [`siamese`](#siamese) neural networks.
https://github.com/explosion/thinc/blob/master/thinc/layers/cauchysimilarity.py
```

### Dish {#dish tag="function"}

<inline-list>

- **Input:** <ndarray shape="batch_size, nI">Floats2d</ndarray>
- **Output:** <ndarray shape="batch_size, nO">Floats2d</ndarray>
- **Parameters:** <ndarray shape="nO, nI">W</ndarray>,
<ndarray shape="nO,">b</ndarray>

</inline-list>

A dense layer with the Dish activation function. Dish or "Daniël's Swish-like
activation" is an activation function with a non-monotinic shape similar to
[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on
elementary functions like `exp` or `erf`, making it much
[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853)
in most cases.

| Argument | Type | Description |
| -------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------ |
| `nO` | <tt>Optional[int]</tt> | The size of the output vectors. |
| `nI` | <tt>Optional[int]</tt> | The size of the input vectors. |
| _keyword-only_ | | |
| `init_W` | <tt>Callable</tt> | A function to initialize the weights matrix. Defaults to [`he_normal_init`](/docs/api-initializers#he_normal_init) |
| `init_b` | <tt>Callable</tt> | A function to initialize the bias vector. Defaults to [`zero_init`](/docs/api-initializers#zero_init). |
| `dropout` | <tt>Optional[float]</tt> | Dropout rate to avoid overfitting. |
| `normalize` | <tt>bool</tt> | Whether or not to apply [layer normalization](#layernorm). Defaults to `False`. |
| **RETURNS** | <tt>Model[Floats2d, Floats2d]</tt> | The created dense layer. |

```python
https://github.com/explosion/thinc/blob/master/thinc/layers/dish.py
```

### Dropout {#dropout tag="function"}

<inline-list>
Expand Down Expand Up @@ -835,8 +868,8 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_last.py
</inline-list>

Pooling layer that reduces the dimensions of the data by selecting the maximum
value for each feature. A `ValueError` is raised if any element in `lengths`
is zero.
value for each feature. A `ValueError` is raised if any element in `lengths` is
zero.

| Argument | Type | Description |
| ----------- | -------------------------------- | -------------------------- |
Expand Down