You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
6.9 KiB
212 lines
6.9 KiB
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
from yolov6.layers.common import *
|
|
|
|
|
|
class Detect(nn.Module):
|
|
'''Efficient Decoupled Head
|
|
With hardware-aware degisn, the decoupled head is optimized with
|
|
hybridchannels methods.
|
|
'''
|
|
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
|
|
super().__init__()
|
|
assert head_layers is not None
|
|
self.nc = num_classes # number of classes
|
|
self.no = num_classes + 5 # number of outputs per anchor
|
|
self.nl = num_layers # number of detection layers
|
|
if isinstance(anchors, (list, tuple)):
|
|
self.na = len(anchors[0]) // 2
|
|
else:
|
|
self.na = anchors
|
|
self.anchors = anchors
|
|
self.grid = [torch.zeros(1)] * num_layers
|
|
self.prior_prob = 1e-2
|
|
self.inplace = inplace
|
|
stride = [8, 16, 32] # strides computed during build
|
|
self.stride = torch.tensor(stride)
|
|
|
|
# Init decouple head
|
|
self.cls_convs = nn.ModuleList()
|
|
self.reg_convs = nn.ModuleList()
|
|
self.cls_preds = nn.ModuleList()
|
|
self.reg_preds = nn.ModuleList()
|
|
self.obj_preds = nn.ModuleList()
|
|
self.stems = nn.ModuleList()
|
|
|
|
# Efficient decoupled head layers
|
|
for i in range(num_layers):
|
|
idx = i*6
|
|
self.stems.append(head_layers[idx])
|
|
self.cls_convs.append(head_layers[idx+1])
|
|
self.reg_convs.append(head_layers[idx+2])
|
|
self.cls_preds.append(head_layers[idx+3])
|
|
self.reg_preds.append(head_layers[idx+4])
|
|
self.obj_preds.append(head_layers[idx+5])
|
|
|
|
def initialize_biases(self):
|
|
for conv in self.cls_preds:
|
|
b = conv.bias.view(self.na, -1)
|
|
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
|
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
for conv in self.obj_preds:
|
|
b = conv.bias.view(self.na, -1)
|
|
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
|
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
|
|
def forward(self, x):
|
|
z = []
|
|
for i in range(self.nl):
|
|
x[i] = self.stems[i](x[i])
|
|
cls_x = x[i]
|
|
reg_x = x[i]
|
|
cls_feat = self.cls_convs[i](cls_x)
|
|
cls_output = self.cls_preds[i](cls_feat)
|
|
reg_feat = self.reg_convs[i](reg_x)
|
|
reg_output = self.reg_preds[i](reg_feat)
|
|
obj_output = self.obj_preds[i](reg_feat)
|
|
if self.training:
|
|
x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
|
|
bs, _, ny, nx = x[i].shape
|
|
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
else:
|
|
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
|
|
bs, _, ny, nx = y.shape
|
|
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
if self.grid[i].shape[2:4] != y.shape[2:4]:
|
|
d = self.stride.device
|
|
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
|
|
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
|
|
if self.inplace:
|
|
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
|
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
|
else:
|
|
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
|
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
|
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
|
z.append(y.view(bs, -1, self.no))
|
|
return x if self.training else torch.cat(z, 1)
|
|
|
|
|
|
def build_effidehead_layer(channels_list, num_anchors, num_classes):
|
|
head_layers = nn.Sequential(
|
|
# stem0
|
|
Conv(
|
|
in_channels=channels_list[6],
|
|
out_channels=channels_list[6],
|
|
kernel_size=1,
|
|
stride=1
|
|
),
|
|
# cls_conv0
|
|
Conv(
|
|
in_channels=channels_list[6],
|
|
out_channels=channels_list[6],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# reg_conv0
|
|
Conv(
|
|
in_channels=channels_list[6],
|
|
out_channels=channels_list[6],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# cls_pred0
|
|
nn.Conv2d(
|
|
in_channels=channels_list[6],
|
|
out_channels=num_classes * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# reg_pred0
|
|
nn.Conv2d(
|
|
in_channels=channels_list[6],
|
|
out_channels=4 * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# obj_pred0
|
|
nn.Conv2d(
|
|
in_channels=channels_list[6],
|
|
out_channels=1 * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# stem1
|
|
Conv(
|
|
in_channels=channels_list[8],
|
|
out_channels=channels_list[8],
|
|
kernel_size=1,
|
|
stride=1
|
|
),
|
|
# cls_conv1
|
|
Conv(
|
|
in_channels=channels_list[8],
|
|
out_channels=channels_list[8],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# reg_conv1
|
|
Conv(
|
|
in_channels=channels_list[8],
|
|
out_channels=channels_list[8],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# cls_pred1
|
|
nn.Conv2d(
|
|
in_channels=channels_list[8],
|
|
out_channels=num_classes * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# reg_pred1
|
|
nn.Conv2d(
|
|
in_channels=channels_list[8],
|
|
out_channels=4 * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# obj_pred1
|
|
nn.Conv2d(
|
|
in_channels=channels_list[8],
|
|
out_channels=1 * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# stem2
|
|
Conv(
|
|
in_channels=channels_list[10],
|
|
out_channels=channels_list[10],
|
|
kernel_size=1,
|
|
stride=1
|
|
),
|
|
# cls_conv2
|
|
Conv(
|
|
in_channels=channels_list[10],
|
|
out_channels=channels_list[10],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# reg_conv2
|
|
Conv(
|
|
in_channels=channels_list[10],
|
|
out_channels=channels_list[10],
|
|
kernel_size=3,
|
|
stride=1
|
|
),
|
|
# cls_pred2
|
|
nn.Conv2d(
|
|
in_channels=channels_list[10],
|
|
out_channels=num_classes * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# reg_pred2
|
|
nn.Conv2d(
|
|
in_channels=channels_list[10],
|
|
out_channels=4 * num_anchors,
|
|
kernel_size=1
|
|
),
|
|
# obj_pred2
|
|
nn.Conv2d(
|
|
in_channels=channels_list[10],
|
|
out_channels=1 * num_anchors,
|
|
kernel_size=1
|
|
)
|
|
)
|
|
return head_layers
|