""" Full assembly of the parts to form the complete network """ import torch import torch.nn as nn import torchvision.transforms as transforms from .unet_parts import DoubleConv, Down, Up, OutConv from .swin_transformer import SwinTransformerBlock, Conv class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = (DoubleConv(n_channels, 64)) self.down1 = (Down(64, 128)) self.down2 = (Down(128, 256)) self.down3 = (Down(256, 512)) factor = 2 if bilinear else 1 self.down4 = (Down(512, 1024 // factor)) self.up1 = (Up(1024, 512 // factor, bilinear)) self.up2 = (Up(512, 256 // factor, bilinear)) self.up3 = (Up(256, 128 // factor, bilinear)) self.up4 = (Up(128, 64, bilinear)) self.outc = (OutConv(64, n_classes)) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits def use_checkpointing(self): self.inc = torch.utils.checkpoint(self.inc) self.down1 = torch.utils.checkpoint(self.down1) self.down2 = torch.utils.checkpoint(self.down2) self.down3 = torch.utils.checkpoint(self.down3) self.down4 = torch.utils.checkpoint(self.down4) self.up1 = torch.utils.checkpoint(self.up1) self.up2 = torch.utils.checkpoint(self.up2) self.up3 = torch.utils.checkpoint(self.up3) self.up4 = torch.utils.checkpoint(self.up4) self.outc = torch.utils.checkpoint(self.outc) class SEUNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=False): super(SEUNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = (DoubleConv(n_channels, 64)) self.down1 = (Down(64, 128)) self.se1 = SEBlock(128) # 添加SE Block self.down2 = (Down(128, 256)) self.se2 = SEBlock(256) # 添加SE Block self.down3 = (Down(256, 512)) self.se3 = SEBlock(512) # 添加SE Block factor = 2 if bilinear else 1 self.down4 = (Down(512, 1024 // factor)) self.se4 = SEBlock(1024 // factor) # 添加SE Block self.up1 = (Up(1024, 512 // factor, bilinear)) self.se_up1 = SEBlock(512 // factor) # 添加SE Block self.up2 = (Up(512, 256 // factor, bilinear)) self.se_up2 = SEBlock(256 // factor) # 添加SE Block self.up3 = (Up(256, 128 // factor, bilinear)) self.se_up3 = SEBlock(128 // factor) # 添加SE Block self.up4 = (Up(128, 64, bilinear)) self.outc = (OutConv(64, n_classes)) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x2 = self.se1(x2) # 添加SE Block x3 = self.down2(x2) x3 = self.se2(x3) # 添加SE Block x4 = self.down3(x3) x4 = self.se3(x4) # 添加SE Block x5 = self.down4(x4) x5 = self.se4(x5) # 添加SE Block x = self.up1(x5, x4) x = self.se_up1(x) # 添加SE Block x = self.up2(x, x3) x = self.se_up2(x) # 添加SE Block x = self.up3(x, x2) x = self.se_up3(x) # 添加SE Block x = self.up4(x, x1) logits = self.outc(x) return logits def use_checkpointing(self): self.inc = torch.utils.checkpoint(self.inc) self.down1 = torch.utils.checkpoint(self.down1) self.down2 = torch.utils.checkpoint(self.down2) self.down3 = torch.utils.checkpoint(self.down3) self.down4 = torch.utils.checkpoint(self.down4) self.up1 = torch.utils.checkpoint(self.up1) self.up2 = torch.utils.checkpoint(self.up2) self.up3 = torch.utils.checkpoint(self.up3) self.up4 = torch.utils.checkpoint(self.up4) self.outc = torch.utils.checkpoint(self.outc) # SE注意力机制 class SEBlock(nn.Module): def __init__(self, in_channels, reduction=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, in_channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() # print(f"Input size to SEBlock: {x.size()}") y = self.avg_pool(x).view(b, c) # print(f"Size after avg_pool: {y.size()}") y = self.fc(y).view(b, c, 1, 1) # print(f"Size after fc: {y.size()}") return x * y.expand_as(x) class BasicBlock(nn.Module): expansion = 1 # 通道扩充比例 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != BasicBlock.expansion * out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) class BottleNeck(nn.Module): expansion = 4 ''' expansion 是通道扩充的比例 注意实际输出channel = middle_channels * BottleNeck.expansion ''' def __init__(self, in_channels, middle_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, middle_channels, kernel_size=1, bias=False), nn.BatchNorm2d(middle_channels), nn.ReLU(inplace=True), nn.Conv2d(middle_channels, middle_channels, stride=stride, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(middle_channels), nn.ReLU(inplace=True), nn.Conv2d(middle_channels, middle_channels * BottleNeck.expansion, kernel_size=1, bias=False), nn.BatchNorm2d(middle_channels * BottleNeck.expansion), ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != middle_channels * BottleNeck.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, middle_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), nn.BatchNorm2d(middle_channels * BottleNeck.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) class Bottleneck(nn.Module): """Standard bottleneck.""" def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and expansion. """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, k[0], 1) self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 def forward(self, x): """'forward()' applies the YOLO FPN to input data.""" return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class C2f(nn.Module): """Faster Implementation of CSP Bottleneck with 2 convolutions.""" def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, expansion. """ super().__init__() self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) def forward(self, x): """Forward pass through C2f layer.""" y = list(self.cv1(x).chunk(2, 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) def forward_split(self, x): """Forward pass using split() instead of chunk().""" y = list(self.cv1(x).split((self.c, self.c), 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) class C2fST(C2f): """C2f module with Swin TransformerBlock().""" def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): super().__init__(c1, c2, n, shortcut, g, e) num_heads = self.c // 32 self.m = nn.ModuleList(SwinTransformerBlock(self.c, self.c, num_heads, n) for _ in range(n)) class UResnet(nn.Module): def __init__(self, block, layers, n_channels, n_classes, bilinear): super().__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear nb_filter = [64, 128, 256, 512, 1024] self.in_channel = nb_filter[0] self.pool = nn.MaxPool2d(2, 2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv0_0 = DoubleConv(n_channels, nb_filter[0], nb_filter[0]) self.conv1_0 = self._make_layer(block, nb_filter[1], layers[0], 1) self.trans1 = SwinTransformerBlock(nb_filter[1], nb_filter[1], num_heads = 4, num_layers = 2) self.conv2_0 = self._make_layer(block, nb_filter[2], layers[1], 1) self.trans2 = SwinTransformerBlock(nb_filter[2], nb_filter[2], num_heads = 8, num_layers = 2) self.conv3_0 = self._make_layer(block, nb_filter[3], layers[2], 1) self.trans3 = SwinTransformerBlock(nb_filter[3], nb_filter[3], num_heads = 16, num_layers = 2) self.conv4_0 = self._make_layer(block, nb_filter[4], layers[3], 1) self.trans4 = SwinTransformerBlock(nb_filter[4], nb_filter[4], num_heads = 32, num_layers = 2) self.conv3_1 = DoubleConv((nb_filter[3] + nb_filter[4]) * block.expansion, nb_filter[3] * block.expansion, nb_filter[3]) self.conv2_2 = DoubleConv((nb_filter[2] + nb_filter[3]) * block.expansion, nb_filter[2] * block.expansion, nb_filter[2]) self.conv1_3 = DoubleConv((nb_filter[1] + nb_filter[2]) * block.expansion, nb_filter[1] * block.expansion, nb_filter[1]) self.conv0_4 = DoubleConv(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0]) self.final = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1) def _make_layer(self, block, middle_channel, num_blocks, stride): """ middle_channels中间维度,实际输出channels = middle_channels * block.expansion num_blocks,一个Layer包含block的个数 """ strides = [stride] + [1] * (num_blocks - 1) layers = [] # for stride in strides: # layers.append(block(self.in_channel, middle_channel, stride)) # self.in_channel = middle_channel * block.expansion for stride in strides: layers.append(block(self.in_channel, middle_channel, stride)) self.in_channel = middle_channel * block.expansion # 更新输入通道数为当前层的输出通道数 return nn.Sequential(*layers) def forward(self, input): x0_0 = self.conv0_0(input) x1_0 = self.conv1_0(self.pool(x0_0)) x1_0 = self.trans1(x1_0) # print("conv1_0:", x1_0.shape) x2_0 = self.conv2_0(self.pool(x1_0)) x2_0 = self.trans2(x2_0) # print("conv2_0:", x2_0.shape) x3_0 = self.conv3_0(self.pool(x2_0)) x3_0 = self.trans3(x3_0) # print("conv3_0:", x3_0.shape) x4_0 = self.conv4_0(self.pool(x3_0)) x4_0 = self.trans4(x4_0) # print("conv4_0:", x4_0.shape) # x4_0 = self.trans1(self.pool(x3_0)) x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1)) x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1)) x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1)) output = self.final(x0_4) return output class UResnet34(UResnet): def __init__(self, n_channels, n_classes=2, bilinear=False): super(UResnet34, self).__init__(block=BasicBlock,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear) class UResnet50(UResnet): def __init__(self, n_channels, n_classes=2, bilinear=False): super(UResnet50, self).__init__(block=BottleNeck,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear) class UResnet101(UResnet): def __init__(self, n_channels, n_classes=2, bilinear=False): super(UResnet101, self).__init__(block=BottleNeck,layers=[3,4,23,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear) class UResnet152(UResnet): def __init__(self, n_channels, n_classes=2, bilinear=False): super(UResnet152, self).__init__(block=BottleNeck,layers=[3,8,36,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)