@@ -82,7 +82,9 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
82
82
return result
83
83
84
84
85
- def percentile (x : NdarrayOrTensor , q , dim : Optional [int ] = None ) -> Union [NdarrayOrTensor , float , int ]:
85
+ def percentile (
86
+ x : NdarrayOrTensor , q , dim : Optional [int ] = None , keepdim : bool = False , ** kwargs
87
+ ) -> Union [NdarrayOrTensor , float , int ]:
86
88
"""`np.percentile` with equivalent implementation for torch.
87
89
88
90
Pytorch uses `quantile`, but this functionality is only available from v1.7.
@@ -97,6 +99,9 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
97
99
q: percentile to compute (should in range 0 <= q <= 100)
98
100
dim: the dim along which the percentiles are computed. default is to compute the percentile
99
101
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
102
+ keepdim: whether the output data has dim retained or not.
103
+ kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
104
+ https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
100
105
101
106
Returns:
102
107
Resulting value (scalar)
@@ -108,11 +113,11 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
108
113
raise ValueError
109
114
result : Union [NdarrayOrTensor , float , int ]
110
115
if isinstance (x , np .ndarray ):
111
- result = np .percentile (x , q , axis = dim )
116
+ result = np .percentile (x , q , axis = dim , keepdims = keepdim , ** kwargs )
112
117
else :
113
118
q = torch .tensor (q , device = x .device )
114
119
if hasattr (torch , "quantile" ): # `quantile` is new in torch 1.7.0
115
- result = torch .quantile (x , q / 100.0 , dim = dim )
120
+ result = torch .quantile (x , q / 100.0 , dim = dim , keepdim = keepdim )
116
121
else :
117
122
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
118
123
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
@@ -282,13 +287,23 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
282
287
return torch .cat (to_cat , dim = axis , out = out ) # type: ignore
283
288
284
289
285
- def cumsum (a : NdarrayOrTensor , axis = None ):
286
- """`np.cumsum` with equivalent implementation for torch."""
290
+ def cumsum (a : NdarrayOrTensor , axis = None , ** kwargs ):
291
+ """
292
+ `np.cumsum` with equivalent implementation for torch.
293
+
294
+ Args:
295
+ a: input data to compute cumsum.
296
+ axis: expected axis to compute cumsum.
297
+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
298
+ https://pytorch.org/docs/stable/generated/torch.cumsum.html.
299
+
300
+ """
301
+
287
302
if isinstance (a , np .ndarray ):
288
303
return np .cumsum (a , axis )
289
304
if axis is None :
290
- return torch .cumsum (a [:], 0 )
291
- return torch .cumsum (a , dim = axis )
305
+ return torch .cumsum (a [:], 0 , ** kwargs )
306
+ return torch .cumsum (a , dim = axis , ** kwargs )
292
307
293
308
294
309
def isfinite (x ):
@@ -298,18 +313,40 @@ def isfinite(x):
298
313
return torch .isfinite (x )
299
314
300
315
301
- def searchsorted (a : NdarrayOrTensor , v : NdarrayOrTensor , right = False , sorter = None ):
316
+ def searchsorted (a : NdarrayOrTensor , v : NdarrayOrTensor , right = False , sorter = None , ** kwargs ):
317
+ """
318
+ `np.searchsorted` with equivalent implementation for torch.
319
+
320
+ Args:
321
+ a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
322
+ v: containing the search values.
323
+ right: if False, return the first suitable location that is found, if True, return the last such index.
324
+ sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
325
+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
326
+ https://pytorch.org/docs/stable/generated/torch.searchsorted.html.
327
+
328
+ """
302
329
side = "right" if right else "left"
303
330
if isinstance (a , np .ndarray ):
304
331
return np .searchsorted (a , v , side , sorter ) # type: ignore
305
- return torch .searchsorted (a , v , right = right ) # type: ignore
332
+ return torch .searchsorted (a , v , right = right , ** kwargs ) # type: ignore
333
+
306
334
335
+ def repeat (a : NdarrayOrTensor , repeats : int , axis : Optional [int ] = None , ** kwargs ):
336
+ """
337
+ `np.repeat` with equivalent implementation for torch (`repeat_interleave`).
338
+
339
+ Args:
340
+ a: input data to repeat.
341
+ repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis.
342
+ axis: axis along which to repeat values.
343
+ kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
344
+ https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.
307
345
308
- def repeat (a : NdarrayOrTensor , repeats : int , axis : Optional [int ] = None ):
309
- """`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
346
+ """
310
347
if isinstance (a , np .ndarray ):
311
348
return np .repeat (a , repeats , axis )
312
- return torch .repeat_interleave (a , repeats , dim = axis )
349
+ return torch .repeat_interleave (a , repeats , dim = axis , ** kwargs )
313
350
314
351
315
352
def isnan (x : NdarrayOrTensor ):
@@ -330,7 +367,7 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
330
367
Args:
331
368
x: array/tensor
332
369
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
333
- https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous .
370
+ https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.
334
371
335
372
"""
336
373
if isinstance (x , np .ndarray ):
0 commit comments