Skip to content

Commit d4a57d6

Browse files
authored
Merge branch 'dev' into temp-fix-windows-test
2 parents 7c7d4e1 + 2fef7ff commit d4a57d6

File tree

1 file changed

+50
-13
lines changed

1 file changed

+50
-13
lines changed

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
8282
return result
8383

8484

85-
def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]:
85+
def percentile(
86+
x: NdarrayOrTensor, q, dim: Optional[int] = None, keepdim: bool = False, **kwargs
87+
) -> Union[NdarrayOrTensor, float, int]:
8688
"""`np.percentile` with equivalent implementation for torch.
8789
8890
Pytorch uses `quantile`, but this functionality is only available from v1.7.
@@ -97,6 +99,9 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
9799
q: percentile to compute (should in range 0 <= q <= 100)
98100
dim: the dim along which the percentiles are computed. default is to compute the percentile
99101
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
102+
keepdim: whether the output data has dim retained or not.
103+
kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
104+
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
100105
101106
Returns:
102107
Resulting value (scalar)
@@ -108,11 +113,11 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
108113
raise ValueError
109114
result: Union[NdarrayOrTensor, float, int]
110115
if isinstance(x, np.ndarray):
111-
result = np.percentile(x, q, axis=dim)
116+
result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs)
112117
else:
113118
q = torch.tensor(q, device=x.device)
114119
if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0
115-
result = torch.quantile(x, q / 100.0, dim=dim)
120+
result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim)
116121
else:
117122
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
118123
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
@@ -282,13 +287,23 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
282287
return torch.cat(to_cat, dim=axis, out=out) # type: ignore
283288

284289

285-
def cumsum(a: NdarrayOrTensor, axis=None):
286-
"""`np.cumsum` with equivalent implementation for torch."""
290+
def cumsum(a: NdarrayOrTensor, axis=None, **kwargs):
291+
"""
292+
`np.cumsum` with equivalent implementation for torch.
293+
294+
Args:
295+
a: input data to compute cumsum.
296+
axis: expected axis to compute cumsum.
297+
kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
298+
https://pytorch.org/docs/stable/generated/torch.cumsum.html.
299+
300+
"""
301+
287302
if isinstance(a, np.ndarray):
288303
return np.cumsum(a, axis)
289304
if axis is None:
290-
return torch.cumsum(a[:], 0)
291-
return torch.cumsum(a, dim=axis)
305+
return torch.cumsum(a[:], 0, **kwargs)
306+
return torch.cumsum(a, dim=axis, **kwargs)
292307

293308

294309
def isfinite(x):
@@ -298,18 +313,40 @@ def isfinite(x):
298313
return torch.isfinite(x)
299314

300315

301-
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
316+
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs):
317+
"""
318+
`np.searchsorted` with equivalent implementation for torch.
319+
320+
Args:
321+
a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
322+
v: containing the search values.
323+
right: if False, return the first suitable location that is found, if True, return the last such index.
324+
sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
325+
kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
326+
https://pytorch.org/docs/stable/generated/torch.searchsorted.html.
327+
328+
"""
302329
side = "right" if right else "left"
303330
if isinstance(a, np.ndarray):
304331
return np.searchsorted(a, v, side, sorter) # type: ignore
305-
return torch.searchsorted(a, v, right=right) # type: ignore
332+
return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore
333+
306334

335+
def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs):
336+
"""
337+
`np.repeat` with equivalent implementation for torch (`repeat_interleave`).
338+
339+
Args:
340+
a: input data to repeat.
341+
repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis.
342+
axis: axis along which to repeat values.
343+
kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
344+
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.
307345
308-
def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None):
309-
"""`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
346+
"""
310347
if isinstance(a, np.ndarray):
311348
return np.repeat(a, repeats, axis)
312-
return torch.repeat_interleave(a, repeats, dim=axis)
349+
return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)
313350

314351

315352
def isnan(x: NdarrayOrTensor):
@@ -330,7 +367,7 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
330367
Args:
331368
x: array/tensor
332369
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
333-
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.
370+
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.
334371
335372
"""
336373
if isinstance(x, np.ndarray):

0 commit comments

Comments
 (0)