Skip to content

Commit aa55834

Browse files
Any as first parameter of with_... layers
1 parent d269e00 commit aa55834

File tree

11 files changed

+55
-68
lines changed

11 files changed

+55
-68
lines changed

thinc/layers/with_array.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List
1+
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List, Any
22

33
from ..model import Model
44
from ..config import registry
@@ -62,14 +62,12 @@
6262

6363

6464
@registry.layers("with_array.v1")
65-
def with_array(
66-
layer: Model[ArrayXd_co, ArrayXd_co], pad: int = 0
67-
) -> Model[SeqT_co, SeqT_co]:
65+
def with_array(layer: Model[Any, ArrayXd_co], pad: int = 0) -> Model[Any, SeqT_co]:
6866
"""Transform sequence data into a contiguous 2d array on the way into and
6967
out of a model. Handles a variety of sequence types: lists, padded and ragged.
7068
If the input is a 2d array, it is passed through unchanged.
7169
"""
72-
model: Model[SeqT_co, SeqT_co] = Model(
70+
model: Model[Any, SeqT_co] = Model(
7371
f"with_array({layer.name})",
7472
forward,
7573
init=init,
@@ -81,7 +79,7 @@ def with_array(
8179

8280

8381
def forward(
84-
model: Model[SeqT_co, SeqT_co], Xseq: SeqT, is_train: bool
82+
model: Model[Any, SeqT_co], Xseq: SeqT, is_train: bool
8583
) -> Tuple[SeqT, Callable]:
8684
if isinstance(Xseq, Ragged):
8785
return _ragged_forward(
@@ -100,9 +98,9 @@ def forward(
10098

10199

102100
def init(
103-
model: Model[SeqT_co, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
101+
model: Model[Any, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
104102
) -> None:
105-
layer: Model[ArrayXd, ArrayXd] = model.layers[0]
103+
layer: Model[Any, ArrayXd] = model.layers[0]
106104
layer.initialize(
107105
X=_get_array(model, X) if X is not None else X,
108106
Y=_get_array(model, Y) if Y is not None else Y,

thinc/layers/with_array2d.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List
1+
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List, Any
22

33
from ..model import Model
44
from ..config import registry
@@ -18,9 +18,7 @@
1818

1919

2020
@registry.layers("with_array2d.v1")
21-
def with_array2d(
22-
layer: Model[ValT_co, ValT_co], pad: int = 0
23-
) -> Model[SeqT_co, SeqT_co]:
21+
def with_array2d(layer: Model[Any, ValT_co], pad: int = 0) -> Model[Any, SeqT_co]:
2422
"""Transform sequence data into a contiguous 2d array on the way into and
2523
out of a model. Handles a variety of sequence types: lists, padded and ragged.
2624
If the input is a 2d array, it is passed through unchanged.
@@ -36,7 +34,7 @@ def with_array2d(
3634

3735

3836
def forward(
39-
model: Model[SeqT_co, SeqT_co], Xseq: SeqT, is_train: bool
37+
model: Model[Any, SeqT_co], Xseq: SeqT, is_train: bool
4038
) -> Tuple[SeqT, Callable]:
4139
if isinstance(Xseq, Ragged):
4240
return _ragged_forward(
@@ -55,9 +53,9 @@ def forward(
5553

5654

5755
def init(
58-
model: Model[SeqT, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
56+
model: Model[Any, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
5957
) -> None:
60-
layer: Model[Array2d, Array2d] = model.layers[0]
58+
layer: Model[Any, Array2d] = model.layers[0]
6159
layer.initialize(
6260
X=_get_array(model, X) if X is not None else X,
6361
Y=_get_array(model, Y) if Y is not None else Y,
@@ -82,7 +80,7 @@ def _get_array(model, X: SeqT) -> Array2d:
8280

8381

8482
def _list_forward(
85-
model: Model[List[Array2d], List[Array2d]], Xs: List[Array2d], is_train: bool
83+
model: Model[Any, List[Array2d]], Xs: List[Array2d], is_train: bool
8684
) -> Tuple[SeqT, Callable]:
8785
layer = model.layers[0]
8886
pad = model.attrs["pad"]
@@ -99,7 +97,7 @@ def backprop(dYs: List[Array2d]) -> List[Array2d]:
9997

10098

10199
def _ragged_forward(
102-
model: Model[Ragged, Ragged], Xr: Ragged, is_train: bool
100+
model: Model[Any, Ragged], Xr: Ragged, is_train: bool
103101
) -> Tuple[SeqT, Callable]:
104102
layer: Model[Array2d, Array2d] = model.layers[0]
105103
Y, get_dX = layer(Xr.data, is_train)
@@ -112,7 +110,7 @@ def backprop(dYr: Ragged) -> Ragged:
112110

113111

114112
def _padded_forward(
115-
model: Model[Padded, Padded], Xp: Padded, is_train: bool
113+
model: Model[Any, Padded], Xp: Padded, is_train: bool
116114
) -> Tuple[SeqT, Callable]:
117115
layer: Model[Array2d, Array2d] = model.layers[0]
118116
X = model.ops.reshape2(

thinc/layers/with_cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def with_cpu_backprop(d_outputs):
3131
return gpu_outputs, with_cpu_backprop
3232

3333

34-
def init(model: Model, X: Any, Y: Any) -> Model:
35-
return model.layers[0].initialize(X, Y)
34+
def init(model: Model, X: Any, Y: Any) -> None:
35+
model.layers[0].initialize(X, Y)
3636

3737

3838
def _to_cpu(X):

thinc/layers/with_debug.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ def backprop(dY: Any) -> Any:
3434

3535
return layer_Y, backprop
3636

37-
def init(model: Model, X: Any, Y: Any) -> Model:
37+
def init(model: Model, X: Any, Y: Any) -> None:
3838
on_init(model, X, Y)
3939
if orig_init is not None:
40-
return orig_init(layer, X, Y)
41-
else:
42-
return layer
40+
orig_init(layer, X, Y)
4341

4442
layer.replace_callbacks(forward, init=init)
4543

thinc/layers/with_flatten.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313

1414
@registry.layers("with_flatten.v1")
15-
def with_flatten(layer: Model[InT, InT]) -> Model[OutT, OutT]:
15+
def with_flatten(layer: Model[Any, InT]) -> Model[Any, OutT]:
1616
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
1717

1818

1919
def forward(
20-
model: Model[OutT, OutT], Xnest: OutT, is_train: bool
20+
model: Model[Any, OutT], Xnest: OutT, is_train: bool
2121
) -> Tuple[OutT, Callable]:
22-
layer: Model[InT, InT] = model.layers[0]
22+
layer: Model[Any, InT] = model.layers[0]
2323
Xflat: Sequence[Any] = _flatten(Xnest)
2424
Yflat, backprop_layer = layer(Xflat, is_train)
2525
# Get the split points. We want n-1 splits for n items.
@@ -44,7 +44,7 @@ def _flatten(nested: InT) -> List[ItemT]:
4444

4545

4646
def init(
47-
model: Model[OutT, OutT], X: Optional[OutT] = None, Y: Optional[OutT] = None
47+
model: Model[Any, OutT], X: Optional[OutT] = None, Y: Optional[OutT] = None
4848
) -> None:
4949
model.layers[0].initialize(
5050
_flatten(X) if X is not None else None,

thinc/layers/with_getitem.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ def backprop(d_output: OutT) -> InT:
3737

3838
def init(
3939
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
40-
) -> Model[InT, OutT]:
40+
) -> None:
4141
idx = model.attrs["idx"]
4242
X_i = X[idx] if X is not None else X
4343
Y_i = Y[idx] if Y is not None else Y
4444
model.layers[0].initialize(X=X_i, Y=Y_i)
45-
return model

thinc/layers/with_list.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Callable, List, Optional, TypeVar, Union, cast, Sequence
1+
from typing import Tuple, Callable, List, Optional, TypeVar, Union, cast, Any
22

33
from ..types import Padded, Ragged, Floats2d, List2d, Array2d
44
from ..model import Model
@@ -12,7 +12,7 @@
1212

1313

1414
@registry.layers("with_list.v1")
15-
def with_list(layer: Model[List2d_co, List2d_co]) -> Model[SeqT_co, SeqT_co]:
15+
def with_list(layer: Model[Any, List2d_co]) -> Model[Any, SeqT_co]:
1616
return Model(
1717
f"with_list({layer.name})",
1818
forward,
@@ -23,9 +23,9 @@ def with_list(layer: Model[List2d_co, List2d_co]) -> Model[SeqT_co, SeqT_co]:
2323

2424

2525
def forward(
26-
model: Model[SeqT_co, SeqT_co], Xseq: SeqT, is_train: bool
26+
model: Model[Any, SeqT_co], Xseq: SeqT, is_train: bool
2727
) -> Tuple[SeqT, Callable]:
28-
layer: Model[List[Array2d], List[Array2d]] = model.layers[0]
28+
layer: Model[Any, List[Array2d]] = model.layers[0]
2929
Y: SeqT
3030
if isinstance(Xseq, Padded):
3131
Y, backprop = _padded_forward(layer, cast(Padded, Xseq), is_train)
@@ -38,13 +38,12 @@ def forward(
3838

3939

4040
def init(
41-
model: Model[SeqT_co, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
42-
) -> Model[SeqT_co, SeqT_co]:
41+
model: Model[Any, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
42+
) -> None:
4343
model.layers[0].initialize(
4444
X=_get_list(model, X) if X is not None else None,
4545
Y=_get_list(model, Y) if Y is not None else None,
4646
)
47-
return model
4847

4948

5049
def _get_list(model, seq):
@@ -57,7 +56,7 @@ def _get_list(model, seq):
5756

5857

5958
def _ragged_forward(
60-
layer: Model[List[Array2d], List[Array2d]], Xr: Ragged, is_train: bool
59+
layer: Model[Any, List[Array2d]], Xr: Ragged, is_train: bool
6160
) -> Tuple[SeqT, Callable]:
6261
# Assign these to locals, to keep code a bit shorter.
6362
unflatten = layer.ops.unflatten
@@ -81,7 +80,7 @@ def backprop(dYr: Ragged):
8180

8281

8382
def _padded_forward(
84-
layer: Model[List[Array2d], List[Array2d]], Xp: Padded, is_train: bool
83+
layer: Model[Any, List[Array2d]], Xp: Padded, is_train: bool
8584
) -> Tuple[SeqT, Callable]:
8685
# Assign these to locals, to keep code a bit shorter.
8786
padded2list = layer.ops.padded2list

thinc/layers/with_nvtx_range.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def backprop(dY: Any) -> Any:
3535

3636
return layer_Y, backprop
3737

38-
def init(_, X: Any, Y: Any) -> Model:
38+
def init(_, X: Any, Y: Any) -> None:
3939
if orig_init is not None:
40-
return orig_init(layer, X, Y)
41-
else:
42-
return layer
40+
orig_init(layer, X, Y)
4341

4442
layer.replace_callbacks(forward, init=init)
4543

thinc/layers/with_padded.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List
1+
from typing import Tuple, Callable, Optional, TypeVar, Union, cast, List, Any
22

33
from ..types import Padded, Ragged, Floats3d, Ints1d, Floats2d, Array2d, List2d
44
from ..model import Model
@@ -18,7 +18,7 @@
1818

1919

2020
@registry.layers("with_padded.v1")
21-
def with_padded(layer: Model[Padded, Padded]) -> Model[SeqT_co, SeqT_co]:
21+
def with_padded(layer: Model[Any, Padded]) -> Model[Any, SeqT_co]:
2222
return Model(
2323
f"with_padded({layer.name})",
2424
forward,
@@ -29,7 +29,7 @@ def with_padded(layer: Model[Padded, Padded]) -> Model[SeqT_co, SeqT_co]:
2929

3030

3131
def forward(
32-
model: Model[SeqT_co, SeqT_co], Xseq: SeqT, is_train: bool
32+
model: Model[Any, SeqT_co], Xseq: SeqT, is_train: bool
3333
) -> Tuple[SeqT, Callable]:
3434
layer: Model[Padded, Padded] = model.layers[0]
3535
Y: SeqT
@@ -48,7 +48,7 @@ def forward(
4848

4949

5050
def init(
51-
model: Model[SeqT_co, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
51+
model: Model[Any, SeqT_co], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
5252
) -> None:
5353
model.layers[0].initialize(
5454
X=_get_padded(model, X) if X is not None else None,
@@ -60,7 +60,7 @@ def _is_padded_data(seq: SeqT) -> bool:
6060
return isinstance(seq, tuple) and len(seq) == 4 and all(map(is_xp_array, seq))
6161

6262

63-
def _get_padded(model: Model[SeqT_co, SeqT_co], seq: SeqT) -> Padded:
63+
def _get_padded(model: Model[Any, SeqT_co], seq: SeqT) -> Padded:
6464
if isinstance(seq, Padded):
6565
return seq
6666
elif isinstance(seq, Ragged):
@@ -81,7 +81,7 @@ def _get_padded(model: Model[SeqT_co, SeqT_co], seq: SeqT) -> Padded:
8181

8282

8383
def _array_forward(
84-
layer: Model[Padded, Padded], X: Floats3d, is_train
84+
layer: Model[Any, Padded], X: Floats3d, is_train
8585
) -> Tuple[SeqT, Callable]:
8686
# Create bogus metadata for Padded.
8787
Xp = _get_padded(layer, X)
@@ -99,7 +99,7 @@ def backprop(dY: Floats3d) -> Floats3d:
9999

100100

101101
def _tuple_forward(
102-
layer: Model[Padded, Padded], X: PaddedData, is_train: bool
102+
layer: Model[Any, Padded], X: PaddedData, is_train: bool
103103
) -> Tuple[SeqT, Callable]:
104104
Yp, get_dXp = layer(Padded(*X), is_train)
105105

@@ -111,7 +111,7 @@ def backprop(dY):
111111

112112

113113
def _ragged_forward(
114-
layer: Model[Padded, Padded], Xr: Ragged, is_train: bool
114+
layer: Model[Any, Padded], Xr: Ragged, is_train: bool
115115
) -> Tuple[SeqT, Callable]:
116116
# Assign these to locals, to keep code a bit shorter.
117117
list2padded = layer.ops.list2padded
@@ -141,7 +141,7 @@ def backprop(dYr: Ragged):
141141

142142

143143
def _list_forward(
144-
layer: Model[Padded, Padded], Xs: List[Array2d], is_train: bool
144+
layer: Model[Any, Padded], Xs: List[Array2d], is_train: bool
145145
) -> Tuple[SeqT, Callable]:
146146
# Assign these to locals, to keep code a bit shorter.
147147
list2padded = layer.ops.list2padded

thinc/layers/with_ragged.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union
1+
from typing import Tuple, Callable, Optional, TypeVar, cast, List, Union, Any
22

33
from ..types import (
44
Padded,
@@ -62,12 +62,12 @@
6262

6363

6464
@registry.layers("with_ragged.v1")
65-
def with_ragged(layer: Model[Ragged, Ragged]) -> Model[SeqT_co, SeqT_co]:
65+
def with_ragged(layer: Model[Any, Ragged]) -> Model[Any, SeqT_co]:
6666
return Model(f"with_ragged({layer.name})", forward, init=init, layers=[layer])
6767

6868

6969
def forward(
70-
model: Model[SeqT_co, SeqT_co], Xseq: SeqT, is_train: bool
70+
model: Model[Any, SeqT_co], Xseq: SeqT, is_train: bool
7171
) -> Tuple[SeqT, Callable]:
7272
layer: Model[Ragged, Ragged] = model.layers[0]
7373
Y: SeqT_co
@@ -84,7 +84,7 @@ def forward(
8484

8585

8686
def init(
87-
model: Model[SeqT_co, SeqT_co],
87+
model: Model[Any, SeqT_co],
8888
X: Optional[SeqT_co] = None,
8989
Y: Optional[SeqT_co] = None,
9090
) -> None:
@@ -98,7 +98,7 @@ def _is_ragged_data(seq):
9898
return isinstance(seq, tuple) and len(seq) == 2
9999

100100

101-
def _get_ragged(model: Model[SeqT_co, SeqT_co], seq: SeqT) -> Ragged:
101+
def _get_ragged(model: Model[Any, SeqT_co], seq: SeqT) -> Ragged:
102102
if isinstance(seq, Ragged):
103103
return seq
104104
elif isinstance(seq, Padded):
@@ -115,7 +115,7 @@ def _get_ragged(model: Model[SeqT_co, SeqT_co], seq: SeqT) -> Ragged:
115115

116116

117117
def _tuple_forward(
118-
layer: Model[Ragged, Ragged], X: RaggedData, is_train: bool
118+
layer: Model[Any, Ragged], X: RaggedData, is_train: bool
119119
) -> Tuple[SeqT, Callable]:
120120
Yr, get_dXr = layer(Ragged(*X), is_train)
121121

@@ -127,7 +127,7 @@ def backprop(dY: RaggedData) -> RaggedData:
127127

128128

129129
def _padded_forward(
130-
layer: Model[Ragged, Ragged], Xp: Padded, is_train: bool
130+
layer: Model[Any, Ragged], Xp: Padded, is_train: bool
131131
) -> Tuple[SeqT, Callable]:
132132
# Assign these to locals, to keep code a bit shorter.
133133
list2padded = layer.ops.list2padded
@@ -156,7 +156,7 @@ def backprop(dYp: Padded):
156156

157157

158158
def _list_forward(
159-
layer: Model[Ragged, Ragged], Xs: List[Array2d], is_train: bool
159+
layer: Model[Any, Ragged], Xs: List[Array2d], is_train: bool
160160
) -> Tuple[SeqT, Callable]:
161161
# Assign these to locals, to keep code a bit shorter.
162162
flatten = layer.ops.flatten

0 commit comments

Comments
 (0)