Skip to content

Ops: replace FloatsType by constrained typevar #720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 71 additions & 73 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

from typing import Optional, List, Tuple, Sequence, Union, cast, TypeVar
from typing import Optional, List, Tuple, Sequence, Type, Union, cast, TypeVar
from typing import Iterator, overload
import numpy
import itertools
Expand All @@ -9,14 +9,14 @@
from ..types import Floats1d, Floats2d, Floats3d, Floats4d
from ..types import Array1d, Array2d, Array3d, Array4d, ListXd
from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats
from ..types import FloatsXdT
from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator
from ..util import get_array_module, is_xp_array, to_numpy

from .cblas import CBlas

ArrayT = TypeVar("ArrayT", bound=ArrayXd)
FloatsT = TypeVar("FloatsT", bound=_Floats)
FloatsType = TypeVar("FloatsType", bound=FloatsXd)
SQRT2PI = math.sqrt(2.0 / math.pi)
INV_SQRT2 = 1.0 / math.sqrt(2.0)
INV_SQRT_2PI = 1.0 / math.sqrt(2.0 * math.pi)
Expand Down Expand Up @@ -721,29 +721,29 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT:
kwargs = {"dtype": dtype} if dtype is not None else {}
return self.xp.ascontiguousarray(data, **kwargs)

def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
def sigmoid(self, X: FloatsXdT, *, inplace: bool = False) -> FloatsXdT:
if inplace:
# To prevent overflows and help with regularization/numerical stability
X = self.xp.clip(X, -20.0, 20.0, out=X)
self.xp.exp(-X, out=X)
X += 1.0 # type: ignore[assignment]
X **= -1.0 # type: ignore[assignment]
return cast(FloatsType, X)
X += 1.0
X **= -1.0
return X
else:
X = self.xp.clip(X, -20.0, 20.0)
return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X)))
return 1.0 / (1.0 + self.xp.exp(-X))

def backprop_sigmoid(
self, dY: FloatsType, Y: FloatsType, *, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, Y: FloatsXdT, *, inplace: bool = False
) -> FloatsXdT:
if inplace:
self.dsigmoid(Y, inplace=True)
Y *= dY # type: ignore
Y *= dY
return Y
else:
return dY * self.dsigmoid(Y, inplace=inplace) # type: ignore
return dY * self.dsigmoid(Y, inplace=inplace)

def dsigmoid(self, Y: FloatsType, *, inplace: bool = False) -> FloatsType:
def dsigmoid(self, Y: FloatsXdT, *, inplace: bool = False) -> FloatsXdT:
if inplace:
Y *= 1 - Y
return Y
Expand Down Expand Up @@ -864,30 +864,30 @@ def backprop_relu(

def clipped_linear(
self,
X: FloatsType,
X: FloatsXdT,
slope: float = 1.0,
offset: float = 0.0,
min_val: float = 0.0,
max_val: float = 1.0,
inplace: bool = False,
) -> FloatsType:
) -> FloatsXdT:
if inplace:
X *= slope # type: ignore[assignment]
X += offset # type: ignore[assignment]
return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X))
out = X * slope + offset # type: ignore[assignment]
return cast(FloatsType, self.xp.clip(out, min_val, max_val))
X *= slope
X += offset
return self.xp.clip(X, min_val, max_val, out=X)
out = X * slope + offset
return self.xp.clip(out, min_val, max_val)

def backprop_clipped_linear(
self,
dY: FloatsType,
X: FloatsType,
dY: FloatsXdT,
X: FloatsXdT,
slope: float = 1.0,
offset: float = 0.0,
min_val: float = 0.0,
max_val: float = 1.0,
inplace: bool = False,
) -> FloatsType:
) -> FloatsXdT:
low = (min_val - offset) / slope
high = (max_val - offset) / slope
slope = self.xp.float64(slope).astype(X.dtype)
Expand All @@ -898,60 +898,58 @@ def backprop_clipped_linear(
return dY
return dY * dX

def relu_k(
self, X: FloatsType, n: float = 6.0, inplace: bool = False
) -> FloatsType:
def relu_k(self, X: FloatsXdT, n: float = 6.0, inplace: bool = False) -> FloatsXdT:
return self.clipped_linear(X, max_val=n, inplace=inplace)

def backprop_relu_k(
self, dY: FloatsType, X: FloatsType, n: float = 6.0, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, n: float = 6.0, inplace: bool = False
) -> FloatsXdT:
return self.backprop_clipped_linear(dY, X, max_val=n, inplace=inplace)

def hard_sigmoid(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def hard_sigmoid(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
return self.clipped_linear(X, slope=0.2, offset=0.5, inplace=inplace)

def backprop_hard_sigmoid(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
return self.backprop_clipped_linear(dY, X, slope=0.2, offset=0.5)

def hard_tanh(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def hard_tanh(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
return self.clipped_linear(X, min_val=-1.0, max_val=1.0, inplace=inplace)

def backprop_hard_tanh(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
return self.backprop_clipped_linear(dY, X, min_val=-1.0, max_val=1.0)

def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
if inplace:
X *= self.sigmoid(X) # type: ignore[operator, assignment]
return cast(FloatsType, X)
out = X * self.sigmoid(X) # type: ignore[operator]
return cast(FloatsType, out)
X *= self.sigmoid(X)
return X
out = X * self.sigmoid(X)
return out

def backprop_swish(
self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False
) -> FloatsType:
Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator]
self, dY: FloatsXdT, X: FloatsXdT, Y: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
Y = Y + self.sigmoid(X) * (1 - Y)
if inplace:
dY *= Y # type: ignore[operator, assignment]
return cast(FloatsType, dY)
out = dY * Y # type: ignore[operator]
return cast(FloatsType, out)
dY *= Y
return dY
out = dY * Y
return out

# Following https://www.scitepress.org/Papers/2019/74696/74696.pdf
def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def hard_swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
if inplace:
X *= self.hard_sigmoid(X) # type: ignore[operator, assignment]
return cast(FloatsType, X)
out = X * self.hard_sigmoid(X) # type: ignore[operator]
return cast(FloatsType, out)
X *= self.hard_sigmoid(X)
return X
out = X * self.hard_sigmoid(X)
return out

def backprop_hard_swish(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
dX = X * 0.4 + 0.5
dX[X > 2.5] = 1.0
dX[X < -2.5] = 0
Expand All @@ -961,15 +959,15 @@ def backprop_hard_swish(
return dY * dX

# From https://arxiv.org/pdf/1905.02244v5.pdf
def hard_swish_mobilenet(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def hard_swish_mobilenet(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
if inplace:
X *= self.relu_k(X + 3) / 6
return X
return X * (self.relu_k(X + 3) / 6)

def backprop_hard_swish_mobilenet(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
dX = (1 / 6) * (X * 2.0 + 3.0)
dX[X > 3.0] = 1.0
dX[X < -3.0] = 0
Expand All @@ -980,7 +978,7 @@ def backprop_hard_swish_mobilenet(

# Code snippet taken from:
# https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
def erf(self, X: FloatsType) -> FloatsType:
def erf(self, X: FloatsXdT) -> FloatsXdT:
# save the sign of x
sign = self.xp.sign(X)
X = self.xp.abs(X)
Expand All @@ -1000,12 +998,12 @@ def erf(self, X: FloatsType) -> FloatsType:
out = out.astype(X.dtype)
return out

def sechsq(self, X: FloatsType) -> FloatsType:
def sechsq(self, X: FloatsXdT) -> FloatsXdT:
# Avoid overflow in cosh. Clipping at |20| has an error of 1.7e-17.
X = self.xp.clip(X, -20.0, 20.0)
return (1 / self.xp.cosh(X)) ** 2

def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def gelu_approx(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
tmp = 1.0 + self.xp.tanh(SQRT2PI * (X + 0.044715 * self.xp.power(X, 3)))
tmp *= 0.5
tmp = tmp.astype(X.dtype)
Expand All @@ -1018,9 +1016,9 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
return Y

def backprop_gelu_approx(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
dX = self.alloc_f(X.shape)
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
dX = cast(FloatsXdT, self.alloc_f(X.shape))
Xp3 = self.xp.power(X, 3)
tmp = 0.5 * self.xp.tanh(0.0356774 * Xp3 + 0.797885 * X)
tmp += (0.0535161 * Xp3 + 0.398942 * X) * self.sechsq(
Expand All @@ -1033,27 +1031,27 @@ def backprop_gelu_approx(
return dY
return dY * dX

def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType:
def gelu(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT:
# GELU(x) = x · Φ(x)
cdf = gaussian_cdf(self, X)
if inplace:
X *= cdf # type: ignore[operator, assignment]
X *= cdf
return X
return X * cdf # type: ignore[operator, return-value]
return X * cdf

def backprop_gelu(
self, dY: FloatsType, X: FloatsType, inplace: bool = False
) -> FloatsType:
self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False
) -> FloatsXdT:
# GELU'(x) = Φ(x) + x · PDF(x)
dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator]
dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X)
if inplace:
dY *= dX
return dY
return dY * dX

def mish(
self, X: FloatsType, threshold: float = 20.0, inplace: bool = False
) -> FloatsType:
self, X: FloatsXdT, threshold: float = 20.0, inplace: bool = False
) -> FloatsXdT:
tmp = X * self.xp.tanh(self.xp.log(1.0 + self.xp.exp(X)))
Y = self.xp.where(X >= threshold, X, tmp)
if inplace:
Expand All @@ -1064,11 +1062,11 @@ def mish(

def backprop_mish(
self,
dY: FloatsType,
dY: FloatsXdT,
X: Floats2d,
threshold: float = 20.0,
inplace: bool = False,
) -> FloatsType:
) -> FloatsXdT:
if dY.shape != X.shape:
msg = f"arrays have incompatible shapes: {dY.shape} and {X.shape}"
raise ValueError(msg)
Expand Down Expand Up @@ -1614,12 +1612,12 @@ def dtanh(Y: ArrayT) -> ArrayT:
return 1 - Y**2


def gaussian_cdf(ops: Ops, X: FloatsType) -> FloatsType:
def gaussian_cdf(ops: Ops, X: FloatsXdT) -> FloatsXdT:
"""Gaussian CDF for distribution with mean 0 and stdev 1."""
return 0.5 * (1.0 + ops.erf(INV_SQRT2 * X))


def gaussian_pdf(ops: Ops, X: FloatsType) -> FloatsType:
def gaussian_pdf(ops: Ops, X: FloatsXdT) -> FloatsXdT:
"""Gaussian PDF for distribution with mean 0 and stdev 1."""
return INV_SQRT_2PI * ops.xp.exp(-0.5 * X * X)

Expand Down
16 changes: 8 additions & 8 deletions thinc/layers/sigmoid_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

from ..model import Model
from ..config import registry
from ..types import FloatsXd


InT = TypeVar("InT", bound=FloatsXd)
from ..types import FloatsXdT


@registry.layers("sigmoid_activation.v1")
def sigmoid_activation() -> Model[InT, InT]:
def sigmoid_activation() -> Model[FloatsXdT, FloatsXdT]:
return Model("sigmoid_activation", forward)


def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
def forward(
model: Model[FloatsXdT, FloatsXdT], X: FloatsXdT, is_train: bool
) -> Tuple[FloatsXdT, Callable]:
Y = model.ops.sigmoid(X, inplace=False)

def backprop(dY: InT) -> InT:
def backprop(dY: FloatsXdT) -> FloatsXdT:
return cast(
InT, dY * model.ops.dsigmoid(Y, inplace=False) # type:ignore[operator]
FloatsXdT,
dY * model.ops.dsigmoid(Y, inplace=False), # type:ignore[operator]
)

return Y, backprop
1 change: 1 addition & 0 deletions thinc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ArrayT = TypeVar("ArrayT")
SelfT = TypeVar("SelfT")
Array1dT = TypeVar("Array1dT", bound="Array1d")
FloatsXdT = TypeVar("FloatsXdT", "Floats1d", "Floats2d", "Floats3d", "Floats4d")

# These all behave the same as far as indexing is concerned
Slicish = Union[slice, List[int], "ArrayXd"]
Expand Down