diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 57e2f8bd0c..5c1b1fda6b 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -7,6 +7,9 @@ from __future__ import annotations +import warnings +from typing import Sequence + from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import FLOAT, INT64 @@ -14,8 +17,9 @@ _INT64_MAX = 0x7FFFFFFFFFFFFFFF -@torch_op("torchvision::nms") +@torch_op("torchvision::nms", trace_only=True) def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64: + """nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor""" # boxes: [num_batches, spatial_dimension, 4] boxes = op.Unsqueeze(boxes, [0]) # scores: [num_batches, num_classes, spatial_dimension] @@ -23,3 +27,67 @@ def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64: # nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index] nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold) return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1]) + + +def _process_batch_indices_for_roi_align(rois): + # Extract batch indices from the first column (index 0) of rois + indices = op.Slice(rois, axes=[1], starts=[0], ends=[1]) + indices = op.Squeeze(indices, axes=[1]) + return op.Cast(indices, to=INT64.dtype) + + +def _process_rois_for_roi_align(rois): + # Extract roi coordinates from columns 1, 2, 3, 4 (x1, y1, x2, y2) + return op.Slice(rois, axes=[1], starts=[1], ends=[5]) + + +def _process_sampling_ratio_for_roi_align(sampling_ratio: int): + if sampling_ratio < 0: + warnings.warn( + "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. " + "The model will be exported with a sampling_ratio of 0.", + stacklevel=2, + ) + sampling_ratio = 0 + return sampling_ratio + + +@torch_op("torchvision::roi_align", trace_only=True) +def torchvision_roi_align( + input, + boxes, + output_size: Sequence[int], + spatial_scale: float = 1.0, + sampling_ratio: int = -1, + aligned: bool = False, +): + """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" + pooled_height, pooled_width = output_size + batch_indices = _process_batch_indices_for_roi_align(boxes) + rois_coords = _process_rois_for_roi_align(boxes) + coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" + sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio) + + return op.RoiAlign( + input, + rois_coords, + batch_indices, + coordinate_transformation_mode=coordinate_transformation_mode, + spatial_scale=spatial_scale, + output_height=pooled_height, + output_width=pooled_width, + sampling_ratio=sampling_ratio, + ) + + +@torch_op("torchvision::roi_pool", trace_only=True) +def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0): + """roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor""" + # MaxRoiPool expects boxes in format [batch_index, x1, y1, x2, y2] + pooled_height, pooled_width = output_size + return op.MaxRoiPool( + input, + boxes, + pooled_shape=(pooled_height, pooled_width), + spatial_scale=spatial_scale, + ) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 2ce015b363..a28a6c9cd9 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1470,6 +1470,98 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa yield opinfo_core.SampleInput(make_inp(shape), args=(pad,)) +def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False) + + # Test 1: spatial_scale=1, sampling_ratio=2 + x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x1, + args=(roi1, (5, 5)), + kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True}, + ) + + # Test 2: spatial_scale=0.5, sampling_ratio=3 + x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x2, + args=(roi2, (5, 5)), + kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True}, + ) + + # Test 3: spatial_scale=1.8, sampling_ratio=2 + x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x3, + args=(roi3, (5, 5)), + kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True}, + ) + + # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2) + x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x4, + args=(roi4, (2, 2)), + kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True}, + ) + + # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2) + x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x5, + args=(roi5, (2, 2)), + kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True}, + ) + + # Test 6: malformed boxes (test_roi_align_malformed_boxes) + x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x6, + args=(roi6, (5, 5)), + kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True}, + ) + + # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2 + x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x7, + args=(roi7, (5, 5)), + kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False}, + ) + + # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1 + x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x8, + args=(roi8, (5, 5)), + kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False}, + ) + + +def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + # roi_pool signature: (input, boxes, output_size, spatial_scale=1.0) + + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) + yield opinfo_core.SampleInput( + x, + args=(rois, (5, 5)), + kwargs={"spatial_scale": 2.0}, + ) + + def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -3038,4 +3130,18 @@ def __init__(self): sample_inputs_func=sample_inputs_non_max_suppression, supports_out=False, ), + opinfo_core.OpInfo( + "torchvision.ops.roi_align", + op=torchvision.ops.roi_align, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_roi_align, + supports_out=False, + ), + opinfo_core.OpInfo( + "torchvision.ops.roi_pool", + op=torchvision.ops.roi_pool, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_roi_pool, + supports_out=False, + ), ] diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e5041dedd0..bb00e42a1a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1872,6 +1872,8 @@ def _where_input_wrangler( ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), + TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align), + TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool), ) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))