From c3908abc416bc8721d84aeff8fa49271eeb751ce Mon Sep 17 00:00:00 2001 From: pfo49kjfx <2512671328@qq.com> Date: Mon, 4 Dec 2023 16:57:40 +0800 Subject: [PATCH] ADD file via upload --- Unet_parts.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 Unet_parts.py diff --git a/Unet_parts.py b/Unet_parts.py new file mode 100644 index 0000000..8619b33 --- /dev/null +++ b/Unet_parts.py @@ -0,0 +1,79 @@ +""" Parts of the U-Net model """ +"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2 + + + (element-wise add)""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + self.element_wise_add = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)# //为整数除法 + + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) + diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x)