You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.9 KiB

from typing import List, Union
import torch
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def roi_pool(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
) -> Tensor:
"""
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
Args:
input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
contains ``C`` feature maps of dimensions ``H x W``.
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from.
The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
If a single Tensor is passed, then the first column should
contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
in the batch.
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the box coordinates to
the input coordinates. For example, if your boxes are defined on the scale
of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
the original image), you'll want to set this to 0.5. Default: 1.0
Returns:
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_pool)
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
return output
class RoIPool(nn.Module):
"""
See :func:`roi_pool`.
"""
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super().__init__()
_log_api_usage_once(self)
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
return s