You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
5.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5
from lib.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from lib.decoders import EMCAD
class EMCADNet(nn.Module):
def __init__(self, num_classes=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True, add=True, lgag_ks=3,
activation='relu', encoder='pvt_v2_b2', pretrain=True, pretrained_dir='./pretrained_pth/pvt/'):
super(EMCADNet, self).__init__()
# conv block to convert single channel to 3 channels
self.conv = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=1),
nn.BatchNorm2d(3),
nn.ReLU(inplace=True)
)
# backbone network initialization with pretrained weight
if encoder == 'pvt_v2_b0':
self.backbone = pvt_v2_b0()
path = pretrained_dir + '/pvt_v2_b0.pth'
channels = [256, 160, 64, 32]
elif encoder == 'pvt_v2_b1':
self.backbone = pvt_v2_b1()
path = pretrained_dir + '/pvt_v2_b1.pth'
channels = [512, 320, 128, 64]
elif encoder == 'pvt_v2_b2':
self.backbone = pvt_v2_b2()
path = pretrained_dir + '/pvt_v2_b2.pth'
channels = [512, 320, 128, 64]
elif encoder == 'pvt_v2_b3':
self.backbone = pvt_v2_b3()
path = pretrained_dir + '/pvt_v2_b3.pth'
channels = [512, 320, 128, 64]
elif encoder == 'pvt_v2_b4':
self.backbone = pvt_v2_b4()
path = pretrained_dir + '/pvt_v2_b4.pth'
channels = [512, 320, 128, 64]
elif encoder == 'pvt_v2_b5':
self.backbone = pvt_v2_b5()
path = pretrained_dir + '/pvt_v2_b5.pth'
channels = [512, 320, 128, 64]
elif encoder == 'resnet18':
self.backbone = resnet18(pretrained=pretrain)
channels = [512, 256, 128, 64]
elif encoder == 'resnet34':
self.backbone = resnet34(pretrained=pretrain)
channels = [512, 256, 128, 64]
elif encoder == 'resnet50':
self.backbone = resnet50(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
elif encoder == 'resnet101':
self.backbone = resnet101(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
elif encoder == 'resnet152':
self.backbone = resnet152(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
else:
print('Encoder not implemented! Continuing with default encoder pvt_v2_b2.')
self.backbone = pvt_v2_b2()
path = pretrained_dir + '/pvt_v2_b2.pth'
channels = [512, 320, 128, 64]
if pretrain == True and 'pvt_v2' in encoder:
# 添加map_location参数
save_model = torch.load(path, map_location='cpu')
model_dict = self.backbone.state_dict()
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
self.backbone.load_state_dict(model_dict)
print('Model %s created, param count: %d' %
(encoder + ' backbone: ', sum([m.numel() for m in self.backbone.parameters()])))
# decoder initialization
self.decoder = EMCAD(channels=channels, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
dw_parallel=dw_parallel, add=add, lgag_ks=lgag_ks, activation=activation)
print('Model %s created, param count: %d' %
('EMCAD decoder: ', sum([m.numel() for m in self.decoder.parameters()])))
self.out_head4 = nn.Conv2d(channels[0], num_classes, 1)
self.out_head3 = nn.Conv2d(channels[1], num_classes, 1)
self.out_head2 = nn.Conv2d(channels[2], num_classes, 1)
self.out_head1 = nn.Conv2d(channels[3], num_classes, 1)
def forward(self, x, mode='test'):
# if grayscale input, convert to 3 channels
if x.size()[1] == 1:
x = self.conv(x)
# encoder
x1, x2, x3, x4 = self.backbone(x)
# print(x1.shape, x2.shape, x3.shape, x4.shape)
# decoder
dec_outs = self.decoder(x4, [x3, x2, x1])
# prediction heads
p4 = self.out_head4(dec_outs[0])
p3 = self.out_head3(dec_outs[1])
p2 = self.out_head2(dec_outs[2])
p1 = self.out_head1(dec_outs[3])
p4 = F.interpolate(p4, scale_factor=32, mode='bilinear')
p3 = F.interpolate(p3, scale_factor=16, mode='bilinear')
p2 = F.interpolate(p2, scale_factor=8, mode='bilinear')
p1 = F.interpolate(p1, scale_factor=4, mode='bilinear')
if mode == 'test':
return [p4, p3, p2, p1]
return [p4, p3, p2, p1]
if __name__ == '__main__':
# 修改测试代码移除GPU相关操作
model = EMCADNet()
input_tensor = torch.randn(1, 3, 352, 352)
P = model(input_tensor)
print(P[0].size(), P[1].size(), P[2].size(), P[3].size())