From 1fdc19bf249d312e85a99872e5ef336785ef109f Mon Sep 17 00:00:00 2001 From: p8cpi3xy4 <3658590771@qq.com> Date: Sun, 28 Dec 2025 17:28:07 +0800 Subject: [PATCH] ADD file via upload --- networks.py | 125 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 networks.py diff --git a/networks.py b/networks.py new file mode 100644 index 0000000..50f208c --- /dev/null +++ b/networks.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lib.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5 +from lib.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from lib.decoders import EMCAD + + +class EMCADNet(nn.Module): + def __init__(self, num_classes=1, kernel_sizes=[1,3,5], expansion_factor=2, dw_parallel=True, add=True, lgag_ks=3, activation='relu', encoder='pvt_v2_b2', pretrain=True, pretrained_dir='./pretrained_pth/pvt/'): + super(EMCADNet, self).__init__() + + # conv block to convert single channel to 3 channels + self.conv = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=1), + nn.BatchNorm2d(3), + nn.ReLU(inplace=True) + ) + + # backbone network initialization with pretrained weight + if encoder == 'pvt_v2_b0': + self.backbone = pvt_v2_b0() + path = pretrained_dir + '/pvt_v2_b0.pth' + channels=[256, 160, 64, 32] + elif encoder == 'pvt_v2_b1': + self.backbone = pvt_v2_b1() + path = pretrained_dir + '/pvt_v2_b1.pth' + channels=[512, 320, 128, 64] + elif encoder == 'pvt_v2_b2': + self.backbone = pvt_v2_b2() + path = pretrained_dir + '/pvt_v2_b2.pth' + channels=[512, 320, 128, 64] + elif encoder == 'pvt_v2_b3': + self.backbone = pvt_v2_b3() + path = pretrained_dir + '/pvt_v2_b3.pth' + channels=[512, 320, 128, 64] + elif encoder == 'pvt_v2_b4': + self.backbone = pvt_v2_b4() + path = pretrained_dir + '/pvt_v2_b4.pth' + channels=[512, 320, 128, 64] + elif encoder == 'pvt_v2_b5': + self.backbone = pvt_v2_b5() + path = pretrained_dir + '/pvt_v2_b5.pth' + channels=[512, 320, 128, 64] + elif encoder == 'resnet18': + self.backbone = resnet18(pretrained=pretrain) + channels=[512, 256, 128, 64] + elif encoder == 'resnet34': + self.backbone = resnet34(pretrained=pretrain) + channels=[512, 256, 128, 64] + elif encoder == 'resnet50': + self.backbone = resnet50(pretrained=pretrain) + channels=[2048, 1024, 512, 256] + elif encoder == 'resnet101': + self.backbone = resnet101(pretrained=pretrain) + channels=[2048, 1024, 512, 256] + elif encoder == 'resnet152': + self.backbone = resnet152(pretrained=pretrain) + channels=[2048, 1024, 512, 256] + else: + print('Encoder not implemented! Continuing with default encoder pvt_v2_b2.') + self.backbone = pvt_v2_b2() + path = pretrained_dir + '/pvt_v2_b2.pth' + channels=[512, 320, 128, 64] + + if pretrain==True and 'pvt_v2' in encoder: + save_model = torch.load(path) + model_dict = self.backbone.state_dict() + state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} + model_dict.update(state_dict) + self.backbone.load_state_dict(model_dict) + + print('Model %s created, param count: %d' % + (encoder+' backbone: ', sum([m.numel() for m in self.backbone.parameters()]))) + + # decoder initialization + self.decoder = EMCAD(channels=channels, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add, lgag_ks=lgag_ks, activation=activation) + + print('Model %s created, param count: %d' % + ('EMCAD decoder: ', sum([m.numel() for m in self.decoder.parameters()]))) + + self.out_head4 = nn.Conv2d(channels[0], num_classes, 1) + self.out_head3 = nn.Conv2d(channels[1], num_classes, 1) + self.out_head2 = nn.Conv2d(channels[2], num_classes, 1) + self.out_head1 = nn.Conv2d(channels[3], num_classes, 1) + + def forward(self, x, mode='test'): + + # if grayscale input, convert to 3 channels + if x.size()[1] == 1: + x = self.conv(x) + + # encoder + x1, x2, x3, x4 = self.backbone(x) + #print(x1.shape, x2.shape, x3.shape, x4.shape) + + # decoder + dec_outs = self.decoder(x4, [x3, x2, x1]) + + # prediction heads + p4 = self.out_head4(dec_outs[0]) + p3 = self.out_head3(dec_outs[1]) + p2 = self.out_head2(dec_outs[2]) + p1 = self.out_head1(dec_outs[3]) + + p4 = F.interpolate(p4, scale_factor=32, mode='bilinear') + p3 = F.interpolate(p3, scale_factor=16, mode='bilinear') + p2 = F.interpolate(p2, scale_factor=8, mode='bilinear') + p1 = F.interpolate(p1, scale_factor=4, mode='bilinear') + + if mode == 'test': + return [p4, p3, p2, p1] + + return [p4, p3, p2, p1] + + + +if __name__ == '__main__': + model = EMCADNet().cuda() + input_tensor = torch.randn(1, 3, 352, 352).cuda() + + P = model(input_tensor) + print(P[0].size(), P[1].size(), P[2].size(), P[3].size()) +