You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
3.0 KiB
109 lines
3.0 KiB
import torch
|
|
from torch import nn
|
|
from yolov6.layers.common import RepBlock, SimConv, Transpose
|
|
|
|
|
|
class RepPANNeck(nn.Module):
|
|
"""RepPANNeck Module
|
|
EfficientRep is the default backbone of this model.
|
|
RepPANNeck has the balance of feature fusion ability and hardware efficiency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels_list=None,
|
|
num_repeats=None
|
|
):
|
|
super().__init__()
|
|
|
|
assert channels_list is not None
|
|
assert num_repeats is not None
|
|
|
|
self.Rep_p4 = RepBlock(
|
|
in_channels=channels_list[3] + channels_list[5],
|
|
out_channels=channels_list[5],
|
|
n=num_repeats[5],
|
|
)
|
|
|
|
self.Rep_p3 = RepBlock(
|
|
in_channels=channels_list[2] + channels_list[6],
|
|
out_channels=channels_list[6],
|
|
n=num_repeats[6]
|
|
)
|
|
|
|
self.Rep_n3 = RepBlock(
|
|
in_channels=channels_list[6] + channels_list[7],
|
|
out_channels=channels_list[8],
|
|
n=num_repeats[7],
|
|
)
|
|
|
|
self.Rep_n4 = RepBlock(
|
|
in_channels=channels_list[5] + channels_list[9],
|
|
out_channels=channels_list[10],
|
|
n=num_repeats[8]
|
|
)
|
|
|
|
self.reduce_layer0 = SimConv(
|
|
in_channels=channels_list[4],
|
|
out_channels=channels_list[5],
|
|
kernel_size=1,
|
|
stride=1
|
|
)
|
|
|
|
self.upsample0 = Transpose(
|
|
in_channels=channels_list[5],
|
|
out_channels=channels_list[5],
|
|
)
|
|
|
|
self.reduce_layer1 = SimConv(
|
|
in_channels=channels_list[5],
|
|
out_channels=channels_list[6],
|
|
kernel_size=1,
|
|
stride=1
|
|
)
|
|
|
|
self.upsample1 = Transpose(
|
|
in_channels=channels_list[6],
|
|
out_channels=channels_list[6]
|
|
)
|
|
|
|
self.downsample2 = SimConv(
|
|
in_channels=channels_list[6],
|
|
out_channels=channels_list[7],
|
|
kernel_size=3,
|
|
stride=2
|
|
)
|
|
|
|
self.downsample1 = SimConv(
|
|
in_channels=channels_list[8],
|
|
out_channels=channels_list[9],
|
|
kernel_size=3,
|
|
stride=2
|
|
)
|
|
|
|
def forward(self, input):
|
|
|
|
(x2, x1, x0) = input
|
|
|
|
fpn_out0 = self.reduce_layer0(x0)
|
|
upsample_feat0 = self.upsample0(fpn_out0)
|
|
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1)
|
|
f_out0 = self.Rep_p4(f_concat_layer0)
|
|
|
|
fpn_out1 = self.reduce_layer1(f_out0)
|
|
upsample_feat1 = self.upsample1(fpn_out1)
|
|
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1)
|
|
pan_out2 = self.Rep_p3(f_concat_layer1)
|
|
|
|
down_feat1 = self.downsample2(pan_out2)
|
|
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1)
|
|
pan_out1 = self.Rep_n3(p_concat_layer1)
|
|
|
|
down_feat0 = self.downsample1(pan_out1)
|
|
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1)
|
|
pan_out0 = self.Rep_n4(p_concat_layer2)
|
|
|
|
outputs = [pan_out2, pan_out1, pan_out0]
|
|
|
|
return outputs
|