from typing import Tuple import torch import torchvision from torch import Tensor from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh from ._utils import _upcast def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). NMS iteratively removes lower scoring boxes which have an IoU greater than ``iou_threshold`` with another (higher scoring) box. If multiple boxes have the exact same score and satisfy the IoU criterion with respect to a reference box, the selected box is not guaranteed to be the same between CPU and GPU. This is similar to the behavior of argsort in PyTorch when repeated values are present. Args: boxes (Tensor[N, 4])): boxes to perform NMS on. They are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold Returns: Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(nms) _assert_has_ops() return torch.ops.torchvision.nms(boxes, scores, iou_threshold) def batched_nms( boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float, ) -> Tensor: """ Performs non-maximum suppression in a batched fashion. Each index value correspond to a category, and NMS will not be applied between elements of different categories. Args: boxes (Tensor[N, 4]): boxes where NMS will be performed. They are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. scores (Tensor[N]): scores for each one of the boxes idxs (Tensor[N]): indices of the categories for each one of the boxes. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold Returns: Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(batched_nms) # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) else: return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) @torch.jit._script_if_tracing def _batched_nms_coordinate_trick( boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float, ) -> Tensor: # strategy: in order to perform NMS independently per class, # we add an offset to all the boxes. The offset is dependent # only on the class idx, and is large enough so that boxes # from different classes do not overlap if boxes.numel() == 0: return torch.empty((0,), dtype=torch.int64, device=boxes.device) max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) boxes_for_nms = boxes + offsets[:, None] keep = nms(boxes_for_nms, scores, iou_threshold) return keep @torch.jit._script_if_tracing def _batched_nms_vanilla( boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float, ) -> Tensor: # Based on Detectron2 implementation, just manually call nms() on each class independently keep_mask = torch.zeros_like(scores, dtype=torch.bool) for class_id in torch.unique(idxs): curr_indices = torch.where(idxs == class_id)[0] curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold) keep_mask[curr_indices[curr_keep_indices]] = True keep_indices = torch.where(keep_mask)[0] return keep_indices[scores[keep_indices].sort(descending=True)[1]] def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: """ Remove every box from ``boxes`` which contains at least one side length that is smaller than ``min_size``. .. note:: For sanitizing a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead. Args: boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. min_size (float): minimum size Returns: Tensor[K]: indices of the boxes that have both sides larger than ``min_size`` """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(remove_small_boxes) ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] keep = (ws >= min_size) & (hs >= min_size) keep = torch.where(keep)[0] return keep def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor: """ Clip boxes so that they lie inside an image of size ``size``. .. note:: For clipping a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead. Args: boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. size (Tuple[height, width]): size of the image Returns: Tensor[N, 4]: clipped boxes """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(clip_boxes_to_image) dim = boxes.dim() boxes_x = boxes[..., 0::2] boxes_y = boxes[..., 1::2] height, width = size if torchvision._is_tracing(): boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) return clipped_boxes.reshape(boxes.shape) def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: """ Converts :class:`torch.Tensor` boxes from a given ``in_fmt`` to ``out_fmt``. .. note:: For converting a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.BoundingBoxes` object between different formats, consider using :func:`~torchvision.transforms.v2.functional.convert_bounding_box_format` instead. Or see the corresponding transform :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat`. Supported ``in_fmt`` and ``out_fmt`` strings are: ``'xyxy'``: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is the format that torchvision utilities expect. ``'xywh'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. ``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height. Args: boxes (Tensor[N, 4]): boxes which will be converted. in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']. out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'] Returns: Tensor[N, 4]: Boxes into converted format. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_convert) allowed_fmts = ("xyxy", "xywh", "cxcywh") if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts: raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt") if in_fmt == out_fmt: return boxes.clone() if in_fmt != "xyxy" and out_fmt != "xyxy": # convert to xyxy and change in_fmt xyxy if in_fmt == "xywh": boxes = _box_xywh_to_xyxy(boxes) elif in_fmt == "cxcywh": boxes = _box_cxcywh_to_xyxy(boxes) in_fmt = "xyxy" if in_fmt == "xyxy": if out_fmt == "xywh": boxes = _box_xyxy_to_xywh(boxes) elif out_fmt == "cxcywh": boxes = _box_xyxy_to_cxcywh(boxes) elif out_fmt == "xyxy": if in_fmt == "xywh": boxes = _box_xywh_to_xyxy(boxes) elif in_fmt == "cxcywh": boxes = _box_cxcywh_to_xyxy(boxes) return boxes def box_area(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their (x1, y1, x2, y2) coordinates. Args: boxes (Tensor[N, 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Returns: Tensor[N]: the area for each box """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_area) boxes = _upcast(boxes) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter return inter, union def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ Return intersection-over-union (Jaccard index) between two sets of boxes. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: boxes1 (Tensor[N, 4]): first set of boxes boxes2 (Tensor[M, 4]): second set of boxes Returns: Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_iou) inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union return iou # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ Return generalized intersection-over-union (Jaccard index) between two sets of boxes. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: boxes1 (Tensor[N, 4]): first set of boxes boxes2 (Tensor[M, 4]): second set of boxes Returns: Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou) inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] areai = whi[:, :, 0] * whi[:, :, 1] return iou - (areai - union) / areai def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: """ Return complete intersection-over-union (Jaccard index) between two sets of boxes. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: boxes1 (Tensor[N, 4]): first set of boxes boxes2 (Tensor[M, 4]): second set of boxes eps (float, optional): small number to prevent division by zero. Default: 1e-7 Returns: Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou) boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) diou, iou = _box_diou_iou(boxes1, boxes2, eps) w_pred = boxes1[:, None, 2] - boxes1[:, None, 0] h_pred = boxes1[:, None, 3] - boxes1[:, None, 1] w_gt = boxes2[:, 2] - boxes2[:, 0] h_gt = boxes2[:, 3] - boxes2[:, 1] v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) return diou - alpha * v def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: """ Return distance intersection-over-union (Jaccard index) between two sets of boxes. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: boxes1 (Tensor[N, 4]): first set of boxes boxes2 (Tensor[M, 4]): second set of boxes eps (float, optional): small number to prevent division by zero. Default: 1e-7 Returns: Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(distance_box_iou) boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps) return diou def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: iou = box_iou(boxes1, boxes2) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps # centers of boxes x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 # The distance between boxes' centers squared. centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + ( _upcast((y_p[:, None] - y_g[None, :])) ** 2 ) # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: """ Compute the bounding boxes around the provided masks. Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: masks (Tensor[N, H, W]): masks to transform where N is the number of masks and (H, W) are the spatial dimensions. Returns: Tensor[N, 4]: bounding boxes """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(masks_to_boxes) if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.float) n = masks.shape[0] bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) for index, mask in enumerate(masks): y, x = torch.where(mask != 0) bounding_boxes[index, 0] = torch.min(x) bounding_boxes[index, 1] = torch.min(y) bounding_boxes[index, 2] = torch.max(x) bounding_boxes[index, 3] = torch.max(y) return bounding_boxes