diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index 6b2eb36a9..c9fb10aae 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -1,6 +1,6 @@ import math -from typing import Optional, List, Tuple, Sequence, Union, cast, TypeVar +from typing import Optional, List, Tuple, Sequence, Type, Union, cast, TypeVar from typing import Iterator, overload import numpy import itertools @@ -9,6 +9,7 @@ from ..types import Floats1d, Floats2d, Floats3d, Floats4d from ..types import Array1d, Array2d, Array3d, Array4d, ListXd from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats +from ..types import FloatsXdT from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator from ..util import get_array_module, is_xp_array, to_numpy @@ -16,7 +17,6 @@ ArrayT = TypeVar("ArrayT", bound=ArrayXd) FloatsT = TypeVar("FloatsT", bound=_Floats) -FloatsType = TypeVar("FloatsType", bound=FloatsXd) SQRT2PI = math.sqrt(2.0 / math.pi) INV_SQRT2 = 1.0 / math.sqrt(2.0) INV_SQRT_2PI = 1.0 / math.sqrt(2.0 * math.pi) @@ -721,29 +721,29 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT: kwargs = {"dtype": dtype} if dtype is not None else {} return self.xp.ascontiguousarray(data, **kwargs) - def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType: + def sigmoid(self, X: FloatsXdT, *, inplace: bool = False) -> FloatsXdT: if inplace: # To prevent overflows and help with regularization/numerical stability X = self.xp.clip(X, -20.0, 20.0, out=X) self.xp.exp(-X, out=X) - X += 1.0 # type: ignore[assignment] - X **= -1.0 # type: ignore[assignment] - return cast(FloatsType, X) + X += 1.0 + X **= -1.0 + return X else: X = self.xp.clip(X, -20.0, 20.0) - return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X))) + return 1.0 / (1.0 + self.xp.exp(-X)) def backprop_sigmoid( - self, dY: FloatsType, Y: FloatsType, *, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, Y: FloatsXdT, *, inplace: bool = False + ) -> FloatsXdT: if inplace: self.dsigmoid(Y, inplace=True) - Y *= dY # type: ignore + Y *= dY return Y else: - return dY * self.dsigmoid(Y, inplace=inplace) # type: ignore + return dY * self.dsigmoid(Y, inplace=inplace) - def dsigmoid(self, Y: FloatsType, *, inplace: bool = False) -> FloatsType: + def dsigmoid(self, Y: FloatsXdT, *, inplace: bool = False) -> FloatsXdT: if inplace: Y *= 1 - Y return Y @@ -864,30 +864,30 @@ def backprop_relu( def clipped_linear( self, - X: FloatsType, + X: FloatsXdT, slope: float = 1.0, offset: float = 0.0, min_val: float = 0.0, max_val: float = 1.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: if inplace: - X *= slope # type: ignore[assignment] - X += offset # type: ignore[assignment] - return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X)) - out = X * slope + offset # type: ignore[assignment] - return cast(FloatsType, self.xp.clip(out, min_val, max_val)) + X *= slope + X += offset + return self.xp.clip(X, min_val, max_val, out=X) + out = X * slope + offset + return self.xp.clip(out, min_val, max_val) def backprop_clipped_linear( self, - dY: FloatsType, - X: FloatsType, + dY: FloatsXdT, + X: FloatsXdT, slope: float = 1.0, offset: float = 0.0, min_val: float = 0.0, max_val: float = 1.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: low = (min_val - offset) / slope high = (max_val - offset) / slope slope = self.xp.float64(slope).astype(X.dtype) @@ -898,60 +898,58 @@ def backprop_clipped_linear( return dY return dY * dX - def relu_k( - self, X: FloatsType, n: float = 6.0, inplace: bool = False - ) -> FloatsType: + def relu_k(self, X: FloatsXdT, n: float = 6.0, inplace: bool = False) -> FloatsXdT: return self.clipped_linear(X, max_val=n, inplace=inplace) def backprop_relu_k( - self, dY: FloatsType, X: FloatsType, n: float = 6.0, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, n: float = 6.0, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, max_val=n, inplace=inplace) - def hard_sigmoid(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_sigmoid(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: return self.clipped_linear(X, slope=0.2, offset=0.5, inplace=inplace) def backprop_hard_sigmoid( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, slope=0.2, offset=0.5) - def hard_tanh(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_tanh(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: return self.clipped_linear(X, min_val=-1.0, max_val=1.0, inplace=inplace) def backprop_hard_tanh( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: return self.backprop_clipped_linear(dY, X, min_val=-1.0, max_val=1.0) - def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: - X *= self.sigmoid(X) # type: ignore[operator, assignment] - return cast(FloatsType, X) - out = X * self.sigmoid(X) # type: ignore[operator] - return cast(FloatsType, out) + X *= self.sigmoid(X) + return X + out = X * self.sigmoid(X) + return out def backprop_swish( - self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False - ) -> FloatsType: - Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator] + self, dY: FloatsXdT, X: FloatsXdT, Y: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + Y = Y + self.sigmoid(X) * (1 - Y) if inplace: - dY *= Y # type: ignore[operator, assignment] - return cast(FloatsType, dY) - out = dY * Y # type: ignore[operator] - return cast(FloatsType, out) + dY *= Y + return dY + out = dY * Y + return out # Following https://www.scitepress.org/Papers/2019/74696/74696.pdf - def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_swish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: - X *= self.hard_sigmoid(X) # type: ignore[operator, assignment] - return cast(FloatsType, X) - out = X * self.hard_sigmoid(X) # type: ignore[operator] - return cast(FloatsType, out) + X *= self.hard_sigmoid(X) + return X + out = X * self.hard_sigmoid(X) + return out def backprop_hard_swish( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: dX = X * 0.4 + 0.5 dX[X > 2.5] = 1.0 dX[X < -2.5] = 0 @@ -961,15 +959,15 @@ def backprop_hard_swish( return dY * dX # From https://arxiv.org/pdf/1905.02244v5.pdf - def hard_swish_mobilenet(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def hard_swish_mobilenet(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: if inplace: X *= self.relu_k(X + 3) / 6 return X return X * (self.relu_k(X + 3) / 6) def backprop_hard_swish_mobilenet( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: dX = (1 / 6) * (X * 2.0 + 3.0) dX[X > 3.0] = 1.0 dX[X < -3.0] = 0 @@ -980,7 +978,7 @@ def backprop_hard_swish_mobilenet( # Code snippet taken from: # https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ - def erf(self, X: FloatsType) -> FloatsType: + def erf(self, X: FloatsXdT) -> FloatsXdT: # save the sign of x sign = self.xp.sign(X) X = self.xp.abs(X) @@ -1000,12 +998,12 @@ def erf(self, X: FloatsType) -> FloatsType: out = out.astype(X.dtype) return out - def sechsq(self, X: FloatsType) -> FloatsType: + def sechsq(self, X: FloatsXdT) -> FloatsXdT: # Avoid overflow in cosh. Clipping at |20| has an error of 1.7e-17. X = self.xp.clip(X, -20.0, 20.0) return (1 / self.xp.cosh(X)) ** 2 - def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def gelu_approx(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: tmp = 1.0 + self.xp.tanh(SQRT2PI * (X + 0.044715 * self.xp.power(X, 3))) tmp *= 0.5 tmp = tmp.astype(X.dtype) @@ -1018,9 +1016,9 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: return Y def backprop_gelu_approx( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: - dX = self.alloc_f(X.shape) + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + dX = cast(FloatsXdT, self.alloc_f(X.shape)) Xp3 = self.xp.power(X, 3) tmp = 0.5 * self.xp.tanh(0.0356774 * Xp3 + 0.797885 * X) tmp += (0.0535161 * Xp3 + 0.398942 * X) * self.sechsq( @@ -1033,27 +1031,27 @@ def backprop_gelu_approx( return dY return dY * dX - def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType: + def gelu(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: # GELU(x) = x · Φ(x) cdf = gaussian_cdf(self, X) if inplace: - X *= cdf # type: ignore[operator, assignment] + X *= cdf return X - return X * cdf # type: ignore[operator, return-value] + return X * cdf def backprop_gelu( - self, dY: FloatsType, X: FloatsType, inplace: bool = False - ) -> FloatsType: + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: # GELU'(x) = Φ(x) + x · PDF(x) - dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator] + dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) if inplace: dY *= dX return dY return dY * dX def mish( - self, X: FloatsType, threshold: float = 20.0, inplace: bool = False - ) -> FloatsType: + self, X: FloatsXdT, threshold: float = 20.0, inplace: bool = False + ) -> FloatsXdT: tmp = X * self.xp.tanh(self.xp.log(1.0 + self.xp.exp(X))) Y = self.xp.where(X >= threshold, X, tmp) if inplace: @@ -1064,11 +1062,11 @@ def mish( def backprop_mish( self, - dY: FloatsType, + dY: FloatsXdT, X: Floats2d, threshold: float = 20.0, inplace: bool = False, - ) -> FloatsType: + ) -> FloatsXdT: if dY.shape != X.shape: msg = f"arrays have incompatible shapes: {dY.shape} and {X.shape}" raise ValueError(msg) @@ -1614,12 +1612,12 @@ def dtanh(Y: ArrayT) -> ArrayT: return 1 - Y**2 -def gaussian_cdf(ops: Ops, X: FloatsType) -> FloatsType: +def gaussian_cdf(ops: Ops, X: FloatsXdT) -> FloatsXdT: """Gaussian CDF for distribution with mean 0 and stdev 1.""" return 0.5 * (1.0 + ops.erf(INV_SQRT2 * X)) -def gaussian_pdf(ops: Ops, X: FloatsType) -> FloatsType: +def gaussian_pdf(ops: Ops, X: FloatsXdT) -> FloatsXdT: """Gaussian PDF for distribution with mean 0 and stdev 1.""" return INV_SQRT_2PI * ops.xp.exp(-0.5 * X * X) diff --git a/thinc/layers/sigmoid_activation.py b/thinc/layers/sigmoid_activation.py index 8b3982aea..b87261075 100644 --- a/thinc/layers/sigmoid_activation.py +++ b/thinc/layers/sigmoid_activation.py @@ -2,23 +2,23 @@ from ..model import Model from ..config import registry -from ..types import FloatsXd - - -InT = TypeVar("InT", bound=FloatsXd) +from ..types import FloatsXdT @registry.layers("sigmoid_activation.v1") -def sigmoid_activation() -> Model[InT, InT]: +def sigmoid_activation() -> Model[FloatsXdT, FloatsXdT]: return Model("sigmoid_activation", forward) -def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]: +def forward( + model: Model[FloatsXdT, FloatsXdT], X: FloatsXdT, is_train: bool +) -> Tuple[FloatsXdT, Callable]: Y = model.ops.sigmoid(X, inplace=False) - def backprop(dY: InT) -> InT: + def backprop(dY: FloatsXdT) -> FloatsXdT: return cast( - InT, dY * model.ops.dsigmoid(Y, inplace=False) # type:ignore[operator] + FloatsXdT, + dY * model.ops.dsigmoid(Y, inplace=False), # type:ignore[operator] ) return Y, backprop diff --git a/thinc/types.py b/thinc/types.py index 04b81946a..41a3d3fb7 100644 --- a/thinc/types.py +++ b/thinc/types.py @@ -46,6 +46,7 @@ ArrayT = TypeVar("ArrayT") SelfT = TypeVar("SelfT") Array1dT = TypeVar("Array1dT", bound="Array1d") +FloatsXdT = TypeVar("FloatsXdT", "Floats1d", "Floats2d", "Floats3d", "Floats4d") # These all behave the same as far as indexing is concerned Slicish = Union[slice, List[int], "ArrayXd"]