@@ -82,7 +82,9 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
8282 return result
8383
8484
85- def percentile (x : NdarrayOrTensor , q , dim : Optional [int ] = None ) -> Union [NdarrayOrTensor , float , int ]:
85+ def percentile (
86+ x : NdarrayOrTensor , q , dim : Optional [int ] = None , keepdim : bool = False , ** kwargs
87+ ) -> Union [NdarrayOrTensor , float , int ]:
8688 """`np.percentile` with equivalent implementation for torch.
8789
8890 Pytorch uses `quantile`, but this functionality is only available from v1.7.
@@ -97,6 +99,9 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
9799 q: percentile to compute (should in range 0 <= q <= 100)
98100 dim: the dim along which the percentiles are computed. default is to compute the percentile
99101 along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
102+ keepdim: whether the output data has dim retained or not.
103+ kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
104+ https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
100105
101106 Returns:
102107 Resulting value (scalar)
@@ -108,11 +113,11 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
108113 raise ValueError
109114 result : Union [NdarrayOrTensor , float , int ]
110115 if isinstance (x , np .ndarray ):
111- result = np .percentile (x , q , axis = dim )
116+ result = np .percentile (x , q , axis = dim , keepdims = keepdim , ** kwargs )
112117 else :
113118 q = torch .tensor (q , device = x .device )
114119 if hasattr (torch , "quantile" ): # `quantile` is new in torch 1.7.0
115- result = torch .quantile (x , q / 100.0 , dim = dim )
120+ result = torch .quantile (x , q / 100.0 , dim = dim , keepdim = keepdim )
116121 else :
117122 # Note that ``kthvalue()`` works one-based, i.e., the first sorted value
118123 # corresponds to k=1, not k=0. Thus, we need the `1 +`.
@@ -282,13 +287,23 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
282287 return torch .cat (to_cat , dim = axis , out = out ) # type: ignore
283288
284289
285- def cumsum (a : NdarrayOrTensor , axis = None ):
286- """`np.cumsum` with equivalent implementation for torch."""
290+ def cumsum (a : NdarrayOrTensor , axis = None , ** kwargs ):
291+ """
292+ `np.cumsum` with equivalent implementation for torch.
293+
294+ Args:
295+ a: input data to compute cumsum.
296+ axis: expected axis to compute cumsum.
297+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
298+ https://pytorch.org/docs/stable/generated/torch.cumsum.html.
299+
300+ """
301+
287302 if isinstance (a , np .ndarray ):
288303 return np .cumsum (a , axis )
289304 if axis is None :
290- return torch .cumsum (a [:], 0 )
291- return torch .cumsum (a , dim = axis )
305+ return torch .cumsum (a [:], 0 , ** kwargs )
306+ return torch .cumsum (a , dim = axis , ** kwargs )
292307
293308
294309def isfinite (x ):
@@ -298,18 +313,40 @@ def isfinite(x):
298313 return torch .isfinite (x )
299314
300315
301- def searchsorted (a : NdarrayOrTensor , v : NdarrayOrTensor , right = False , sorter = None ):
316+ def searchsorted (a : NdarrayOrTensor , v : NdarrayOrTensor , right = False , sorter = None , ** kwargs ):
317+ """
318+ `np.searchsorted` with equivalent implementation for torch.
319+
320+ Args:
321+ a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
322+ v: containing the search values.
323+ right: if False, return the first suitable location that is found, if True, return the last such index.
324+ sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
325+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
326+ https://pytorch.org/docs/stable/generated/torch.searchsorted.html.
327+
328+ """
302329 side = "right" if right else "left"
303330 if isinstance (a , np .ndarray ):
304331 return np .searchsorted (a , v , side , sorter ) # type: ignore
305- return torch .searchsorted (a , v , right = right ) # type: ignore
332+ return torch .searchsorted (a , v , right = right , ** kwargs ) # type: ignore
333+
306334
335+ def repeat (a : NdarrayOrTensor , repeats : int , axis : Optional [int ] = None , ** kwargs ):
336+ """
337+ `np.repeat` with equivalent implementation for torch (`repeat_interleave`).
338+
339+ Args:
340+ a: input data to repeat.
341+ repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis.
342+ axis: axis along which to repeat values.
343+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
344+ https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.
307345
308- def repeat (a : NdarrayOrTensor , repeats : int , axis : Optional [int ] = None ):
309- """`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
346+ """
310347 if isinstance (a , np .ndarray ):
311348 return np .repeat (a , repeats , axis )
312- return torch .repeat_interleave (a , repeats , dim = axis )
349+ return torch .repeat_interleave (a , repeats , dim = axis , ** kwargs )
313350
314351
315352def isnan (x : NdarrayOrTensor ):
@@ -330,7 +367,7 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
330367 Args:
331368 x: array/tensor
332369 kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
333- https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous .
370+ https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.
334371
335372 """
336373 if isinstance (x , np .ndarray ):
0 commit comments