From 41d7e55393d0cd73a6ec0168027e5a91e579c38e Mon Sep 17 00:00:00 2001 From: pfo49kjfx <2512671328@qq.com> Date: Mon, 4 Dec 2023 16:57:31 +0800 Subject: [PATCH] ADD file via upload --- Unet_model.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 Unet_model.py diff --git a/Unet_model.py b/Unet_model.py new file mode 100644 index 0000000..92eb98a --- /dev/null +++ b/Unet_model.py @@ -0,0 +1,39 @@ +""" Full assembly of the parts to form the complete network """ +"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py""" + +import torch.nn.functional as F +from Unet_parts import * + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 512) + self.up1 = Up(1024, 256, bilinear) + self.up2 = Up(512, 128, bilinear) + self.up3 = Up(256, 64, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + out = self.outc(x) + return out + +