You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
2.7 KiB
103 lines
2.7 KiB
from torch import nn
|
|
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
|
|
|
|
|
|
class EfficientRep(nn.Module):
|
|
'''EfficientRep Backbone
|
|
EfficientRep is handcrafted by hardware-aware neural network design.
|
|
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
channels_list=None,
|
|
num_repeats=None,
|
|
):
|
|
super().__init__()
|
|
|
|
assert channels_list is not None
|
|
assert num_repeats is not None
|
|
|
|
self.stem = RepVGGBlock(
|
|
in_channels=in_channels,
|
|
out_channels=channels_list[0],
|
|
kernel_size=3,
|
|
stride=2
|
|
)
|
|
|
|
self.ERBlock_2 = nn.Sequential(
|
|
RepVGGBlock(
|
|
in_channels=channels_list[0],
|
|
out_channels=channels_list[1],
|
|
kernel_size=3,
|
|
stride=2
|
|
),
|
|
RepBlock(
|
|
in_channels=channels_list[1],
|
|
out_channels=channels_list[1],
|
|
n=num_repeats[1]
|
|
)
|
|
)
|
|
|
|
self.ERBlock_3 = nn.Sequential(
|
|
RepVGGBlock(
|
|
in_channels=channels_list[1],
|
|
out_channels=channels_list[2],
|
|
kernel_size=3,
|
|
stride=2
|
|
),
|
|
RepBlock(
|
|
in_channels=channels_list[2],
|
|
out_channels=channels_list[2],
|
|
n=num_repeats[2]
|
|
)
|
|
)
|
|
|
|
self.ERBlock_4 = nn.Sequential(
|
|
RepVGGBlock(
|
|
in_channels=channels_list[2],
|
|
out_channels=channels_list[3],
|
|
kernel_size=3,
|
|
stride=2
|
|
),
|
|
RepBlock(
|
|
in_channels=channels_list[3],
|
|
out_channels=channels_list[3],
|
|
n=num_repeats[3]
|
|
)
|
|
)
|
|
|
|
self.ERBlock_5 = nn.Sequential(
|
|
RepVGGBlock(
|
|
in_channels=channels_list[3],
|
|
out_channels=channels_list[4],
|
|
kernel_size=3,
|
|
stride=2,
|
|
),
|
|
RepBlock(
|
|
in_channels=channels_list[4],
|
|
out_channels=channels_list[4],
|
|
n=num_repeats[4]
|
|
),
|
|
SimSPPF(
|
|
in_channels=channels_list[4],
|
|
out_channels=channels_list[4],
|
|
kernel_size=5
|
|
)
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
outputs = []
|
|
x = self.stem(x)
|
|
x = self.ERBlock_2(x)
|
|
x = self.ERBlock_3(x)
|
|
outputs.append(x)
|
|
x = self.ERBlock_4(x)
|
|
outputs.append(x)
|
|
x = self.ERBlock_5(x)
|
|
outputs.append(x)
|
|
|
|
return tuple(outputs)
|