You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

336 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

""" Full assembly of the parts to form the complete network """
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from .unet_parts import DoubleConv, Down, Up, OutConv
from .swin_transformer import SwinTransformerBlock, Conv
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
class SEUNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(SEUNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.se1 = SEBlock(128) # 添加SE Block
self.down2 = (Down(128, 256))
self.se2 = SEBlock(256) # 添加SE Block
self.down3 = (Down(256, 512))
self.se3 = SEBlock(512) # 添加SE Block
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.se4 = SEBlock(1024 // factor) # 添加SE Block
self.up1 = (Up(1024, 512 // factor, bilinear))
self.se_up1 = SEBlock(512 // factor) # 添加SE Block
self.up2 = (Up(512, 256 // factor, bilinear))
self.se_up2 = SEBlock(256 // factor) # 添加SE Block
self.up3 = (Up(256, 128 // factor, bilinear))
self.se_up3 = SEBlock(128 // factor) # 添加SE Block
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x2 = self.se1(x2) # 添加SE Block
x3 = self.down2(x2)
x3 = self.se2(x3) # 添加SE Block
x4 = self.down3(x3)
x4 = self.se3(x4) # 添加SE Block
x5 = self.down4(x4)
x5 = self.se4(x5) # 添加SE Block
x = self.up1(x5, x4)
x = self.se_up1(x) # 添加SE Block
x = self.up2(x, x3)
x = self.se_up2(x) # 添加SE Block
x = self.up3(x, x2)
x = self.se_up3(x) # 添加SE Block
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
# SE注意力机制
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
# print(f"Input size to SEBlock: {x.size()}")
y = self.avg_pool(x).view(b, c)
# print(f"Size after avg_pool: {y.size()}")
y = self.fc(y).view(b, c, 1, 1)
# print(f"Size after fc: {y.size()}")
return x * y.expand_as(x)
class BasicBlock(nn.Module):
expansion = 1 # 通道扩充比例
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
expansion = 4
'''
expansion 是通道扩充的比例
注意实际输出channel = middle_channels * BottleNeck.expansion
'''
def __init__(self, in_channels, middle_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, middle_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != middle_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, middle_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(middle_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class Bottleneck(nn.Module):
"""Standard bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
"""'forward()' applies the YOLO FPN to input data."""
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C2f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
expansion.
"""
super().__init__()
self.c = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class C2fST(C2f):
"""C2f module with Swin TransformerBlock()."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
num_heads = self.c // 32
self.m = nn.ModuleList(SwinTransformerBlock(self.c, self.c, num_heads, n) for _ in range(n))
class UResnet(nn.Module):
def __init__(self, block, layers, n_channels, n_classes, bilinear):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
nb_filter = [64, 128, 256, 512, 1024]
self.in_channel = nb_filter[0]
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv0_0 = DoubleConv(n_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = self._make_layer(block, nb_filter[1], layers[0], 1)
self.trans1 = SwinTransformerBlock(nb_filter[1], nb_filter[1], num_heads = 4, num_layers = 2)
self.conv2_0 = self._make_layer(block, nb_filter[2], layers[1], 1)
self.trans2 = SwinTransformerBlock(nb_filter[2], nb_filter[2], num_heads = 8, num_layers = 2)
self.conv3_0 = self._make_layer(block, nb_filter[3], layers[2], 1)
self.trans3 = SwinTransformerBlock(nb_filter[3], nb_filter[3], num_heads = 16, num_layers = 2)
self.conv4_0 = self._make_layer(block, nb_filter[4], layers[3], 1)
self.trans4 = SwinTransformerBlock(nb_filter[4], nb_filter[4], num_heads = 32, num_layers = 2)
self.conv3_1 = DoubleConv((nb_filter[3] + nb_filter[4]) * block.expansion, nb_filter[3] * block.expansion, nb_filter[3])
self.conv2_2 = DoubleConv((nb_filter[2] + nb_filter[3]) * block.expansion, nb_filter[2] * block.expansion, nb_filter[2])
self.conv1_3 = DoubleConv((nb_filter[1] + nb_filter[2]) * block.expansion, nb_filter[1] * block.expansion, nb_filter[1])
self.conv0_4 = DoubleConv(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])
self.final = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
def _make_layer(self, block, middle_channel, num_blocks, stride):
"""
middle_channels中间维度实际输出channels = middle_channels * block.expansion
num_blocks一个Layer包含block的个数
"""
strides = [stride] + [1] * (num_blocks - 1)
layers = []
# for stride in strides:
# layers.append(block(self.in_channel, middle_channel, stride))
# self.in_channel = middle_channel * block.expansion
for stride in strides:
layers.append(block(self.in_channel, middle_channel, stride))
self.in_channel = middle_channel * block.expansion # 更新输入通道数为当前层的输出通道数
return nn.Sequential(*layers)
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x1_0 = self.trans1(x1_0)
# print("conv1_0:", x1_0.shape)
x2_0 = self.conv2_0(self.pool(x1_0))
x2_0 = self.trans2(x2_0)
# print("conv2_0:", x2_0.shape)
x3_0 = self.conv3_0(self.pool(x2_0))
x3_0 = self.trans3(x3_0)
# print("conv3_0:", x3_0.shape)
x4_0 = self.conv4_0(self.pool(x3_0))
x4_0 = self.trans4(x4_0)
# print("conv4_0:", x4_0.shape)
# x4_0 = self.trans1(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
output = self.final(x0_4)
return output
class UResnet34(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet34, self).__init__(block=BasicBlock,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet50(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet50, self).__init__(block=BottleNeck,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet101(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet101, self).__init__(block=BottleNeck,layers=[3,4,23,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
class UResnet152(UResnet):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UResnet152, self).__init__(block=BottleNeck,layers=[3,8,36,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)