You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
434 lines
16 KiB
434 lines
16 KiB
from typing import Tuple
|
|
|
|
import torch
|
|
import torchvision
|
|
from torch import Tensor
|
|
from torchvision.extension import _assert_has_ops
|
|
|
|
from ..utils import _log_api_usage_once
|
|
from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh
|
|
from ._utils import _upcast
|
|
|
|
|
|
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
|
|
"""
|
|
Performs non-maximum suppression (NMS) on the boxes according
|
|
to their intersection-over-union (IoU).
|
|
|
|
NMS iteratively removes lower scoring boxes which have an
|
|
IoU greater than ``iou_threshold`` with another (higher scoring)
|
|
box.
|
|
|
|
If multiple boxes have the exact same score and satisfy the IoU
|
|
criterion with respect to a reference box, the selected box is
|
|
not guaranteed to be the same between CPU and GPU. This is similar
|
|
to the behavior of argsort in PyTorch when repeated values are present.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4])): boxes to perform NMS on. They
|
|
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
|
|
``0 <= y1 < y2``.
|
|
scores (Tensor[N]): scores for each one of the boxes
|
|
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
|
|
|
|
Returns:
|
|
Tensor: int64 tensor with the indices of the elements that have been kept
|
|
by NMS, sorted in decreasing order of scores
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(nms)
|
|
_assert_has_ops()
|
|
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
|
|
|
|
|
|
def batched_nms(
|
|
boxes: Tensor,
|
|
scores: Tensor,
|
|
idxs: Tensor,
|
|
iou_threshold: float,
|
|
) -> Tensor:
|
|
"""
|
|
Performs non-maximum suppression in a batched fashion.
|
|
|
|
Each index value correspond to a category, and NMS
|
|
will not be applied between elements of different categories.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes where NMS will be performed. They
|
|
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
|
|
``0 <= y1 < y2``.
|
|
scores (Tensor[N]): scores for each one of the boxes
|
|
idxs (Tensor[N]): indices of the categories for each one of the boxes.
|
|
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
|
|
|
|
Returns:
|
|
Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
|
|
in decreasing order of scores
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(batched_nms)
|
|
# Benchmarks that drove the following thresholds are at
|
|
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
|
|
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
|
|
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
|
|
else:
|
|
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
|
|
|
|
|
|
@torch.jit._script_if_tracing
|
|
def _batched_nms_coordinate_trick(
|
|
boxes: Tensor,
|
|
scores: Tensor,
|
|
idxs: Tensor,
|
|
iou_threshold: float,
|
|
) -> Tensor:
|
|
# strategy: in order to perform NMS independently per class,
|
|
# we add an offset to all the boxes. The offset is dependent
|
|
# only on the class idx, and is large enough so that boxes
|
|
# from different classes do not overlap
|
|
if boxes.numel() == 0:
|
|
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
|
max_coordinate = boxes.max()
|
|
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
|
|
boxes_for_nms = boxes + offsets[:, None]
|
|
keep = nms(boxes_for_nms, scores, iou_threshold)
|
|
return keep
|
|
|
|
|
|
@torch.jit._script_if_tracing
|
|
def _batched_nms_vanilla(
|
|
boxes: Tensor,
|
|
scores: Tensor,
|
|
idxs: Tensor,
|
|
iou_threshold: float,
|
|
) -> Tensor:
|
|
# Based on Detectron2 implementation, just manually call nms() on each class independently
|
|
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
|
|
for class_id in torch.unique(idxs):
|
|
curr_indices = torch.where(idxs == class_id)[0]
|
|
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
|
|
keep_mask[curr_indices[curr_keep_indices]] = True
|
|
keep_indices = torch.where(keep_mask)[0]
|
|
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
|
|
|
|
|
|
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
|
|
"""
|
|
Remove every box from ``boxes`` which contains at least one side length
|
|
that is smaller than ``min_size``.
|
|
|
|
.. note::
|
|
For sanitizing a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
|
|
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
|
|
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
min_size (float): minimum size
|
|
|
|
Returns:
|
|
Tensor[K]: indices of the boxes that have both sides
|
|
larger than ``min_size``
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(remove_small_boxes)
|
|
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
|
|
keep = (ws >= min_size) & (hs >= min_size)
|
|
keep = torch.where(keep)[0]
|
|
return keep
|
|
|
|
|
|
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
|
|
"""
|
|
Clip boxes so that they lie inside an image of size ``size``.
|
|
|
|
.. note::
|
|
For clipping a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
|
|
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
|
|
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
size (Tuple[height, width]): size of the image
|
|
|
|
Returns:
|
|
Tensor[N, 4]: clipped boxes
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(clip_boxes_to_image)
|
|
dim = boxes.dim()
|
|
boxes_x = boxes[..., 0::2]
|
|
boxes_y = boxes[..., 1::2]
|
|
height, width = size
|
|
|
|
if torchvision._is_tracing():
|
|
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
|
|
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
|
|
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
|
|
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
|
|
else:
|
|
boxes_x = boxes_x.clamp(min=0, max=width)
|
|
boxes_y = boxes_y.clamp(min=0, max=height)
|
|
|
|
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
|
|
return clipped_boxes.reshape(boxes.shape)
|
|
|
|
|
|
def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
|
|
"""
|
|
Converts :class:`torch.Tensor` boxes from a given ``in_fmt`` to ``out_fmt``.
|
|
|
|
.. note::
|
|
For converting a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.BoundingBoxes` object
|
|
between different formats,
|
|
consider using :func:`~torchvision.transforms.v2.functional.convert_bounding_box_format` instead.
|
|
Or see the corresponding transform :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat`.
|
|
|
|
Supported ``in_fmt`` and ``out_fmt`` strings are:
|
|
|
|
``'xyxy'``: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
|
|
This is the format that torchvision utilities expect.
|
|
|
|
``'xywh'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
|
|
|
|
``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h
|
|
being width and height.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes which will be converted.
|
|
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
|
|
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
|
|
|
|
Returns:
|
|
Tensor[N, 4]: Boxes into converted format.
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(box_convert)
|
|
allowed_fmts = ("xyxy", "xywh", "cxcywh")
|
|
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
|
|
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
|
|
|
|
if in_fmt == out_fmt:
|
|
return boxes.clone()
|
|
|
|
if in_fmt != "xyxy" and out_fmt != "xyxy":
|
|
# convert to xyxy and change in_fmt xyxy
|
|
if in_fmt == "xywh":
|
|
boxes = _box_xywh_to_xyxy(boxes)
|
|
elif in_fmt == "cxcywh":
|
|
boxes = _box_cxcywh_to_xyxy(boxes)
|
|
in_fmt = "xyxy"
|
|
|
|
if in_fmt == "xyxy":
|
|
if out_fmt == "xywh":
|
|
boxes = _box_xyxy_to_xywh(boxes)
|
|
elif out_fmt == "cxcywh":
|
|
boxes = _box_xyxy_to_cxcywh(boxes)
|
|
elif out_fmt == "xyxy":
|
|
if in_fmt == "xywh":
|
|
boxes = _box_xywh_to_xyxy(boxes)
|
|
elif in_fmt == "cxcywh":
|
|
boxes = _box_cxcywh_to_xyxy(boxes)
|
|
return boxes
|
|
|
|
|
|
def box_area(boxes: Tensor) -> Tensor:
|
|
"""
|
|
Computes the area of a set of bounding boxes, which are specified by their
|
|
(x1, y1, x2, y2) coordinates.
|
|
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
|
|
are expected to be in (x1, y1, x2, y2) format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
|
Returns:
|
|
Tensor[N]: the area for each box
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(box_area)
|
|
boxes = _upcast(boxes)
|
|
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
|
|
|
|
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
|
|
# with slight modifications
|
|
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
|
|
area1 = box_area(boxes1)
|
|
area2 = box_area(boxes2)
|
|
|
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
|
|
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
|
|
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
|
|
union = area1[:, None] + area2 - inter
|
|
|
|
return inter, union
|
|
|
|
|
|
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
|
"""
|
|
Return intersection-over-union (Jaccard index) between two sets of boxes.
|
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
|
Args:
|
|
boxes1 (Tensor[N, 4]): first set of boxes
|
|
boxes2 (Tensor[M, 4]): second set of boxes
|
|
|
|
Returns:
|
|
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(box_iou)
|
|
inter, union = _box_inter_union(boxes1, boxes2)
|
|
iou = inter / union
|
|
return iou
|
|
|
|
|
|
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
|
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
|
"""
|
|
Return generalized intersection-over-union (Jaccard index) between two sets of boxes.
|
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
|
Args:
|
|
boxes1 (Tensor[N, 4]): first set of boxes
|
|
boxes2 (Tensor[M, 4]): second set of boxes
|
|
|
|
Returns:
|
|
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
|
|
for every element in boxes1 and boxes2
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(generalized_box_iou)
|
|
|
|
inter, union = _box_inter_union(boxes1, boxes2)
|
|
iou = inter / union
|
|
|
|
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
|
|
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
|
|
areai = whi[:, :, 0] * whi[:, :, 1]
|
|
|
|
return iou - (areai - union) / areai
|
|
|
|
|
|
def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
|
|
"""
|
|
Return complete intersection-over-union (Jaccard index) between two sets of boxes.
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
Args:
|
|
boxes1 (Tensor[N, 4]): first set of boxes
|
|
boxes2 (Tensor[M, 4]): second set of boxes
|
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7
|
|
Returns:
|
|
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
|
|
for every element in boxes1 and boxes2
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(complete_box_iou)
|
|
|
|
boxes1 = _upcast(boxes1)
|
|
boxes2 = _upcast(boxes2)
|
|
|
|
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
|
|
|
|
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
|
|
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
|
|
|
|
w_gt = boxes2[:, 2] - boxes2[:, 0]
|
|
h_gt = boxes2[:, 3] - boxes2[:, 1]
|
|
|
|
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
|
|
with torch.no_grad():
|
|
alpha = v / (1 - iou + v + eps)
|
|
return diou - alpha * v
|
|
|
|
|
|
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
|
|
"""
|
|
Return distance intersection-over-union (Jaccard index) between two sets of boxes.
|
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
|
Args:
|
|
boxes1 (Tensor[N, 4]): first set of boxes
|
|
boxes2 (Tensor[M, 4]): second set of boxes
|
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7
|
|
|
|
Returns:
|
|
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
|
|
for every element in boxes1 and boxes2
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(distance_box_iou)
|
|
|
|
boxes1 = _upcast(boxes1)
|
|
boxes2 = _upcast(boxes2)
|
|
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
|
|
return diou
|
|
|
|
|
|
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:
|
|
|
|
iou = box_iou(boxes1, boxes2)
|
|
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
|
|
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
|
|
# centers of boxes
|
|
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
|
|
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
|
|
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
|
|
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
|
|
# The distance between boxes' centers squared.
|
|
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
|
|
_upcast((y_p[:, None] - y_g[None, :])) ** 2
|
|
)
|
|
# The distance IoU is the IoU penalized by a normalized
|
|
# distance between boxes' centers squared.
|
|
return iou - (centers_distance_squared / diagonal_distance_squared), iou
|
|
|
|
|
|
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute the bounding boxes around the provided masks.
|
|
|
|
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
|
Args:
|
|
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
|
|
and (H, W) are the spatial dimensions.
|
|
|
|
Returns:
|
|
Tensor[N, 4]: bounding boxes
|
|
"""
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(masks_to_boxes)
|
|
if masks.numel() == 0:
|
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
|
|
|
|
n = masks.shape[0]
|
|
|
|
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
|
|
|
|
for index, mask in enumerate(masks):
|
|
y, x = torch.where(mask != 0)
|
|
|
|
bounding_boxes[index, 0] = torch.min(x)
|
|
bounding_boxes[index, 1] = torch.min(y)
|
|
bounding_boxes[index, 2] = torch.max(x)
|
|
bounding_boxes[index, 3] = torch.max(y)
|
|
|
|
return bounding_boxes
|