You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

336 lines
14 KiB

5 months ago
""" Full assembly of the parts to form the complete network """
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from .unet_parts import DoubleConv, Down, Up, OutConv
from .swin_transformer import SwinTransformerBlock, Conv
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
class SEUNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(SEUNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.se1 = SEBlock(128) # 添加SE Block
self.down2 = (Down(128, 256))
self.se2 = SEBlock(256) # 添加SE Block
self.down3 = (Down(256, 512))
self.se3 = SEBlock(512) # 添加SE Block
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.se4 = SEBlock(1024 // factor) # 添加SE Block
self.up1 = (Up(1024, 512 // factor, bilinear))
self.se_up1 = SEBlock(512 // factor) # 添加SE Block
self.up2 = (Up(512, 256 // factor, bilinear))
self.se_up2 = SEBlock(256 // factor) # 添加SE Block
self.up3 = (Up(256, 128 // factor, bilinear))
self.se_up3 = SEBlock(128 // factor) # 添加SE Block
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x2 = self.se1(x2) # 添加SE Block
x3 = self.down2(x2)
x3 = self.se2(x3) # 添加SE Block
x4 = self.down3(x3)
x4 = self.se3(x4) # 添加SE Block
x5 = self.down4(x4)
x5 = self.se4(x5) # 添加SE Block
x = self.up1(x5, x4)
x = self.se_up1(x) # 添加SE Block
x = self.up2(x, x3)
x = self.se_up2(x) # 添加SE Block
x = self.up3(x, x2)
x = self.se_up3(x) # 添加SE Block
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
# SE注意力机制
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
# print(f"Input size to SEBlock: {x.size()}")
y = self.avg_pool(x).view(b, c)
# print(f"Size after avg_pool: {y.size()}")
y = self.fc(y).view(b, c, 1, 1)
# print(f"Size after fc: {y.size()}")
return x * y.expand_as(x)
class BasicBlock(nn.Module):
expansion = 1 # 通道扩充比例
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
expansion = 4
'''
expansion 是通道扩充的比例
注意实际输出channel = middle_channels * BottleNeck.expansion
'''
def __init__(self, in_channels, middle_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, middle_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != middle_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, middle_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class Bottleneck(nn.Module):
"""Standard bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
"""'forward()' applies the YOLO FPN to input data."""
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C2f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
expansion.
"""
super().__init__()
self.c = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class C2fST(C2f):
"""C2f module with Swin TransformerBlock()."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
num_heads = self.c // 32
self.m = nn.ModuleList(SwinTransformerBlock(self.c, self.c, num_heads, n) for _ in range(n))
class UResnet(nn.Module):
def __init__(self, block, layers, n_channels, n_classes, bilinear):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
nb_filter = [64, 128, 256, 512, 1024]
self.in_channel = nb_filter[0]
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv0_0 = DoubleConv(n_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = self._make_layer(block, nb_filter[1], layers[0], 1)
self.trans1 = SwinTransformerBlock(nb_filter[1], nb_filter[1], num_heads = 4, num_layers = 2)
self.conv2_0 = self._make_layer(block, nb_filter[2], layers[1], 1)
self.trans2 = SwinTransformerBlock(nb_filter[2], nb_filter[2], num_heads = 8, num_layers = 2)
self.conv3_0 = self._make_layer(block, nb_filter[3], layers[2], 1)
self.trans3 = SwinTransformerBlock(nb_filter[3], nb_filter[3], num_heads = 16, num_layers = 2)
self.conv4_0 = self._make_layer(block, nb_filter[4], layers[3], 1)
self.trans4 = SwinTransformerBlock(nb_filter[4], nb_filter[4], num_heads = 32, num_layers = 2)
self.conv3_1 = DoubleConv((nb_filter[3] + nb_filter[4]) * block.expansion, nb_filter[3] * block.expansion, nb_filter[3])
self.conv2_2 = DoubleConv((nb_filter[2] + nb_filter[3]) * block.expansion, nb_filter[2] * block.expansion, nb_filter[2])
self.conv1_3 = DoubleConv((nb_filter[1] + nb_filter[2]) * block.expansion, nb_filter[1] * block.expansion, nb_filter[1])
self.conv0_4 = DoubleConv(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])
self.final = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
def _make_layer(self, block, middle_channel, num_blocks, stride):
"""
middle_channels中间维度实际输出channels = middle_channels * block.expansion
num_blocks一个Layer包含block的个数
"""
strides = [stride] + [1] * (num_blocks - 1)
layers = []
# for stride in strides:
# layers.append(block(self.in_channel, middle_channel, stride))
# self.in_channel = middle_channel * block.expansion
for stride in strides:
layers.append(block(self.in_channel, middle_channel, stride))
self.in_channel = middle_channel * block.expansion # 更新输入通道数为当前层的输出通道数
return nn.Sequential(*layers)
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x1_0 = self.trans1(x1_0)
# print("conv1_0:", x1_0.shape)
x2_0 = self.conv2_0(self.pool(x1_0))
x2_0 = self.trans2(x2_0)
# print("conv2_0:", x2_0.shape)
x3_0 = self.conv3_0(self.pool(x2_0))
x3_0 = self.trans3(x3_0)
# print("conv3_0:", x3_0.shape)
x4_0 = self.conv4_0(self.pool(x3_0))
x4_0 = self.trans4(x4_0)
# print("conv4_0:", x4_0.shape)
# x4_0 = self.trans1(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
output = self.final(x0_4)
return output
class UResnet34(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet34, self).__init__(block=BasicBlock,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet50(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet50, self).__init__(block=BottleNeck,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet101(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet101, self).__init__(block=BottleNeck,layers=[3,4,23,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet152(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet152, self).__init__(block=BottleNeck,layers=[3,8,36,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)