1- from typing import Tuple , Callable , Optional , TypeVar , Union , cast , List
1+ from typing import Tuple , Callable , Optional , TypeVar , Union , cast , List , Any
22
33from ..types import Padded , Ragged , Floats3d , Ints1d , Floats2d , Array2d , List2d
44from ..model import Model
1818
1919
2020@registry .layers ("with_padded.v1" )
21- def with_padded (layer : Model [Padded , Padded ]) -> Model [SeqT_co , SeqT_co ]:
21+ def with_padded (layer : Model [Any , Padded ]) -> Model [Any , SeqT_co ]:
2222 return Model (
2323 f"with_padded({ layer .name } )" ,
2424 forward ,
@@ -29,7 +29,7 @@ def with_padded(layer: Model[Padded, Padded]) -> Model[SeqT_co, SeqT_co]:
2929
3030
3131def forward (
32- model : Model [SeqT_co , SeqT_co ], Xseq : SeqT , is_train : bool
32+ model : Model [Any , SeqT_co ], Xseq : SeqT , is_train : bool
3333) -> Tuple [SeqT , Callable ]:
3434 layer : Model [Padded , Padded ] = model .layers [0 ]
3535 Y : SeqT
@@ -48,7 +48,7 @@ def forward(
4848
4949
5050def init (
51- model : Model [SeqT_co , SeqT_co ], X : Optional [SeqT ] = None , Y : Optional [SeqT ] = None
51+ model : Model [Any , SeqT_co ], X : Optional [SeqT ] = None , Y : Optional [SeqT ] = None
5252) -> None :
5353 model .layers [0 ].initialize (
5454 X = _get_padded (model , X ) if X is not None else None ,
@@ -60,7 +60,7 @@ def _is_padded_data(seq: SeqT) -> bool:
6060 return isinstance (seq , tuple ) and len (seq ) == 4 and all (map (is_xp_array , seq ))
6161
6262
63- def _get_padded (model : Model [SeqT_co , SeqT_co ], seq : SeqT ) -> Padded :
63+ def _get_padded (model : Model [Any , SeqT_co ], seq : SeqT ) -> Padded :
6464 if isinstance (seq , Padded ):
6565 return seq
6666 elif isinstance (seq , Ragged ):
@@ -81,7 +81,7 @@ def _get_padded(model: Model[SeqT_co, SeqT_co], seq: SeqT) -> Padded:
8181
8282
8383def _array_forward (
84- layer : Model [Padded , Padded ], X : Floats3d , is_train
84+ layer : Model [Any , Padded ], X : Floats3d , is_train
8585) -> Tuple [SeqT , Callable ]:
8686 # Create bogus metadata for Padded.
8787 Xp = _get_padded (layer , X )
@@ -99,7 +99,7 @@ def backprop(dY: Floats3d) -> Floats3d:
9999
100100
101101def _tuple_forward (
102- layer : Model [Padded , Padded ], X : PaddedData , is_train : bool
102+ layer : Model [Any , Padded ], X : PaddedData , is_train : bool
103103) -> Tuple [SeqT , Callable ]:
104104 Yp , get_dXp = layer (Padded (* X ), is_train )
105105
@@ -111,7 +111,7 @@ def backprop(dY):
111111
112112
113113def _ragged_forward (
114- layer : Model [Padded , Padded ], Xr : Ragged , is_train : bool
114+ layer : Model [Any , Padded ], Xr : Ragged , is_train : bool
115115) -> Tuple [SeqT , Callable ]:
116116 # Assign these to locals, to keep code a bit shorter.
117117 list2padded = layer .ops .list2padded
@@ -141,7 +141,7 @@ def backprop(dYr: Ragged):
141141
142142
143143def _list_forward (
144- layer : Model [Padded , Padded ], Xs : List [Array2d ], is_train : bool
144+ layer : Model [Any , Padded ], Xs : List [Array2d ], is_train : bool
145145) -> Tuple [SeqT , Callable ]:
146146 # Assign these to locals, to keep code a bit shorter.
147147 list2padded = layer .ops .list2padded
0 commit comments