55import numpy
66import itertools
77
8- from .. import registry
98from ..types import Xp , Shape , DTypes , DTypesInt , DTypesFloat , List2d , ArrayXd
10- from ..types import Array3d , Floats1d , Floats2d , Floats3d , Floats4d
9+ from ..types import Floats1d , Floats2d , Floats3d , Floats4d
10+ from ..types import Array1d , Array2d , Array3d , Array4d , ListXd
1111from ..types import FloatsXd , Ints1d , Ints2d , Ints3d , Ints4d , IntsXd , _Floats
1212from ..types import DeviceTypes , Generator , Padded , Batchable , SizedGenerator
1313from ..util import get_array_module , is_xp_array , to_numpy
@@ -135,13 +135,11 @@ def _get_batch(self, sequence, indices):
135135 if isinstance (sequence , list ):
136136 subseq = [sequence [i ] for i in indices ]
137137 elif isinstance (sequence , tuple ):
138- subseq = tuple (sequence [i ] for i in indices ) # type: ignore
138+ subseq = tuple (sequence [i ] for i in indices )
139139 else :
140- subseq = sequence [indices ] # type: ignore
140+ subseq = sequence [indices ]
141141 if is_xp_array (subseq ):
142- subseq = self .as_contig (
143- cast (ArrayXd , self .xp .asarray (subseq ))
144- ) # type: ignore
142+ subseq = self .as_contig (self .xp .asarray (subseq ))
145143 return subseq
146144
147145 def _get_batch_sizes (self , length : int , sizes : Iterator [int ]):
@@ -225,13 +223,65 @@ def affine(self, X: Floats2d, W: Floats2d, b: Floats1d) -> Floats2d:
225223 Y += b
226224 return Y
227225
226+ @overload
228227 def flatten (
229228 self ,
230- X : Sequence [ ArrayT ],
229+ X : List [ Floats2d ],
231230 dtype : Optional [DTypes ] = None ,
232231 pad : int = 0 ,
233232 ndim_if_empty : int = 2 ,
234- ) -> ArrayT :
233+ ) -> Floats2d :
234+ ...
235+
236+ @overload
237+ def flatten (
238+ self ,
239+ X : List [Ints1d ],
240+ dtype : Optional [DTypes ] = None ,
241+ pad : int = 0 ,
242+ ndim_if_empty : int = 2 ,
243+ ) -> Ints1d :
244+ ...
245+
246+ @overload
247+ def flatten (
248+ self ,
249+ X : List2d ,
250+ dtype : Optional [DTypes ] = None ,
251+ pad : int = 0 ,
252+ ndim_if_empty : int = 2 ,
253+ ) -> Array2d :
254+ ...
255+
256+ # further specific typed signatures can be added as necessary
257+
258+ @overload
259+ def flatten (
260+ self ,
261+ X : ListXd ,
262+ dtype : Optional [DTypes ] = None ,
263+ pad : int = 0 ,
264+ ndim_if_empty : int = 2 ,
265+ ) -> ArrayXd :
266+ ...
267+
268+ @overload
269+ def flatten (
270+ self ,
271+ X : Sequence [ArrayXd ],
272+ dtype : Optional [DTypes ] = None ,
273+ pad : int = 0 ,
274+ ndim_if_empty : int = 2 ,
275+ ) -> ArrayXd :
276+ ...
277+
278+ def flatten (
279+ self ,
280+ X : Sequence [ArrayXd ],
281+ dtype : Optional [DTypes ] = None ,
282+ pad : int = 0 ,
283+ ndim_if_empty : int = 2 ,
284+ ) -> ArrayXd :
235285 """Flatten a list of arrays into one large array."""
236286 if X is None or len (X ) == 0 :
237287 return self .alloc ((0 ,) * ndim_if_empty , dtype = dtype or "f" )
@@ -252,7 +302,25 @@ def flatten(
252302 result = xp .asarray (result , dtype = dtype )
253303 return result
254304
305+ @overload
255306 def unflatten (self , X : Floats2d , lengths : Ints1d , pad : int = 0 ) -> List [Floats2d ]:
307+ ...
308+
309+ @overload
310+ def unflatten (self , X : Ints1d , lengths : Ints1d , pad : int = 0 ) -> List [Ints1d ]:
311+ ...
312+
313+ @overload
314+ def unflatten (self , X : Array2d , lengths : Ints1d , pad : int = 0 ) -> List2d :
315+ ...
316+
317+ # further specific typed signatures can be added as necessary
318+
319+ @overload
320+ def unflatten (self , X : ArrayXd , lengths : Ints1d , pad : int = 0 ) -> ListXd :
321+ ...
322+
323+ def unflatten (self , X : ArrayXd , lengths : Ints1d , pad : int = 0 ) -> ListXd :
256324 """The reverse/backward operation of the `flatten` function: unflatten
257325 a large array into a list of arrays according to the given lengths.
258326 """
@@ -302,7 +370,7 @@ def pad( # noqa: F811
302370 output : Array3d = self .alloc (final_shape , dtype = seqs [0 ].dtype )
303371 for i , arr in enumerate (seqs ):
304372 # It's difficult to convince this that the dtypes will match.
305- output [i , : arr .shape [0 ]] = arr # type: ignore
373+ output [i , : arr .shape [0 ]] = arr # type: ignore[assignment, call-overload]
306374 return output
307375
308376 def unpad (self , padded : Array3d , lengths : List [int ]) -> List2d :
@@ -314,14 +382,14 @@ def unpad(self, padded: Array3d, lengths: List[int]) -> List2d:
314382 output .append (padded [i , :length ])
315383 return cast (List2d , output )
316384
317- def list2padded (self , seqs : List [ Floats2d ] ) -> Padded :
385+ def list2padded (self , seqs : List2d ) -> Padded :
318386 """Pack a sequence of 2d arrays into a Padded datatype."""
319387 if not seqs :
320388 return Padded (
321389 self .alloc3f (0 , 0 , 0 ), self .alloc1i (0 ), self .alloc1i (0 ), self .alloc1i (0 )
322390 )
323391 elif len (seqs ) == 1 :
324- data = self .reshape3f (seqs [0 ], seqs [0 ].shape [0 ], 1 , seqs [0 ].shape [1 ])
392+ data = self .reshape3 (seqs [0 ], seqs [0 ].shape [0 ], 1 , seqs [0 ].shape [1 ])
325393 size_at_t = self .asarray1i ([1 ] * data .shape [0 ])
326394 lengths = self .asarray1i ([data .shape [0 ]])
327395 indices = self .asarray1i ([0 ])
@@ -336,8 +404,8 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
336404 # Reorder the sequences, by length. This looks the same in either
337405 # direction: you're swapping elements between their original and sorted
338406 # position.
339- seqs = [seqs [i ] for i in indices_ ]
340- arr : Floats3d = self .pad (seqs )
407+ seqs = cast ( List2d , [seqs [i ] for i in indices_ ])
408+ arr : Array3d = self .pad (seqs )
341409 assert arr .shape == (nB , nS , nO ), (nB , nS , nO )
342410 arr = self .as_contig (arr .transpose ((1 , 0 , 2 )))
343411 assert arr .shape == (nS , nB , nO )
@@ -350,7 +418,7 @@ def list2padded(self, seqs: List[Floats2d]) -> Padded:
350418 batch_size_at_t_ [t ] = current_size
351419 assert sum (lengths_ ) == sum (batch_size_at_t_ )
352420 return Padded (
353- cast ( Floats3d , arr ) ,
421+ arr ,
354422 self .asarray1i (batch_size_at_t_ ),
355423 self .asarray1i (lengths_ ),
356424 self .asarray1i (indices_ ),
@@ -361,7 +429,7 @@ def padded2list(self, padded: Padded) -> List2d:
361429 data = padded .data
362430 indices = to_numpy (padded .indices )
363431 lengths = to_numpy (padded .lengths )
364- unpadded : List [Optional [Floats2d ]] = [None ] * len (lengths )
432+ unpadded : List [Optional [Array2d ]] = [None ] * len (lengths )
365433 # Transpose from (length, batch, data) to (batch, length, data)
366434 data = self .as_contig (data .transpose ((1 , 0 , 2 )))
367435 for i in range (data .shape [0 ]):
@@ -500,6 +568,18 @@ def alloc(
500568 else :
501569 return self .xp .empty (shape , dtype = dtype )
502570
571+ def reshape1 (self , array : ArrayXd , d0 : int ) -> Array1d :
572+ return cast (Array1d , self .reshape (array , (d0 ,)))
573+
574+ def reshape2 (self , array : ArrayXd , d0 : int , d1 : int ) -> Array2d :
575+ return cast (Array2d , self .reshape (array , (d0 , d1 )))
576+
577+ def reshape3 (self , array : ArrayXd , d0 : int , d1 : int , d2 : int ) -> Array3d :
578+ return cast (Array3d , self .reshape (array , (d0 , d1 , d2 )))
579+
580+ def reshape4 (self , array : ArrayXd , d0 : int , d1 : int , d2 : int , d3 : int ) -> Array4d :
581+ return cast (Array4d , self .reshape (array , (d0 , d1 , d2 , d3 )))
582+
503583 def reshape1f (self , array : FloatsXd , d0 : int ) -> Floats1d :
504584 return cast (Floats1d , self .reshape (array , (d0 ,)))
505585
@@ -619,7 +699,7 @@ def asarray(
619699 return self .xp .asarray (data , dtype = dtype )
620700 elif hasattr (data , "numpy" ):
621701 # Handles PyTorch Tensor
622- return data .numpy () # type: ignore
702+ return data .numpy () # type: ignore[union-attr]
623703 elif dtype is not None :
624704 return self .xp .array (data , dtype = dtype )
625705 else :
@@ -641,8 +721,8 @@ def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType:
641721
642722 if inplace :
643723 self .xp .exp (- X , out = X )
644- X += 1.0 # type: ignore
645- X **= - 1.0 # type: ignore
724+ X += 1.0 # type: ignore[assignment]
725+ X **= - 1.0 # type: ignore[assignment]
646726 return cast (FloatsType , X )
647727 else :
648728 return cast (FloatsType , 1.0 / (1.0 + self .xp .exp (- X )))
@@ -786,10 +866,10 @@ def clipped_linear(
786866 inplace : bool = False ,
787867 ) -> FloatsType :
788868 if inplace :
789- X *= slope # type: ignore
790- X += offset # type: ignore
869+ X *= slope # type: ignore[assignment]
870+ X += offset # type: ignore[assignment]
791871 return cast (FloatsType , self .xp .clip (X , min_val , max_val , out = X ))
792- out = X * slope + offset # type: ignore
872+ out = X * slope + offset # type: ignore[assignment]
793873 return cast (FloatsType , self .xp .clip (out , min_val , max_val ))
794874
795875 def backprop_clipped_linear (
@@ -840,27 +920,27 @@ def backprop_hard_tanh(
840920
841921 def swish (self , X : FloatsType , inplace : bool = False ) -> FloatsType :
842922 if inplace :
843- X *= self .sigmoid (X ) # type: ignore
923+ X *= self .sigmoid (X ) # type: ignore[operator, assignment]
844924 return cast (FloatsType , X )
845- out = X * self .sigmoid (X ) # type: ignore
925+ out = X * self .sigmoid (X ) # type: ignore[operator]
846926 return cast (FloatsType , out )
847927
848928 def backprop_swish (
849929 self , dY : FloatsType , X : FloatsType , Y : FloatsType , inplace : bool = False
850930 ) -> FloatsType :
851- Y = Y + self .sigmoid (X ) * (1 - Y ) # type: ignore
931+ Y = Y + self .sigmoid (X ) * (1 - Y ) # type: ignore[operator]
852932 if inplace :
853- dY *= Y # type: ignore
933+ dY *= Y # type: ignore[operator, assignment]
854934 return cast (FloatsType , dY )
855- out = dY * Y # type: ignore
935+ out = dY * Y # type: ignore[operator]
856936 return cast (FloatsType , out )
857937
858938 # Following https://www.scitepress.org/Papers/2019/74696/74696.pdf
859939 def hard_swish (self , X : FloatsType , inplace : bool = False ) -> FloatsType :
860940 if inplace :
861- X *= self .hard_sigmoid (X ) # type: ignore
941+ X *= self .hard_sigmoid (X ) # type: ignore[operator, assignment]
862942 return cast (FloatsType , X )
863- out = X * self .hard_sigmoid (X ) # type: ignore
943+ out = X * self .hard_sigmoid (X ) # type: ignore[operator]
864944 return cast (FloatsType , out )
865945
866946 def backprop_hard_swish (
@@ -927,7 +1007,7 @@ def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType:
9271007 else :
9281008 Y = self .xp .array (X )
9291009 Y *= tmp
930- return cast ( FloatsType , Y )
1010+ return Y
9311011
9321012 def backprop_gelu_approx (
9331013 self , dY : FloatsType , X : FloatsType , inplace : bool = False
@@ -949,15 +1029,15 @@ def gelu(self, X: FloatsType, inplace: bool = False) -> FloatsType:
9491029 # GELU(x) = x · Φ(x)
9501030 cdf = gaussian_cdf (self , X )
9511031 if inplace :
952- X *= cdf # type: ignore
1032+ X *= cdf # type: ignore[operator, assignment]
9531033 return X
954- return X * cdf # type: ignore
1034+ return X * cdf # type: ignore[operator, return-value]
9551035
9561036 def backprop_gelu (
9571037 self , dY : FloatsType , X : FloatsType , inplace : bool = False
9581038 ) -> FloatsType :
9591039 # GELU'(x) = Φ(x) + x · PDF(x)
960- dX = gaussian_cdf (self , X ) + X * gaussian_pdf (self , X ) # type: ignore
1040+ dX = gaussian_cdf (self , X ) + X * gaussian_pdf (self , X ) # type: ignore[operator]
9611041 if inplace :
9621042 dY *= dX
9631043 return dY
@@ -1239,8 +1319,8 @@ def lstm_forward_training(
12391319 for d in range (dirs ):
12401320 # The inits are shaped (depth, dirs, nO). We add the internal dimension
12411321 # to make them set correctly.
1242- Yt2 = h_init [i , d ].reshape ((1 , nO )) # type: ignore
1243- Ct2 = c_init [i , d ].reshape ((1 , nO )) # type: ignore
1322+ Yt2 = h_init [i , d ].reshape ((1 , nO )) # type: ignore[assignment]
1323+ Ct2 = c_init [i , d ].reshape ((1 , nO )) # type: ignore[assignment]
12441324 layer_params , params_i = _split_weights (params , i , nO , nI , params_i )
12451325 Wx , Wh , bias = _transpose_weights (layer_params )
12461326 G [i , d ] += xp .dot (X , Wx .T )
0 commit comments