Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,87 @@

from __future__ import annotations

import warnings
from typing import Sequence

from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import FLOAT, INT64

_INT64_MAX = 0x7FFFFFFFFFFFFFFF


@torch_op("torchvision::nms")
@torch_op("torchvision::nms", trace_only=True)
def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64:
"""nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor"""
# boxes: [num_batches, spatial_dimension, 4]
boxes = op.Unsqueeze(boxes, [0])
# scores: [num_batches, num_classes, spatial_dimension]
scores = op.Unsqueeze(scores, [0, 1])
# nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index]
nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold)
return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1])


def _process_batch_indices_for_roi_align(rois):
# Extract batch indices from the first column (index 0) of rois
indices = op.Slice(rois, axes=[1], starts=[0], ends=[1])
indices = op.Squeeze(indices, axes=[1])
return op.Cast(indices, to=INT64.dtype)


def _process_rois_for_roi_align(rois):
# Extract roi coordinates from columns 1, 2, 3, 4 (x1, y1, x2, y2)
return op.Slice(rois, axes=[1], starts=[1], ends=[5])


def _process_sampling_ratio_for_roi_align(sampling_ratio: int):
if sampling_ratio < 0:
warnings.warn(
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0.",
stacklevel=2,
)
sampling_ratio = 0
return sampling_ratio


@torch_op("torchvision::roi_align", trace_only=True)
def torchvision_roi_align(
input,
boxes,
output_size: Sequence[int],
spatial_scale: float = 1.0,
sampling_ratio: int = -1,
aligned: bool = False,
):
Comment thread
justinchuby marked this conversation as resolved.
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
pooled_height, pooled_width = output_size
batch_indices = _process_batch_indices_for_roi_align(boxes)
rois_coords = _process_rois_for_roi_align(boxes)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio)

return op.RoiAlign(
input,
rois_coords,
batch_indices,
coordinate_transformation_mode=coordinate_transformation_mode,
spatial_scale=spatial_scale,
output_height=pooled_height,
output_width=pooled_width,
sampling_ratio=sampling_ratio,
)


@torch_op("torchvision::roi_pool", trace_only=True)
def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0):
"""roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor"""
# MaxRoiPool expects boxes in format [batch_index, x1, y1, x2, y2]
pooled_height, pooled_width = output_size
return op.MaxRoiPool(
input,
boxes,
pooled_shape=(pooled_height, pooled_width),
spatial_scale=spatial_scale,
)
106 changes: 106 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,98 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
yield opinfo_core.SampleInput(make_inp(shape), args=(pad,))


def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
del op_info
del kwargs
# roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)

# Test 1: spatial_scale=1, sampling_ratio=2
x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x1,
args=(roi1, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True},
)

# Test 2: spatial_scale=0.5, sampling_ratio=3
x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x2,
args=(roi2, (5, 5)),
kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True},
)

# Test 3: spatial_scale=1.8, sampling_ratio=2
x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x3,
args=(roi3, (5, 5)),
kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True},
)

# Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x4,
args=(roi4, (2, 2)),
kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True},
)

# Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x5,
args=(roi5, (2, 2)),
kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True},
)

# Test 6: malformed boxes (test_roi_align_malformed_boxes)
x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x6,
args=(roi6, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True},
)

# Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x7,
args=(roi7, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False},
)

# Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x8,
args=(roi8, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False},
)


def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
# roi_pool signature: (input, boxes, output_size, spatial_scale=1.0)

x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x,
args=(rois, (5, 5)),
kwargs={"spatial_scale": 2.0},
)


def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -3038,4 +3130,18 @@ def __init__(self):
sample_inputs_func=sample_inputs_non_max_suppression,
supports_out=False,
),
opinfo_core.OpInfo(
"torchvision.ops.roi_align",
op=torchvision.ops.roi_align,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_roi_align,
supports_out=False,
),
opinfo_core.OpInfo(
"torchvision.ops.roi_pool",
op=torchvision.ops.roi_pool,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_roi_pool,
supports_out=False,
),
]
2 changes: 2 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,8 @@ def _where_input_wrangler(
),
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like),
TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms),
TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align),
TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
Expand Down
Loading