You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
3.5 KiB

5 months ago
from typing import List, Optional, Tuple, Union
import torch
from torch import nn, Tensor
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
# TODO add back the assert
# assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
concat_boxes = _cat([b for b in boxes], dim=0)
temp = []
for i, b in enumerate(boxes):
temp.append(torch.full_like(b[:, :1], i))
ids = _cat(temp, dim=0)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
if isinstance(boxes, (list, tuple)):
for _tensor in boxes:
torch._assert(
_tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
)
elif isinstance(boxes, torch.Tensor):
torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]")
else:
torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]")
return
def split_normalization_params(
model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
]
for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")
classes = tuple(norm_classes)
norm_params = []
other_params = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
def _upcast_non_float(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t
def _loss_inter_union(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
return intsctk, unionk