Skip to content

Commit f6eee9a

Browse files
Add bound types throughout layers and Ops (#599)
* Fixed typing problems mostly in with... methods * Sorted out flatten and unflatten * Iterable and Concatenatable types * Corrections * More corrections * Corrections to layers * Moved type definitions from types to ops * Updated documentation * Simplified ops type declarations * Fixed mypy backwards compatibility issue * Correct type-ignore comment * Updated Mypy version in azure-pipelines * Added CI checks with Python 3.7 * Any as first parameter of with_... layers * Revert "Any as first parameter of with_... layers" This reverts commit aa55834. * Tidied up init methods * Removed unnecessary imports * Put import statement on one line * Changes based on PR review comments * Improvements after PR feedback * Went through ignore statements in layers * Removed unnecessary covariance * Improvements based on PR review * Remove Python 3.7 additions * Reverted lstm_tagger.py changes * Added ArrayTXd_co * Final changes before review * Cast in main rather than in type-specific forward methods * Added empty line * Corrections * More corrections * Corrections * Returned to ListXd types * More corrections * Further corrections * Corrected model typing * Further corrections * Corrections * Tidying up * Corrections * Removed line * Made imports clearer * Readded line * Reformatted * Readded line * Corrected residual.py * Changed imports back to original order * Changes in response to review comments * Update thinc/layers/dropout.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Update thinc/layers/embed.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Changes responding to Github review * Reversed changes to init() return types * Reversed changes to init() return types * Corrected embed.py and hashembed.py * Corrections based on Github review * Fixed chain.py * Further correction to chain.py * Removed unnecessary cast * Updated documentation * Changes based on review * Added @overload signatures in ops * Added comment * Changes based on review comments * Final corrections * Bumped mypy version * Changes based on review comments * Added space to trigger CI * Corrected Pydantic version ranges * Fixed mypy version range * Correct documentation for clone Co-authored-by: Sofie Van Landeghem <[email protected]>
1 parent 4fc5ded commit f6eee9a

38 files changed

+581
-444
lines changed

azure-pipelines.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ jobs:
6161
displayName: 'Build sdist'
6262
6363
- script: |
64-
python -m pip install mypy==0.910
6564
python -m mypy thinc
6665
displayName: 'Run mypy'
6766

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ wasabi>=0.8.1,<1.1.0
88
catalogue>=2.0.4,<2.1.0
99
ml_datasets>=0.2.0,<0.3.0
1010
# Third-party dependencies
11-
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
11+
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
1212
numpy>=1.15.0
1313
# Backports of modern Python features
1414
dataclasses>=0.6,<1.0; python_version < "3.7"
@@ -22,8 +22,7 @@ pytest-cov>=2.7.0,<2.8.0
2222
coverage>=5.0.0,<6.0.0
2323
mock>=2.0.0,<3.0.0
2424
flake8>=3.5.0,<3.6.0
25-
# restricting mypy until faster 3.10 wheels are available
26-
mypy>=0.901,<0.920; python_version < "3.10"
25+
mypy>=0.901,<0.960
2726
types-mock>=0.1.1
2827
types-contextvars>=0.1.2; python_version < "3.7"
2928
types-dataclasses>=0.1.3; python_version < "3.7"

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ python_requires = >=3.6
3131
setup_requires =
3232
cython>=0.25,<3.0
3333
numpy>=1.15.0
34-
# We also need our Cython packages here to compile against
34+
# We also need our Cython packages here to compile against
3535
cymem>=2.0.2,<2.1.0
3636
preshed>=3.0.2,<3.1.0
3737
murmurhash>=1.0.2,<1.1.0
@@ -48,7 +48,7 @@ install_requires =
4848
# Third-party dependencies
4949
setuptools
5050
numpy>=1.15.0
51-
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
51+
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
5252
# Backports of modern Python features
5353
dataclasses>=0.6,<1.0; python_version < "3.7"
5454
typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8"

thinc/backends/ops.py

Lines changed: 115 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import numpy
66
import itertools
77

8-
from .. import registry
98
from ..types import Xp, Shape, DTypes, DTypesInt, DTypesFloat, List2d, ArrayXd
10-
from ..types import Array3d, Floats1d, Floats2d, Floats3d, Floats4d
9+
from ..types import Floats1d, Floats2d, Floats3d, Floats4d
10+
from ..types import Array1d, Array2d, Array3d, Array4d, ListXd
1111
from ..types import FloatsXd, Ints1d, Ints2d, Ints3d, Ints4d, IntsXd, _Floats
1212
from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator
1313
from ..util import get_array_module, is_xp_array, to_numpy
@@ -135,13 +135,11 @@ def _get_batch(self, sequence, indices):
135135
if isinstance(sequence, list):
136136
subseq = [sequence[i] for i in indices]
137137
elif isinstance(sequence, tuple):
138-
subseq = tuple(sequence[i] for i in indices) # type: ignore
138+
subseq = tuple(sequence[i] for i in indices)
139139
else:
140-
subseq = sequence[indices] # type: ignore
140+
subseq = sequence[indices]
141141
if is_xp_array(subseq):
142-
subseq = self.as_contig(
143-
cast(ArrayXd, self.xp.asarray(subseq))
144-
) # type: ignore
142+
subseq = self.as_contig(self.xp.asarray(subseq))
145143
return subseq
146144

147145
def _get_batch_sizes(self, length: int, sizes: Iterator[int]):
@@ -225,13 +223,65 @@ def affine(self, X: Floats2d, W: Floats2d, b: Floats1d) -> Floats2d:
225223
Y += b
226224
return Y
227225

226+
@overload
228227
def flatten(
229228
self,
230-
X: Sequence[ArrayT],
229+
X: List[Floats2d],
231230
dtype: Optional[DTypes] = None,
232231
pad: int = 0,
233232
ndim_if_empty: int = 2,
234-
) -> ArrayT:
233+
) -> Floats2d:
234+
...
235+
236+
@overload
237+
def flatten(
238+
self,
239+
X: List[Ints1d],
240+
dtype: Optional[DTypes] = None,
241+
pad: int = 0,
242+
ndim_if_empty: int = 2,
243+
) -> Ints1d:
244+
...
245+
246+
@overload
247+
def flatten(
248+
self,
249+
X: List2d,
250+
dtype: Optional[DTypes] = None,
251+
pad: int = 0,
252+
ndim_if_empty: int = 2,
253+
) -> Array2d:
254+
...
255+
256+
# further specific typed signatures can be added as necessary
257+
258+
@overload
259+
def flatten(
260+
self,
261+
X: ListXd,
262+
dtype: Optional[DTypes] = None,
263+
pad: int = 0,
264+
ndim_if_empty: int = 2,
265+
) -> ArrayXd:
266+
...
267+
268+
@overload
269+
def flatten(
270+
self,
271+
X: Sequence[ArrayXd],
272+
dtype: Optional[DTypes] = None,
273+
pad: int = 0,
274+
ndim_if_empty: int = 2,
275+
) -> ArrayXd:
276+
...
277+
278+
def flatten(
279+
self,
280+
X: Sequence[ArrayXd],
281+
dtype: Optional[DTypes] = None,
282+
pad: int = 0,
283+
ndim_if_empty: int = 2,
284+
) -> ArrayXd:
235285
"""Flatten a list of arrays into one large array."""
236286
if X is None or len(X) == 0:
237287
return self.alloc((0,) * ndim_if_empty, dtype=dtype or "f")
@@ -252,7 +302,25 @@ def flatten(
252302
result = xp.asarray(result, dtype=dtype)
253303
return result
254304

305+
@overload
255306
def unflatten(self, X: Floats2d, lengths: Ints1d, pad: int = 0) -> List[Floats2d]:
307+
...
308+
309+
@overload
310+
def unflatten(self, X: Ints1d, lengths: Ints1d, pad: int = 0) -> List[Ints1d]:
311+
...
312+
313+
@overload
314+
def unflatten(self, X: Array2d, lengths: Ints1d, pad: int = 0) -> List2d:
315+
...
316+
317+
# further specific typed signatures can be added as necessary
318+
319+
@overload
320+
def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd:
321+
...
322+
323+
def unflatten(self, X: ArrayXd, lengths: Ints1d, pad: int = 0) -> ListXd:
256324
"""The reverse/backward operation of the `flatten` function: unflatten
257325
a large array into a list of arrays according to the given lengths.
258326
"""
@@ -302,7 +370,7 @@ def pad( # noqa: F811
302370
output: Array3d = self.alloc(final_shape, dtype=seqs[0].dtype)
303371
for i, arr in enumerate(seqs):
304372
# It's difficult to convince this that the dtypes will match.
305-
output[i, : arr.shape[0]] = arr # type: ignore
373+
output[i, : arr.shape[0]] = arr # type: ignore[assignment, call-overload]
306374
return output
307375

308376
def unpad(self, padded: Array3d, lengths: List[int]) -> List2d:
@@ -314,14 +382,14 @@ def unpad(self, padded: Array3d, lengths: List[int]) -> List2d:
314382
output.append(padded[i, :length])
315383
return cast(List2d, output)
316384

317-
def list2padded(self, seqs: List[Floats2d]) -> Padded:
385+
def list2padded(self, seqs: List2d) -> Padded:
318386
"""Pack a sequence of 2d arrays into a Padded datatype."""
319387
if not seqs:
320388
return Padded(
321389
self.alloc3f(0, 0, 0), self.alloc1i(0), self.alloc1i(0), self.alloc1i(0)
322390
)
323391
elif len(seqs) == 1:
324-
data = self.reshape3f(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1])
392+
data = self.reshape3(seqs[0], seqs[0].shape[0], 1, seqs[0].shape[1])
325393
size_at_t = self.asarray1i([1] * data.shape[0])
326394
lengths = self.asarray1i([data.shape[0]])
327395
indices = self.asarray1i([0])
@@ -336,8 +404,8 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
336404
# Reorder the sequences, by length. This looks the same in either
337405
# direction: you're swapping elements between their original and sorted
338406
# position.
339-
seqs = [seqs[i] for i in indices_]
340-
arr: Floats3d = self.pad(seqs)
407+
seqs = cast(List2d, [seqs[i] for i in indices_])
408+
arr: Array3d = self.pad(seqs)
341409
assert arr.shape == (nB, nS, nO), (nB, nS, nO)
342410
arr = self.as_contig(arr.transpose((1, 0, 2)))
343411
assert arr.shape == (nS, nB, nO)
@@ -350,7 +418,7 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
350418
batch_size_at_t_[t] = current_size
351419
assert sum(lengths_) == sum(batch_size_at_t_)
352420
return Padded(
353-
cast(Floats3d, arr),
421+
arr,
354422
self.asarray1i(batch_size_at_t_),
355423
self.asarray1i(lengths_),
356424
self.asarray1i(indices_),
@@ -361,7 +429,7 @@ def padded2list(self, padded: Padded) -> List2d:
361429
data = padded.data
362430
indices = to_numpy(padded.indices)
363431
lengths = to_numpy(padded.lengths)
364-
unpadded: List[Optional[Floats2d]] = [None] * len(lengths)
432+
unpadded: List[Optional[Array2d]] = [None] * len(lengths)
365433
# Transpose from (length, batch, data) to (batch, length, data)
366434
data = self.as_contig(data.transpose((1, 0, 2)))
367435
for i in range(data.shape[0]):
@@ -500,6 +568,18 @@ def alloc(
500568
else:
501569
return self.xp.empty(shape, dtype=dtype)
502570

571+
def reshape1(self, array: ArrayXd, d0: int) -> Array1d:
572+
return cast(Array1d, self.reshape(array, (d0,)))
573+
574+
def reshape2(self, array: ArrayXd, d0: int, d1: int) -> Array2d:
575+
return cast(Array2d, self.reshape(array, (d0, d1)))
576+
577+
def reshape3(self, array: ArrayXd, d0: int, d1: int, d2: int) -> Array3d:
578+
return cast(Array3d, self.reshape(array, (d0, d1, d2)))
579+
580+
def reshape4(self, array: ArrayXd, d0: int, d1: int, d2: int, d3: int) -> Array4d:
581+
return cast(Array4d, self.reshape(array, (d0, d1, d2, d3)))
582+
503583
def reshape1f(self, array: FloatsXd, d0: int) -> Floats1d:
504584
return cast(Floats1d, self.reshape(array, (d0,)))
505585

@@ -619,7 +699,7 @@ def asarray(
619699
return self.xp.asarray(data, dtype=dtype)
620700
elif hasattr(data, "numpy"):
621701
# Handles PyTorch Tensor
622-
return data.numpy() # type: ignore
702+
return data.numpy() # type: ignore[union-attr]
623703
elif dtype is not None:
624704
return self.xp.array(data, dtype=dtype)
625705
else:
@@ -641,8 +721,8 @@ def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
641721

642722
if inplace:
643723
self.xp.exp(-X, out=X)
644-
X += 1.0 # type: ignore
645-
X **= -1.0 # type: ignore
724+
X += 1.0 # type: ignore[assignment]
725+
X **= -1.0 # type: ignore[assignment]
646726
return cast(FloatsType, X)
647727
else:
648728
return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X)))
@@ -786,10 +866,10 @@ def clipped_linear(
786866
inplace: bool = False,
787867
) -> FloatsType:
788868
if inplace:
789-
X *= slope # type: ignore
790-
X += offset # type: ignore
869+
X *= slope # type: ignore[assignment]
870+
X += offset # type: ignore[assignment]
791871
return cast(FloatsType, self.xp.clip(X, min_val, max_val, out=X))
792-
out = X * slope + offset # type: ignore
872+
out = X * slope + offset # type: ignore[assignment]
793873
return cast(FloatsType, self.xp.clip(out, min_val, max_val))
794874

795875
def backprop_clipped_linear(
@@ -840,27 +920,27 @@ def backprop_hard_tanh(
840920

841921
def swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
842922
if inplace:
843-
X *= self.sigmoid(X) # type: ignore
923+
X *= self.sigmoid(X) # type: ignore[operator, assignment]
844924
return cast(FloatsType, X)
845-
out = X * self.sigmoid(X) # type: ignore
925+
out = X * self.sigmoid(X) # type: ignore[operator]
846926
return cast(FloatsType, out)
847927

848928
def backprop_swish(
849929
self, dY: FloatsType, X: FloatsType, Y: FloatsType, inplace: bool = False
850930
) -> FloatsType:
851-
Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore
931+
Y = Y + self.sigmoid(X) * (1 - Y) # type: ignore[operator]
852932
if inplace:
853-
dY *= Y # type: ignore
933+
dY *= Y # type: ignore[operator, assignment]
854934
return cast(FloatsType, dY)
855-
out = dY * Y # type: ignore
935+
out = dY * Y # type: ignore[operator]
856936
return cast(FloatsType, out)
857937

858938
# Following https://www.scitepress.org/Papers/2019/74696/74696.pdf
859939
def hard_swish(self, X: FloatsType, inplace: bool = False) -> FloatsType:
860940
if inplace:
861-
X *= self.hard_sigmoid(X) # type: ignore
941+
X *= self.hard_sigmoid(X) # type: ignore[operator, assignment]
862942
return cast(FloatsType, X)
863-
out = X * self.hard_sigmoid(X) # type: ignore
943+
out = X * self.hard_sigmoid(X) # type: ignore[operator]
864944
return cast(FloatsType, out)
865945

866946
def backprop_hard_swish(
@@ -927,7 +1007,7 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
9271007
else:
9281008
Y = self.xp.array(X)
9291009
Y *= tmp
930-
return cast(FloatsType, Y)
1010+
return Y
9311011

9321012
def backprop_gelu_approx(
9331013
self, dY: FloatsType, X: FloatsType, inplace: bool = False
@@ -949,15 +1029,15 @@ def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType:
9491029
# GELU(x) = x · Φ(x)
9501030
cdf = gaussian_cdf(self, X)
9511031
if inplace:
952-
X *= cdf # type: ignore
1032+
X *= cdf # type: ignore[operator, assignment]
9531033
return X
954-
return X * cdf # type: ignore
1034+
return X * cdf # type: ignore[operator, return-value]
9551035

9561036
def backprop_gelu(
9571037
self, dY: FloatsType, X: FloatsType, inplace: bool = False
9581038
) -> FloatsType:
9591039
# GELU'(x) = Φ(x) + x · PDF(x)
960-
dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore
1040+
dX = gaussian_cdf(self, X) + X * gaussian_pdf(self, X) # type: ignore[operator]
9611041
if inplace:
9621042
dY *= dX
9631043
return dY
@@ -1239,8 +1319,8 @@ def lstm_forward_training(
12391319
for d in range(dirs):
12401320
# The inits are shaped (depth, dirs, nO). We add the internal dimension
12411321
# to make them set correctly.
1242-
Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore
1243-
Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore
1322+
Yt2 = h_init[i, d].reshape((1, nO)) # type: ignore[assignment]
1323+
Ct2 = c_init[i, d].reshape((1, nO)) # type: ignore[assignment]
12441324
layer_params, params_i = _split_weights(params, i, nO, nI, params_i)
12451325
Wx, Wh, bias = _transpose_weights(layer_params)
12461326
G[i, d] += xp.dot(X, Wx.T)

thinc/config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type
1+
from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping
22
from typing import Iterable, Sequence, cast
33
from types import GeneratorType
44
from dataclasses import dataclass
@@ -550,7 +550,7 @@ def __init__(
550550
self,
551551
*,
552552
config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None,
553-
errors: Iterable[Dict[str, Any]] = tuple(),
553+
errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(),
554554
title: Optional[str] = "Config validation error",
555555
desc: Optional[str] = None,
556556
parent: Optional[str] = None,
@@ -560,9 +560,10 @@ def __init__(
560560
561561
config (Union[Config, Dict[str, Dict[str, Any]], str]): The
562562
config the validation error refers to.
563-
errors (Iterable[Dict[str, Any]]): A list of errors as dicts with keys
564-
"loc" (list of strings describing the path of the value), "msg"
565-
(validation message to show) and optional "type" (mostly internals).
563+
errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]):
564+
A list of errors as dicts with keys "loc" (list of strings
565+
describing the path of the value), "msg" (validation message
566+
to show) and optional "type" (mostly internals).
566567
Same format as produced by pydantic's validation error (e.errors()).
567568
title (str): The error title.
568569
desc (str): Optional error description, displayed below the title.

thinc/layers/array_getitem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Union, Sequence, Tuple
1+
from typing import Union, Sequence, Tuple, TypeVar
22
from ..types import ArrayXd, FloatsXd, IntsXd
33
from ..model import Model
44

55

66
AxisIndex = Union[int, slice, Sequence[int]]
77
Index = Union[AxisIndex, Tuple[AxisIndex, ...]]
8+
ArrayTXd = TypeVar("ArrayTXd", bound=ArrayXd)
89

910

10-
def array_getitem(index: Index) -> Model[ArrayXd, ArrayXd]:
11+
def array_getitem(index: Index) -> Model[ArrayTXd, ArrayTXd]:
1112
"""Index into input arrays, and return the subarrays.
1213
1314
index:

0 commit comments

Comments
 (0)