From 4c3bbde89a376523d3cbc7e26ae537a40a11b739 Mon Sep 17 00:00:00 2001
From: Minecraft_zzz <2806700029@qq.com>
Date: Tue, 2 Jul 2024 15:13:53 +0800
Subject: [PATCH] Unet System
---
.gitignore | 2 +
.idea/.gitignore | 3 +
.idea/DigitalImageFinal.iml | 8 +
.../inspectionProfiles/profiles_settings.xml | 6 +
.idea/misc.xml | 4 +
.idea/modules.xml | 8 +
MainWindow.py | 517 ++++++++++
MyUi.py | 456 +++++++++
MyWindow.py | 61 ++
README.md | 1 +
main.py | 18 +
out.png | Bin 0 -> 2389 bytes
output.py | 10 +
predict.py | 215 +++++
ui/MainWindow.ui | 789 ++++++++++++++++
ui/aa.ui | 622 ++++++++++++
ui/bb.ui | 882 ++++++++++++++++++
ui/main.ui | 779 ++++++++++++++++
unet/__init__.py | 1 +
unet/swin_transformer.py | 721 ++++++++++++++
unet/unet_model.py | 336 +++++++
unet/unet_parts.py | 77 ++
utils/__init__.py | 0
utils/data_loading.py | 146 +++
utils/dice_score.py | 28 +
utils/utils.py | 13 +
26 files changed, 5703 insertions(+)
create mode 100644 .idea/.gitignore
create mode 100644 .idea/DigitalImageFinal.iml
create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 MainWindow.py
create mode 100644 MyUi.py
create mode 100644 MyWindow.py
create mode 100644 main.py
create mode 100644 out.png
create mode 100644 output.py
create mode 100644 predict.py
create mode 100644 ui/MainWindow.ui
create mode 100644 ui/aa.ui
create mode 100644 ui/bb.ui
create mode 100644 ui/main.ui
create mode 100644 unet/__init__.py
create mode 100644 unet/swin_transformer.py
create mode 100644 unet/unet_model.py
create mode 100644 unet/unet_parts.py
create mode 100644 utils/__init__.py
create mode 100644 utils/data_loading.py
create mode 100644 utils/dice_score.py
create mode 100644 utils/utils.py
diff --git a/.gitignore b/.gitignore
index 5d381cc..efbcd2e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+.pth
+
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..359bb53
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/.idea/DigitalImageFinal.iml b/.idea/DigitalImageFinal.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/DigitalImageFinal.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b69999c
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..c86f2f4
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/MainWindow.py b/MainWindow.py
new file mode 100644
index 0000000..ccc27b7
--- /dev/null
+++ b/MainWindow.py
@@ -0,0 +1,517 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'ui/bb.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.9
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
+from PyQt5 import QtCore, QtGui, QtWidgets
+
+
+class Ui_MainWindow(object):
+ def setupUi(self, MainWindow):
+ MainWindow.setObjectName("MainWindow")
+ MainWindow.resize(1092, 799)
+ MainWindow.setMinimumSize(QtCore.QSize(0, 0))
+ MainWindow.setMaximumSize(QtCore.QSize(10000, 10000))
+ self.centralwidget = QtWidgets.QWidget(MainWindow)
+ self.centralwidget.setObjectName("centralwidget")
+ self.horizontalLayout = QtWidgets.QHBoxLayout(self.centralwidget)
+ self.horizontalLayout.setObjectName("horizontalLayout")
+ self.horizontalLayout_12 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_12.setObjectName("horizontalLayout_12")
+ self.verticalLayout_20 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_20.setObjectName("verticalLayout_20")
+ spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_20.addItem(spacerItem)
+ self.verticalLayout_21 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_21.setObjectName("verticalLayout_21")
+ self.horizontalLayout_13 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_13.setObjectName("horizontalLayout_13")
+ self.line_17 = QtWidgets.QFrame(self.centralwidget)
+ self.line_17.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_17.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_17.setObjectName("line_17")
+ self.horizontalLayout_13.addWidget(self.line_17)
+ self.label_8 = QtWidgets.QLabel(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.label_8.sizePolicy().hasHeightForWidth())
+ self.label_8.setSizePolicy(sizePolicy)
+ self.label_8.setObjectName("label_8")
+ self.horizontalLayout_13.addWidget(self.label_8)
+ self.line_18 = QtWidgets.QFrame(self.centralwidget)
+ self.line_18.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_18.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_18.setObjectName("line_18")
+ self.horizontalLayout_13.addWidget(self.line_18)
+ self.verticalLayout_21.addLayout(self.horizontalLayout_13)
+ self.verticalLayout_22 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_22.setObjectName("verticalLayout_22")
+ self.horizontalLayout_21 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_21.setObjectName("horizontalLayout_21")
+ self.corrosionBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.corrosionBtn.setObjectName("corrosionBtn")
+ self.horizontalLayout_21.addWidget(self.corrosionBtn)
+ self.expansionBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.expansionBtn.setObjectName("expansionBtn")
+ self.horizontalLayout_21.addWidget(self.expansionBtn)
+ self.verticalLayout_22.addLayout(self.horizontalLayout_21)
+ self.horizontalLayout_22 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_22.setObjectName("horizontalLayout_22")
+ self.OpenBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.OpenBtn.setObjectName("OpenBtn")
+ self.horizontalLayout_22.addWidget(self.OpenBtn)
+ self.CloseBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.CloseBtn.setObjectName("CloseBtn")
+ self.horizontalLayout_22.addWidget(self.CloseBtn)
+ self.verticalLayout_22.addLayout(self.horizontalLayout_22)
+ self.verticalLayout_21.addLayout(self.verticalLayout_22)
+ self.verticalLayout_20.addLayout(self.verticalLayout_21)
+ spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_20.addItem(spacerItem1)
+ self.verticalLayout_23 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_23.setObjectName("verticalLayout_23")
+ self.horizontalLayout_14 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_14.setObjectName("horizontalLayout_14")
+ self.line_19 = QtWidgets.QFrame(self.centralwidget)
+ self.line_19.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_19.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_19.setObjectName("line_19")
+ self.horizontalLayout_14.addWidget(self.line_19)
+ self.label_11 = QtWidgets.QLabel(self.centralwidget)
+ self.label_11.setObjectName("label_11")
+ self.horizontalLayout_14.addWidget(self.label_11)
+ self.line_20 = QtWidgets.QFrame(self.centralwidget)
+ self.line_20.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_20.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_20.setObjectName("line_20")
+ self.horizontalLayout_14.addWidget(self.line_20)
+ self.verticalLayout_23.addLayout(self.horizontalLayout_14)
+ self.verticalLayout_24 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_24.setObjectName("verticalLayout_24")
+ self.horizontalLayout_23 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_23.setObjectName("horizontalLayout_23")
+ self.verticalLayout_25 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_25.setObjectName("verticalLayout_25")
+ self.histogramBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.histogramBtn.setObjectName("histogramBtn")
+ self.verticalLayout_25.addWidget(self.histogramBtn)
+ self.MeanBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.MeanBtn.setObjectName("MeanBtn")
+ self.verticalLayout_25.addWidget(self.MeanBtn)
+ self.medianBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.medianBtn.setObjectName("medianBtn")
+ self.verticalLayout_25.addWidget(self.medianBtn)
+ self.PrewittBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.PrewittBtn.setObjectName("PrewittBtn")
+ self.verticalLayout_25.addWidget(self.PrewittBtn)
+ self.LowpassBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.LowpassBtn.setObjectName("LowpassBtn")
+ self.verticalLayout_25.addWidget(self.LowpassBtn)
+ self.horizontalLayout_23.addLayout(self.verticalLayout_25)
+ self.verticalLayout_26 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_26.setObjectName("verticalLayout_26")
+ self.EqualizeBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.EqualizeBtn.setObjectName("EqualizeBtn")
+ self.verticalLayout_26.addWidget(self.EqualizeBtn)
+ self.vagueBtn = QtWidgets.QPushButton(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.vagueBtn.sizePolicy().hasHeightForWidth())
+ self.vagueBtn.setSizePolicy(sizePolicy)
+ self.vagueBtn.setObjectName("vagueBtn")
+ self.verticalLayout_26.addWidget(self.vagueBtn)
+ self.RobertsBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.RobertsBtn.setObjectName("RobertsBtn")
+ self.verticalLayout_26.addWidget(self.RobertsBtn)
+ self.SobelBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.SobelBtn.setObjectName("SobelBtn")
+ self.verticalLayout_26.addWidget(self.SobelBtn)
+ self.HighpassBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.HighpassBtn.setObjectName("HighpassBtn")
+ self.verticalLayout_26.addWidget(self.HighpassBtn)
+ self.horizontalLayout_23.addLayout(self.verticalLayout_26)
+ self.verticalLayout_24.addLayout(self.horizontalLayout_23)
+ self.BoxBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.BoxBtn.setObjectName("BoxBtn")
+ self.verticalLayout_24.addWidget(self.BoxBtn)
+ self.verticalLayout_23.addLayout(self.verticalLayout_24)
+ self.verticalLayout_20.addLayout(self.verticalLayout_23)
+ spacerItem2 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_20.addItem(spacerItem2)
+ self.verticalLayout_27 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_27.setObjectName("verticalLayout_27")
+ self.horizontalLayout_24 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_24.setObjectName("horizontalLayout_24")
+ self.line_21 = QtWidgets.QFrame(self.centralwidget)
+ self.line_21.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_21.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_21.setObjectName("line_21")
+ self.horizontalLayout_24.addWidget(self.line_21)
+ self.label_13 = QtWidgets.QLabel(self.centralwidget)
+ self.label_13.setObjectName("label_13")
+ self.horizontalLayout_24.addWidget(self.label_13)
+ self.line_22 = QtWidgets.QFrame(self.centralwidget)
+ self.line_22.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_22.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_22.setObjectName("line_22")
+ self.horizontalLayout_24.addWidget(self.line_22)
+ self.verticalLayout_27.addLayout(self.horizontalLayout_24)
+ self.horizontalLayout_25 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_25.setObjectName("horizontalLayout_25")
+ self.ImportBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.ImportBtn.setObjectName("ImportBtn")
+ self.horizontalLayout_25.addWidget(self.ImportBtn)
+ self.SaveBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.SaveBtn.setObjectName("SaveBtn")
+ self.horizontalLayout_25.addWidget(self.SaveBtn)
+ self.verticalLayout_27.addLayout(self.horizontalLayout_25)
+ self.horizontalLayout_26 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_26.setObjectName("horizontalLayout_26")
+ self.verticalLayout_28 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_28.setObjectName("verticalLayout_28")
+ self.label_15 = QtWidgets.QLabel(self.centralwidget)
+ self.label_15.setObjectName("label_15")
+ self.verticalLayout_28.addWidget(self.label_15)
+ self.label_17 = QtWidgets.QLabel(self.centralwidget)
+ self.label_17.setObjectName("label_17")
+ self.verticalLayout_28.addWidget(self.label_17)
+ self.label_18 = QtWidgets.QLabel(self.centralwidget)
+ self.label_18.setObjectName("label_18")
+ self.verticalLayout_28.addWidget(self.label_18)
+ self.label_19 = QtWidgets.QLabel(self.centralwidget)
+ self.label_19.setObjectName("label_19")
+ self.verticalLayout_28.addWidget(self.label_19)
+ self.horizontalLayout_26.addLayout(self.verticalLayout_28)
+ self.verticalLayout_29 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_29.setObjectName("verticalLayout_29")
+ self.Label_H = QtWidgets.QLabel(self.centralwidget)
+ self.Label_H.setObjectName("Label_H")
+ self.verticalLayout_29.addWidget(self.Label_H)
+ self.Label_W = QtWidgets.QLabel(self.centralwidget)
+ self.Label_W.setObjectName("Label_W")
+ self.verticalLayout_29.addWidget(self.Label_W)
+ self.Label_T = QtWidgets.QLabel(self.centralwidget)
+ self.Label_T.setObjectName("Label_T")
+ self.verticalLayout_29.addWidget(self.Label_T)
+ self.Label_Type = QtWidgets.QLabel(self.centralwidget)
+ self.Label_Type.setObjectName("Label_Type")
+ self.verticalLayout_29.addWidget(self.Label_Type)
+ self.horizontalLayout_26.addLayout(self.verticalLayout_29)
+ self.verticalLayout_27.addLayout(self.horizontalLayout_26)
+ self.verticalLayout_20.addLayout(self.verticalLayout_27)
+ spacerItem3 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_20.addItem(spacerItem3)
+ self.horizontalLayout_12.addLayout(self.verticalLayout_20)
+ self.vertical_line1_2 = QtWidgets.QFrame(self.centralwidget)
+ self.vertical_line1_2.setFrameShape(QtWidgets.QFrame.VLine)
+ self.vertical_line1_2.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.vertical_line1_2.setObjectName("vertical_line1_2")
+ self.horizontalLayout_12.addWidget(self.vertical_line1_2)
+ self.verticalLayout_30 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_30.setObjectName("verticalLayout_30")
+ self.horizontalLayout_27 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_27.setObjectName("horizontalLayout_27")
+ self.line_23 = QtWidgets.QFrame(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.line_23.sizePolicy().hasHeightForWidth())
+ self.line_23.setSizePolicy(sizePolicy)
+ self.line_23.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_23.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_23.setObjectName("line_23")
+ self.horizontalLayout_27.addWidget(self.line_23)
+ self.label_20 = QtWidgets.QLabel(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.label_20.sizePolicy().hasHeightForWidth())
+ self.label_20.setSizePolicy(sizePolicy)
+ self.label_20.setObjectName("label_20")
+ self.horizontalLayout_27.addWidget(self.label_20)
+ self.line_24 = QtWidgets.QFrame(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.line_24.sizePolicy().hasHeightForWidth())
+ self.line_24.setSizePolicy(sizePolicy)
+ self.line_24.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_24.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_24.setObjectName("line_24")
+ self.horizontalLayout_27.addWidget(self.line_24)
+ self.verticalLayout_30.addLayout(self.horizontalLayout_27)
+ self.PicBefore = QtWidgets.QLabel(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.PicBefore.sizePolicy().hasHeightForWidth())
+ self.PicBefore.setSizePolicy(sizePolicy)
+ self.PicBefore.setText("")
+ self.PicBefore.setObjectName("PicBefore")
+ self.verticalLayout_30.addWidget(self.PicBefore)
+ self.horizontalLayout_28 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_28.setObjectName("horizontalLayout_28")
+ self.line_25 = QtWidgets.QFrame(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.line_25.sizePolicy().hasHeightForWidth())
+ self.line_25.setSizePolicy(sizePolicy)
+ self.line_25.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_25.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_25.setObjectName("line_25")
+ self.horizontalLayout_28.addWidget(self.line_25)
+ self.label_21 = QtWidgets.QLabel(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.label_21.sizePolicy().hasHeightForWidth())
+ self.label_21.setSizePolicy(sizePolicy)
+ self.label_21.setObjectName("label_21")
+ self.horizontalLayout_28.addWidget(self.label_21)
+ self.line_26 = QtWidgets.QFrame(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.line_26.sizePolicy().hasHeightForWidth())
+ self.line_26.setSizePolicy(sizePolicy)
+ self.line_26.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_26.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_26.setObjectName("line_26")
+ self.horizontalLayout_28.addWidget(self.line_26)
+ self.verticalLayout_30.addLayout(self.horizontalLayout_28)
+ self.PicAfter = QtWidgets.QLabel(self.centralwidget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.PicAfter.sizePolicy().hasHeightForWidth())
+ self.PicAfter.setSizePolicy(sizePolicy)
+ self.PicAfter.setText("")
+ self.PicAfter.setObjectName("PicAfter")
+ self.verticalLayout_30.addWidget(self.PicAfter)
+ self.verticalLayout_30.setStretch(0, 1)
+ self.verticalLayout_30.setStretch(1, 10)
+ self.verticalLayout_30.setStretch(2, 1)
+ self.verticalLayout_30.setStretch(3, 10)
+ self.horizontalLayout_12.addLayout(self.verticalLayout_30)
+ self.vertical_line2_2 = QtWidgets.QFrame(self.centralwidget)
+ self.vertical_line2_2.setFrameShape(QtWidgets.QFrame.VLine)
+ self.vertical_line2_2.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.vertical_line2_2.setObjectName("vertical_line2_2")
+ self.horizontalLayout_12.addWidget(self.vertical_line2_2)
+ self.verticalLayout_31 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_31.setObjectName("verticalLayout_31")
+ spacerItem4 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_31.addItem(spacerItem4)
+ self.verticalLayout_32 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_32.setObjectName("verticalLayout_32")
+ self.horizontalLayout_29 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_29.setObjectName("horizontalLayout_29")
+ self.line_27 = QtWidgets.QFrame(self.centralwidget)
+ self.line_27.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_27.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_27.setObjectName("line_27")
+ self.horizontalLayout_29.addWidget(self.line_27)
+ self.label_22 = QtWidgets.QLabel(self.centralwidget)
+ self.label_22.setObjectName("label_22")
+ self.horizontalLayout_29.addWidget(self.label_22)
+ self.line_28 = QtWidgets.QFrame(self.centralwidget)
+ self.line_28.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_28.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_28.setObjectName("line_28")
+ self.horizontalLayout_29.addWidget(self.line_28)
+ self.verticalLayout_32.addLayout(self.horizontalLayout_29)
+ self.verticalLayout_33 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_33.setObjectName("verticalLayout_33")
+ self.horizontalLayout_30 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_30.setObjectName("horizontalLayout_30")
+ self.LOGBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.LOGBtn.setObjectName("LOGBtn")
+ self.horizontalLayout_30.addWidget(self.LOGBtn)
+ self.ScharrBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.ScharrBtn.setObjectName("ScharrBtn")
+ self.horizontalLayout_30.addWidget(self.ScharrBtn)
+ self.verticalLayout_33.addLayout(self.horizontalLayout_30)
+ self.CannyBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.CannyBtn.setObjectName("CannyBtn")
+ self.verticalLayout_33.addWidget(self.CannyBtn)
+ self.verticalLayout_32.addLayout(self.verticalLayout_33)
+ self.verticalLayout_31.addLayout(self.verticalLayout_32)
+ spacerItem5 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_31.addItem(spacerItem5)
+ self.verticalLayout_34 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_34.setObjectName("verticalLayout_34")
+ self.horizontalLayout_31 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_31.setObjectName("horizontalLayout_31")
+ self.line_29 = QtWidgets.QFrame(self.centralwidget)
+ self.line_29.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_29.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_29.setObjectName("line_29")
+ self.horizontalLayout_31.addWidget(self.line_29)
+ self.label_23 = QtWidgets.QLabel(self.centralwidget)
+ self.label_23.setObjectName("label_23")
+ self.horizontalLayout_31.addWidget(self.label_23)
+ self.line_30 = QtWidgets.QFrame(self.centralwidget)
+ self.line_30.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_30.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_30.setObjectName("line_30")
+ self.horizontalLayout_31.addWidget(self.line_30)
+ self.verticalLayout_34.addLayout(self.horizontalLayout_31)
+ self.verticalLayout_35 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_35.setObjectName("verticalLayout_35")
+ self.horizontalLayout_32 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_32.setObjectName("horizontalLayout_32")
+ self.GrayscaleBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.GrayscaleBtn.setObjectName("GrayscaleBtn")
+ self.horizontalLayout_32.addWidget(self.GrayscaleBtn)
+ self.BinarizationBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.BinarizationBtn.setObjectName("BinarizationBtn")
+ self.horizontalLayout_32.addWidget(self.BinarizationBtn)
+ self.verticalLayout_35.addLayout(self.horizontalLayout_32)
+ self.geometryBtn = QtWidgets.QPushButton(self.centralwidget)
+ self.geometryBtn.setObjectName("geometryBtn")
+ self.verticalLayout_35.addWidget(self.geometryBtn)
+ self.verticalLayout_34.addLayout(self.verticalLayout_35)
+ self.verticalLayout_31.addLayout(self.verticalLayout_34)
+ spacerItem6 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_31.addItem(spacerItem6)
+ self.verticalLayout_36 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_36.setObjectName("verticalLayout_36")
+ self.verticalLayout_37 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_37.setObjectName("verticalLayout_37")
+ self.horizontalLayout_33 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_33.setObjectName("horizontalLayout_33")
+ self.verticalLayout_38 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_38.setObjectName("verticalLayout_38")
+ self.horizontalLayout_34 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_34.setObjectName("horizontalLayout_34")
+ self.line_31 = QtWidgets.QFrame(self.centralwidget)
+ self.line_31.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_31.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_31.setObjectName("line_31")
+ self.horizontalLayout_34.addWidget(self.line_31)
+ self.label_24 = QtWidgets.QLabel(self.centralwidget)
+ self.label_24.setObjectName("label_24")
+ self.horizontalLayout_34.addWidget(self.label_24)
+ self.line_32 = QtWidgets.QFrame(self.centralwidget)
+ self.line_32.setFrameShape(QtWidgets.QFrame.HLine)
+ self.line_32.setFrameShadow(QtWidgets.QFrame.Sunken)
+ self.line_32.setObjectName("line_32")
+ self.horizontalLayout_34.addWidget(self.line_32)
+ self.verticalLayout_38.addLayout(self.horizontalLayout_34)
+ self.feiyan = QtWidgets.QPushButton(self.centralwidget)
+ self.feiyan.setObjectName("feiyan")
+ self.verticalLayout_38.addWidget(self.feiyan)
+ self.horizontalLayout_33.addLayout(self.verticalLayout_38)
+ self.verticalLayout_37.addLayout(self.horizontalLayout_33)
+ self.verticalLayout_36.addLayout(self.verticalLayout_37)
+ self.verticalLayout_31.addLayout(self.verticalLayout_36)
+ spacerItem7 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_31.addItem(spacerItem7)
+ self.horizontalLayout_35 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_35.setObjectName("horizontalLayout_35")
+ self.verticalLayout_39 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_39.setObjectName("verticalLayout_39")
+ self.label_16 = QtWidgets.QLabel(self.centralwidget)
+ self.label_16.setObjectName("label_16")
+ self.verticalLayout_39.addWidget(self.label_16)
+ self.label_25 = QtWidgets.QLabel(self.centralwidget)
+ self.label_25.setObjectName("label_25")
+ self.verticalLayout_39.addWidget(self.label_25)
+ self.label_26 = QtWidgets.QLabel(self.centralwidget)
+ self.label_26.setObjectName("label_26")
+ self.verticalLayout_39.addWidget(self.label_26)
+ self.label_27 = QtWidgets.QLabel(self.centralwidget)
+ self.label_27.setObjectName("label_27")
+ self.verticalLayout_39.addWidget(self.label_27)
+ self.horizontalLayout_35.addLayout(self.verticalLayout_39)
+ self.verticalLayout_40 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_40.setObjectName("verticalLayout_40")
+ self.class0 = QtWidgets.QLabel(self.centralwidget)
+ self.class0.setObjectName("class0")
+ self.verticalLayout_40.addWidget(self.class0)
+ self.class3 = QtWidgets.QLabel(self.centralwidget)
+ self.class3.setObjectName("class3")
+ self.verticalLayout_40.addWidget(self.class3)
+ self.class1 = QtWidgets.QLabel(self.centralwidget)
+ self.class1.setObjectName("class1")
+ self.verticalLayout_40.addWidget(self.class1)
+ self.class2 = QtWidgets.QLabel(self.centralwidget)
+ self.class2.setObjectName("class2")
+ self.verticalLayout_40.addWidget(self.class2)
+ self.horizontalLayout_35.addLayout(self.verticalLayout_40)
+ self.verticalLayout_31.addLayout(self.horizontalLayout_35)
+ spacerItem8 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
+ self.verticalLayout_31.addItem(spacerItem8)
+ self.horizontalLayout_12.addLayout(self.verticalLayout_31)
+ self.horizontalLayout.addLayout(self.horizontalLayout_12)
+ MainWindow.setCentralWidget(self.centralwidget)
+ self.menubar = QtWidgets.QMenuBar(MainWindow)
+ self.menubar.setGeometry(QtCore.QRect(0, 0, 1092, 26))
+ self.menubar.setObjectName("menubar")
+ MainWindow.setMenuBar(self.menubar)
+ self.statusbar = QtWidgets.QStatusBar(MainWindow)
+ self.statusbar.setObjectName("statusbar")
+ MainWindow.setStatusBar(self.statusbar)
+
+ self.retranslateUi(MainWindow)
+ QtCore.QMetaObject.connectSlotsByName(MainWindow)
+
+ def retranslateUi(self, MainWindow):
+ _translate = QtCore.QCoreApplication.translate
+ MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
+ self.label_8.setText(_translate("MainWindow", "形态学处理"))
+ self.corrosionBtn.setText(_translate("MainWindow", "腐蚀运算"))
+ self.expansionBtn.setText(_translate("MainWindow", "膨胀运算"))
+ self.OpenBtn.setText(_translate("MainWindow", "开运算"))
+ self.CloseBtn.setText(_translate("MainWindow", "闭运算"))
+ self.label_11.setText(_translate("MainWindow", "图像增强"))
+ self.histogramBtn.setText(_translate("MainWindow", "灰度直方图"))
+ self.MeanBtn.setText(_translate("MainWindow", "均值滤波"))
+ self.medianBtn.setText(_translate("MainWindow", "中值滤波"))
+ self.PrewittBtn.setText(_translate("MainWindow", "Prewitt算子"))
+ self.LowpassBtn.setText(_translate("MainWindow", "低通滤波"))
+ self.EqualizeBtn.setText(_translate("MainWindow", "均衡化"))
+ self.vagueBtn.setText(_translate("MainWindow", "高斯模糊滤波"))
+ self.RobertsBtn.setText(_translate("MainWindow", "Roberts算子"))
+ self.SobelBtn.setText(_translate("MainWindow", "Sobel算子"))
+ self.HighpassBtn.setText(_translate("MainWindow", "高通滤波"))
+ self.BoxBtn.setText(_translate("MainWindow", "方框滤波"))
+ self.label_13.setText(_translate("MainWindow", "图像信息"))
+ self.ImportBtn.setText(_translate("MainWindow", "导入"))
+ self.SaveBtn.setText(_translate("MainWindow", "保存"))
+ self.label_15.setText(_translate("MainWindow", "高度:"))
+ self.label_17.setText(_translate("MainWindow", "宽度:"))
+ self.label_18.setText(_translate("MainWindow", "通道数:"))
+ self.label_19.setText(_translate("MainWindow", "图片类型:"))
+ self.Label_H.setText(_translate("MainWindow", "Label_H"))
+ self.Label_W.setText(_translate("MainWindow", "Label_W"))
+ self.Label_T.setText(_translate("MainWindow", "Label_T"))
+ self.Label_Type.setText(_translate("MainWindow", "Label_Type"))
+ self.label_20.setText(_translate("MainWindow", " 处理前"))
+ self.label_21.setText(_translate("MainWindow", " 处理后"))
+ self.label_22.setText(_translate("MainWindow", "边缘检测"))
+ self.LOGBtn.setText(_translate("MainWindow", "LoG检测"))
+ self.ScharrBtn.setText(_translate("MainWindow", "Scharr算子"))
+ self.CannyBtn.setText(_translate("MainWindow", "Canny算子"))
+ self.label_23.setText(_translate("MainWindow", "基本处理"))
+ self.GrayscaleBtn.setText(_translate("MainWindow", "灰度处理"))
+ self.BinarizationBtn.setText(_translate("MainWindow", "二值化处理"))
+ self.geometryBtn.setText(_translate("MainWindow", "几何变换"))
+ self.label_24.setText(_translate("MainWindow", "深度学习"))
+ self.feiyan.setText(_translate("MainWindow", "肺炎图像处理"))
+ self.label_16.setText(_translate("MainWindow", "黑色"))
+ self.label_25.setText(_translate("MainWindow", "白色"))
+ self.label_26.setText(_translate("MainWindow", "蓝色"))
+ self.label_27.setText(_translate("MainWindow", "绿色"))
+ self.class0.setText(_translate("MainWindow", "背景"))
+ self.class3.setText(_translate("MainWindow", "肺区"))
+ self.class1.setText(_translate("MainWindow", "磨玻璃影"))
+ self.class2.setText(_translate("MainWindow", "实变"))
diff --git a/MyUi.py b/MyUi.py
new file mode 100644
index 0000000..7d9750c
--- /dev/null
+++ b/MyUi.py
@@ -0,0 +1,456 @@
+import os
+import sys
+import cv2
+from PyQt5.QtGui import QPixmap, QImage, qRed, qGreen, qBlue
+from PyQt5.QtWidgets import QMainWindow, QApplication, QFileDialog, QMessageBox
+import matplotlib.pyplot as plt
+import numpy as np
+import traceback
+from PIL import Image
+from PyQt5.QtCore import Qt
+from predict import solve
+
+
+def cvImgtoQtImg(cvImg): # 定义opencv图像转PyQt图像的函数
+ QtImgBuf = cv2.cvtColor(cvImg, cv2.COLOR_BGR2BGRA)
+ QtImg = QImage(QtImgBuf.data, QtImgBuf.shape[1], QtImgBuf.shape[0], QImage.Format_RGB32)
+ return QtImg
+
+
+def QImage2CV(qimg):
+ tmp = qimg
+
+ # 使用numpy创建空的图象
+ cv_image = np.zeros((tmp.height(), tmp.width(), 3), dtype=np.uint8)
+
+ for row in range(0, tmp.height()):
+ for col in range(0, tmp.width()):
+ r = qRed(tmp.pixel(col, row))
+ g = qGreen(tmp.pixel(col, row))
+ b = qBlue(tmp.pixel(col, row))
+ cv_image[row, col, 0] = r
+ cv_image[row, col, 1] = g
+ cv_image[row, col, 2] = b
+
+ cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
+
+ return cv_image
+
+
+def QPixmap2cv(qtpixmap):
+ try:
+ qimg = qtpixmap.toImage()
+ temp_shape = (qimg.height(), qimg.bytesPerLine() * 8 // qimg.depth())
+ temp_shape += (4,)
+ ptr = qimg.bits()
+ ptr.setsize(qimg.byteCount())
+ result = np.array(ptr, dtype=np.uint8).reshape(temp_shape)
+ result = result[..., :3]
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
+ except Exception as e:
+ traceback.print_exc()
+
+ return result
+
+def FFT2(img):
+ # 傅里叶变换
+ f = np.fft.fft2(img)
+ fshift = np.fft.fftshift(f)
+ res = 20 * np.log(np.abs(fshift))
+ res = res - res.min()
+ res = res / res.max() * 255
+ res = np.array(res, np.uint8)
+
+ # plt.imshow(res, cmap='gray')
+ plt.imshow(res)
+ plt.axis('off')
+ plt.savefig('Img.png', bbox_inches='tight', pad_inches=0.0)
+ plt.close()
+ result = cv2.imread('Img.png', 0)
+ os.remove('Img.png')
+
+ return result
+
+
+class MyWindow(QMainWindow):
+ def __init__(self, Ui_MainWindow):
+ super().__init__()
+ self.ui = Ui_MainWindow
+ app = QApplication(sys.argv)
+ MainWindow = QMainWindow()
+ self.ui.setupUi(MainWindow)
+ self.picpath = ''
+ self.openfile_name = ''
+ self.pixmapBefore = QPixmap()
+ self.pixmapAfter = QPixmap()
+
+ self.ui.PicBefore.setScaledContents(False)
+ self.ui.PicBefore.setAlignment(Qt.AlignCenter)
+ self.ui.PicAfter.setScaledContents(False)
+ self.ui.PicAfter.setAlignment(Qt.AlignCenter)
+
+ self.ui.ImportBtn.clicked.connect(lambda: self.Import())
+ self.ui.GrayscaleBtn.clicked.connect(lambda: self.Grayscale())
+ self.ui.BinarizationBtn.clicked.connect(lambda: self.Binarization())
+ self.ui.geometryBtn.clicked.connect(lambda: self.Geometry())
+ self.ui.histogramBtn.clicked.connect(lambda: self.Histogram())
+ self.ui.EqualizeBtn.clicked.connect(lambda: self.Equalize())
+ self.ui.MeanBtn.clicked.connect(lambda: self.Mean())
+ self.ui.BoxBtn.clicked.connect(lambda: self.Box())
+ self.ui.vagueBtn.clicked.connect(lambda: self.Vague())
+ self.ui.medianBtn.clicked.connect(lambda: self.Median())
+ self.ui.RobertsBtn.clicked.connect(lambda: self.Roberts())
+ self.ui.PrewittBtn.clicked.connect(lambda: self.Prewitt())
+ self.ui.SobelBtn.clicked.connect(lambda: self.Sobel())
+ self.ui.LowpassBtn.clicked.connect(lambda: self.Lowpass())
+ self.ui.HighpassBtn.clicked.connect(lambda: self.Highpass())
+ self.ui.corrosionBtn.clicked.connect(lambda: self.Corrosion())
+ self.ui.expansionBtn.clicked.connect(lambda: self.Expansion())
+ self.ui.OpenBtn.clicked.connect(lambda: self.Open())
+ self.ui.CloseBtn.clicked.connect(lambda: self.Close())
+ self.ui.LOGBtn.clicked.connect(lambda: self.LOG())
+ self.ui.ScharrBtn.clicked.connect(lambda: self.Scharr())
+ self.ui.CannyBtn.clicked.connect(lambda: self.Canny())
+ self.ui.SaveBtn.clicked.connect(lambda: self.Save())
+ self.ui.feiyan.clicked.connect(lambda:self.Feiyan())
+
+ MainWindow.show()
+ sys.exit(app.exec_())
+
+
+ def Import(self):
+ # 导入图片
+ self.openfile_name = QFileDialog.getOpenFileName(self, '选择文件', '', "Image Files (*.png *.jpg *.bmp)")[0]
+ if not self.openfile_name:
+ # 如果没有选择文件,直接返回
+ return
+
+ try:
+ self.pixmapBefore = QPixmap(self.openfile_name)
+ if self.pixmapBefore.isNull():
+ raise ValueError("无法加载图片文件")
+
+ self.picpath = self.openfile_name
+
+ image = cv2.imread(self.picpath)
+ if image is None:
+ raise ValueError("无法读取图片文件")
+
+ self.ui.Label_H.setText(str(image.shape[0]))
+ self.ui.Label_W.setText(str(image.shape[1]))
+ self.ui.Label_T.setText(str(image.shape[2]))
+ self.ui.Label_Type.setText(str(image.dtype))
+
+
+ self.resizeImage(self.ui.PicBefore,self.pixmapBefore)
+
+ except Exception as e:
+ traceback.print_exc()
+ QMessageBox.critical(self, "错误", f"加载图片文件时出错: {e}")
+
+ def Save(self):
+ if self.pixmapAfter.isNull():
+ QMessageBox.about(self, '保存失败', '没有已经处理完成的图片')
+ return
+ # 保存
+ SaveName = QFileDialog.getSaveFileName(self, '选择文件', '', "Image Files (*.png *.jpg *.bmp)")[0]
+
+ if not SaveName:
+ return
+
+ try:
+ if type(self.pixmapAfter) == QImage:
+ print("1")
+ result = cv2.imwrite(SaveName, QImage2CV(self.pixmapAfter))
+ else:
+ print("2")
+ result = cv2.imwrite(SaveName, QPixmap2cv(self.pixmapAfter))
+ if result:
+ QMessageBox.about(self, '保存成功', '保存成功')
+ else:
+ QMessageBox.about(self, '保存失败', '路径中不能含有中文和空格')
+ except Exception as e:
+ traceback.print_exc()
+
+ def resizeImage(self, label, pixmap):
+ # print("aaa")
+ if pixmap:
+ label.setPixmap(pixmap.scaled(label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
+
+ def resizeEvent(self, event):
+ self.resizeImage(self.ui.PicBefore,self.pixmapBefore)
+ self.resizeImage(self.ui.PicAfter,self.pixmapAfter)
+ super(MyWindow, self).resizeEvent(event)
+
+ def check(self):
+ if self.pixmapBefore.isNull():
+ QMessageBox.about(self, '操作失败', '请先导入图片')
+ return True
+ img = cv2.imread(self.picpath)
+ if img is None:
+ QMessageBox.about(self, '操作失败', '无法读取图片')
+ return True
+ return False
+
+ def Grayscale(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ grayImg = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
+ qt_img = cvImgtoQtImg(grayImg)
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter,self.pixmapAfter)
+
+ def Binarization(self):
+ if self.check():
+ return
+ try:
+ img = cv2.imread(self.picpath)
+ grayImg = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
+ qt_img = cvImgtoQtImg(cv2.threshold(grayImg, 127, 255, cv2.THRESH_BINARY)[1])
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+ except Exception as e:
+ traceback.print_exc()
+
+ def Geometry(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ self.ui.PicAfter.setScaledContents(False)
+ qt_img = cvImgtoQtImg(cv2.resize(img, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Histogram(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.calcHist([cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)], [0], None, [256], [0, 255])
+ plt.plot(img)
+ plt.savefig('img.jpg')
+ plt.close()
+ self.pixmapAfter = QPixmap('img.jpg')
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+ os.remove('img.jpg')
+
+ def Equalize(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ qt_img = cvImgtoQtImg(cv2.equalizeHist(img))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Mean(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ qt_img = cvImgtoQtImg(cv2.blur(img, (3, 5)))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Box(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ qt_img = cvImgtoQtImg(cv2.boxFilter(img, -1, (3, 5)))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Vague(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ qt_img = cvImgtoQtImg(cv2.GaussianBlur(img, (3, 3), 0))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Median(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ qt_img = cvImgtoQtImg(cv2.medianBlur(img, 5))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Roberts(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
+ kernelx_Robert = np.array([[-1, 0], [0, 1]], dtype=int)
+ kernely_Robert = np.array([[0, -1], [1, 0]], dtype=int)
+ x_Robert = cv2.filter2D(img_binary, cv2.CV_16S, kernelx_Robert)
+ y_Robert = cv2.filter2D(img_binary, cv2.CV_16S, kernely_Robert)
+ absX_Robert = cv2.convertScaleAbs(x_Robert)
+ absY_Robert = cv2.convertScaleAbs(y_Robert)
+ qt_img = cvImgtoQtImg(cv2.addWeighted(absX_Robert, 0.5, absY_Robert, 0.5, 0))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Prewitt(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
+ kernelx_Prewitt = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]], dtype=int)
+ kernely_Prewitt = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]], dtype=int)
+ x_Prewitt = cv2.filter2D(img_binary, -1, kernelx_Prewitt)
+ y_Prewitt = cv2.filter2D(img_binary, -1, kernely_Prewitt)
+ absX_Prewitt = cv2.convertScaleAbs(x_Prewitt)
+ absY_Prewitt = cv2.convertScaleAbs(y_Prewitt)
+ qt_img = cvImgtoQtImg(cv2.addWeighted(absX_Prewitt, 0.5, absY_Prewitt, 0.5, 0))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Sobel(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
+ x_Sobel = cv2.Sobel(img_binary, cv2.CV_16S, 1, 0)
+ y_Sobel = cv2.Sobel(img_binary, cv2.CV_16S, 0, 1)
+ absX_Sobel = cv2.convertScaleAbs(x_Sobel)
+ absY_Sobel = cv2.convertScaleAbs(y_Sobel)
+ qt_img = cvImgtoQtImg(cv2.addWeighted(absX_Sobel, 0.5, absY_Sobel, 0.5, 0))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Lowpass(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ img_dft = np.fft.fft2(img)
+ dft_shift_low = np.fft.fftshift(img_dft)
+ h, w = dft_shift_low.shape[0:2]
+ h_center, w_center = int(h / 2), int(w / 2)
+ img_black = np.zeros((h, w), np.uint8)
+ img_black[h_center - int(100 / 2):h_center + int(100 / 2), w_center - int(100 / 2):w_center + int(100 / 2)] = 1
+ dft_shift_low = dft_shift_low * img_black
+ idft_shift = np.fft.ifftshift(dft_shift_low)
+ ifimg = np.fft.ifft2(idft_shift)
+ ifimg = np.abs(ifimg)
+ ifimg = np.int8(ifimg)
+ cv2.imwrite('img.jpg', ifimg)
+ self.pixmapAfter = QPixmap('img.jpg')
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+ os.remove('img.jpg')
+
+ def Highpass(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ img_dft = np.fft.fft2(img)
+ dft_shift = np.fft.fftshift(img_dft)
+ h, w = dft_shift.shape[0:2]
+ h_center, w_center = int(h / 2), int(w / 2)
+ dft_shift[h_center - int(50 / 2):h_center + int(50 / 2),
+ w_center - int(50 / 2):w_center + int(50 / 2)] = 0
+ idft_shift = np.fft.ifftshift(dft_shift)
+ img_idft = np.fft.ifft2(idft_shift)
+ img_idft = np.abs(img_idft)
+ img_idft = np.int8(img_idft)
+ cv2.imwrite('img.jpg', img_idft)
+ self.pixmapAfter = QPixmap('img.jpg')
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+ os.remove('img.jpg')
+
+ def Corrosion(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 55, 255, cv2.THRESH_BINARY)
+ img_binary = np.ones(img_binary.shape, np.uint8) * 255 - img_binary
+ kernel = np.ones((3, 3), np.uint8)
+ qt_img = cvImgtoQtImg(cv2.erode(img_binary, kernel))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Expansion(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 55, 255, cv2.THRESH_BINARY)
+ img_binary = np.ones(img_binary.shape, np.uint8) * 255 - img_binary
+ kernel = np.ones((3, 3), np.uint8)
+ qt_img = cvImgtoQtImg(cv2.dilate(img_binary, kernel, iterations=1))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Open(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 55, 255, cv2.THRESH_BINARY)
+ img_binary = np.ones(img_binary.shape, np.uint8) * 255 - img_binary
+ kernel = np.ones((3, 3), np.uint8)
+ qt_img = cvImgtoQtImg(cv2.morphologyEx(img_binary, cv2.MORPH_OPEN, kernel))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Close(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ ret, img_binary = cv2.threshold(img, 55, 255, cv2.THRESH_BINARY)
+ img_binary = np.ones(img_binary.shape, np.uint8) * 255 - img_binary
+ kernel = np.ones((3, 3), np.uint8)
+ qt_img = cvImgtoQtImg(cv2.morphologyEx(img_binary, cv2.MORPH_CLOSE, kernel))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def LOG(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img_blur = cv2.GaussianBlur(img_gray, (3, 3), 1, 1)
+ LOG_result = cv2.Laplacian(img_blur, cv2.CV_16S, ksize=1)
+ qt_img = cvImgtoQtImg(cv2.convertScaleAbs(LOG_result))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Scharr(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ Scharr_result = cv2.Scharr(img_gray, cv2.CV_16S, dx=1, dy=0)
+ qt_img = cvImgtoQtImg(cv2.convertScaleAbs(Scharr_result))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Canny(self):
+ if self.check():
+ return
+ img = cv2.imread(self.picpath)
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img_blur_canny = cv2.GaussianBlur(img_gray, (7, 7), 1, 1)
+ qt_img = cvImgtoQtImg(cv2.Canny(img_blur_canny, 50, 150))
+ self.pixmapAfter = QPixmap.fromImage(qt_img)
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+ def Feiyan(self):
+ if self.check():
+ return
+ print(self.picpath)
+ filename = str(self.convert_path(self.picpath))
+ print(filename)
+ _ = solve("checkpoint_epoch16.pth",filename,"out.png",num_classes=4)
+ self.pixmapAfter = QPixmap("out.png")
+ self.resizeImage(self.ui.PicAfter, self.pixmapAfter)
+
+
+ def convert_path(self,path):
+ return path.replace("/", "\\")
\ No newline at end of file
diff --git a/MyWindow.py b/MyWindow.py
new file mode 100644
index 0000000..5ae6a22
--- /dev/null
+++ b/MyWindow.py
@@ -0,0 +1,61 @@
+import os
+import sys
+import cv2
+from PyQt5.QtGui import QPixmap, QImage, qRed, qGreen, qBlue
+from PyQt5.uic import loadUiType
+from PyQt5.QtWidgets import QMainWindow, QApplication, QFileDialog, QMessageBox
+import matplotlib.pyplot as plt
+import numpy as np
+from MainWindow import Ui_Form
+
+
+
+
+
+
+def cvImgtoQtImg(cvImg): # 定义opencv图像转PyQt图像的函数
+ QtImgBuf = cv2.cvtColor(cvImg, cv2.COLOR_BGR2BGRA)
+ QtImg = QImage(QtImgBuf.data, QtImgBuf.shape[1], QtImgBuf.shape[0], QImage.Format_RGB32)
+ return QtImg
+
+def QImage2CV(qimg):
+ tmp = qimg
+
+ # 使用numpy创建空的图象
+ cv_image = np.zeros((tmp.height(), tmp.width(), 3), dtype=np.uint8)
+
+ for row in range(0, tmp.height()):
+ for col in range(0, tmp.width()):
+ r = qRed(tmp.pixel(col, row))
+ g = qGreen(tmp.pixel(col, row))
+ b = qBlue(tmp.pixel(col, row))
+ cv_image[row, col, 0] = r
+ cv_image[row, col, 1] = g
+ cv_image[row, col, 2] = b
+
+ cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
+ return cv_image
+
+
+def QPixmap2cv(qtpixmap):
+ qimg = qtpixmap.toImage()
+ temp_shape = (qimg.height(), qimg.bytesPerLine() * 8 // qimg.depth())
+ temp_shape += (4,)
+ ptr = qimg.bits()
+ ptr.setsize(qimg.byteCount())
+ result = np.array(ptr, dtype=np.uint8).reshape(temp_shape)
+ result = result[..., :3]
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
+
+ return result
+
+main_ui, _ = loadUiType('ui/bb.ui')
+class My_UI(QMainWindow, main_ui):
+ def __init__(self):
+ QMainWindow.__init__(self)
+ self.setupUi(self) # 构造界面
+
+ self.picpath = ' '
+ self.openfile_name = ' '
+
+
diff --git a/README.md b/README.md
index 66cf405..335758a 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,3 @@
# Seg_lung
+直接运行main函数即可,深度学习的参数文件checkpoint_epoch16.pth过大无法上传,所以仓库中没有
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..81f00a6
--- /dev/null
+++ b/main.py
@@ -0,0 +1,18 @@
+import sys
+from PyQt5.QtWidgets import *
+from PyQt5 import uic
+from PyQt5 import QtCore
+from MainWindow import Ui_MainWindow
+from MyUi import MyWindow
+
+
+
+if __name__ == "__main__":
+ #QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
+ app = QApplication(sys.argv)
+ MainWindow = QMainWindow()
+ ui = Ui_MainWindow()
+ My_Ui = MyWindow(ui)
+ sys.exit(app.exec_())
+
+
diff --git a/out.png b/out.png
new file mode 100644
index 0000000000000000000000000000000000000000..fa544ebd2f264b746595471f083ade369b73decf
GIT binary patch
literal 2389
zcmb7``!|#e7suy$JfoRmdIvME#Uz)>ZRA>xc!W$%A~oa~895Yk$)(&q(<`~=k`v+7
zCV#LS^T_p25|G>M}`~9x9KYQ)9_ILlf(_N1^5ecdU388`ae6Pj$QNa^`*$)Yr_)XT*1hvNO({-K|yFV~a{H
zVJ@e`!gWyJP|A;3qszq&Tkf%6QvecaerUwrQP|xt$7mwrOHENFoY%cR6(KIl
z&@l~GwG*0U*!hc|f80Sb*|I3lkv;n~END^95nhui4+V1SaWt?Ivr}Z8wE->yOY)@F
zZz|@rL`XUN307Ix92cWxck21r6y^N__=Z(vwH(7-74k`=kVENv%oX4PR@l
zGC!B1D=vvQl4``zIrrdqOx8#Djxuf7~49R)Fi*>j}7|xafUt
z+X1hnw%QZiMJApRK>u!|Wx1f>c>tZSmVdST94_Yh)BT{PD-yg&&XtiVHv-PfRSuDv
zh&NP?2%{rsPAj#a0`qs}CWL`4Nm-W*%4DpH;WK+BF!=7WB10K$VZ4sZ&_&tOlfPkV
z)trz<*@mDD8uby_JZ4I0;j1aWHv#h%<`VZ|{F{6n5Xy(9XUPY+;;Gf1Q(%kF75)bS
z7}Bz)TTilC^m=T~
zK_K4FyY-qg-rAe(UR#eXCVG#hZ38{%zoTX98G0G$?`vg#Dlh$Bz2q!c$E}bDLMzmO
zL}>wQp9Z=aIFWo!-K;}k{!dc^84y=^A8DB|{ED+2S_2>GLqb#Mk#O8dbyBu|f*Y<<9nurwMT;{WK%r2aC
zg4Kq~ZQTSrd0X4V%LnM4OdQRBk1Tx6R-kE2w5pk`$7TI%AJHh2e-QWP~l5o~#f
zDWi2&>NKh7vl`@fORNcY=QSV7ZcIUAREnX>hl9}cfb)~o($DV_a+bwwB|mm)a(3odCEkjL0`E|ISOB~uLQq@P(rVRbEF=7ILn^a
z`Jtrl?%In(O4MP$BGfEJ)77G%QT=w8iWhaIl`Yp*e0OL_6c}yG{u}BrONnUMV^b@6
zN>=Vmr5h9TXim)<*=wDGTR9p`_{e5GJz2XKpB%+4@ZC_1?x;
z#y49;h_B|?CBOPeQcmn9=NfTa6h^)i%B_KR$|G+EOf17M=ztoj=g&b6JC8#}ovJK$2#aZob0#Y#+@6BS&9M1S@ug?ARCY-n{=^wYLvA7h4zb
z(()1Y9J2C5z!oqmvPaIk3v}UO=c=$21Nmg2t|!!yTM+c?{3Idf(Qug#{7`Qt40~7;
zPx4`2lQ({o6a7=!q)o>O>1uetxn5eWHWTl<{rv6P)S!5!Y5XX(qWJ#iR9L_uHiao%
znP~A!+nl#bu*h8e+n$4{X`L&mG2kVKD?uv>Up$v*0q(aDM^5995KFlj0)^2vF3Wu({`&V3
z;fo!|QR^)o5B#mzy#}Ql@O+MNeP(ZN+}K>WxchCb+t#<>Uwf%gfw}4WaZ0$h1$H~h
z!OxK=xtcP08mByqPCOhLdtJBtIrpt6J-t+$QzT_lStjz2&_J`N8agY4+S3bmkj5+Sf8yJtAwdeOzeJqOA6+5QB~iGP
zCv#PV0H2hKrB^VsJ8-%;^(Rc`Ig{Me-N}PP?ce_M&%g8kDv7fpCB-IPUJDNs{j+<}
M@rXmse$K`J0_lbmzyJUM
literal 0
HcmV?d00001
diff --git a/output.py b/output.py
new file mode 100644
index 0000000..3ca17d8
--- /dev/null
+++ b/output.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'input.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.9
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..7953ecc
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,215 @@
+import argparse
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+import matplotlib.pyplot as plt
+
+from utils.data_loading import BasicDataset
+from unet import UNet, SEUNet, UResnet34, UResnet50, UResnet101, UResnet152
+from utils.utils import plot_img_and_mask
+
+
+# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
+
+def predict_img(net,
+ full_img,
+ device,
+ scale_factor=1,
+ out_threshold=0.5):
+ net.eval()
+ img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
+ img = img.unsqueeze(0)
+ img = img.to(device=device, dtype=torch.float32)
+
+ with torch.no_grad():
+ output = net(img).cpu()
+ output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
+ if net.n_classes > 1:
+ output = F.softmax(output, dim=1)
+ # remove the first channel (background)
+ # output = output[:, 1:, :, :]
+ # remove the second channel (vessels)
+ # output = output[:, :1, :, :]
+ # output = output[:, :1, :, :]
+ print(output.shape)
+ mask = output.argmax(dim=1)
+ else:
+ mask = torch.sigmoid(output) > out_threshold
+
+ return mask[0].long().squeeze().numpy()
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='Predict masks from input images')
+ parser.add_argument('--model', '-m', default='checkpoints/checkpoint_epoch24.pth', metavar='FILE',
+ help='Specify the file in which the model is stored')
+ parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
+ parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
+ parser.add_argument('--viz', '-v', action='store_true',
+ help='Visualize the images as they are processed')
+ parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
+ parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
+ help='Minimum probability value to consider a mask pixel white')
+ parser.add_argument('--scale', '-s', type=float, default=0.5,
+ help='Scale factor for the input images')
+ parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
+ parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
+
+ return parser.parse_args()
+
+
+def get_output_filenames(args):
+ def _generate_name(fn):
+ return f'{os.path.splitext(fn)[0]}_OUT.png'
+
+ return args.output or list(map(_generate_name, args.input))
+
+
+def mask_to_image(mask: np.ndarray, mask_values):
+ if isinstance(mask_values[0], list):
+ out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
+ elif mask_values == [0, 1]:
+ out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
+ else:
+ out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
+
+ if mask.ndim == 3:
+ mask = np.argmax(mask, axis=0)
+
+ for i, v in enumerate(mask_values):
+ out[mask == i] = v
+
+ return Image.fromarray(out)
+
+
+def MainSolve(model_path, input_files, output_files=None, visualize=False, no_save=False, mask_threshold=0.5, scale=0.5,
+ bilinear=False, num_classes=2):
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+
+ in_files = input_files
+ out_files = output_files if output_files else [f"{os.path.splitext(fn)[0]}_OUT.png" for fn in input_files]
+
+ net = UResnet34(n_channels=3, n_classes=num_classes, bilinear=bilinear)
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ logging.info(f'Loading model {model_path}')
+ logging.info(f'Using device {device}')
+
+ net.to(device=device)
+ state_dict = torch.load(model_path, map_location=device)
+ mask_values = state_dict.pop('mask_values', [0, 1])
+ net.load_state_dict(state_dict, strict=False)
+
+ logging.info('Model loaded!')
+
+ for i, filename in enumerate(in_files):
+ logging.info(f'Predicting image {filename} ...')
+ img = Image.open(filename)
+ true_mask = Image.open(filename.replace('image_1', 'mask'))
+
+ mask = predict_img(net=net,
+ full_img=img,
+ scale_factor=scale,
+ out_threshold=mask_threshold,
+ device=device)
+
+ if not no_save:
+ out_filename = out_files[i]
+ result = mask_to_image(mask, mask_values)
+ result.putpalette([
+ 0, 0, 0, # Black background
+ 255, 255, 255, # Class 1
+ 0, 0, 255, # Class 2
+ 0, 255, 0, # Class 3
+ ])
+ result.save(out_filename)
+ logging.info(f'Mask saved to {out_filename}')
+
+ plt.figure(figsize=(10, 10))
+ plt.subplot(1, 3, 1)
+ plt.imshow(img)
+ plt.axis('off')
+ plt.title('Original Image')
+ plt.subplot(1, 3, 2)
+ plt.imshow(result)
+ plt.axis('off')
+ plt.title('Predicted Mask')
+ plt.subplot(1, 3, 3)
+
+ true_mask = mask_to_image(np.asarray(true_mask), mask_values)
+ true_mask.putpalette([
+ 0, 0, 0, # Black background
+ 255, 255, 255, # Class 1
+ 0, 0, 255, # Class 2
+ 0, 255, 0, # Class 3
+ ])
+ plt.imshow(true_mask)
+ plt.axis('off')
+ plt.title('True Mask')
+ plt.show()
+ plt.savefig('res_comparsion')
+
+ if visualize:
+ logging.info(f'Visualizing results for image {filename}, close to continue...')
+ plot_img_and_mask(img, mask)
+
+def solve(model_path, input_file, output_file=None, visualize=False, no_save=False, mask_threshold=0.5, scale=0.5,
+ bilinear=False, num_classes=2):
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+
+
+ net = UResnet34(n_channels=3, n_classes=num_classes, bilinear=bilinear)
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ logging.info(f'Loading model {model_path}')
+ logging.info(f'Using device {device}')
+
+ net.to(device=device)
+ state_dict = torch.load(model_path, map_location=device)
+ mask_values = state_dict.pop('mask_values', [0, 1])
+ net.load_state_dict(state_dict, strict=False)
+
+ logging.info('Model loaded!')
+
+ filename = input_file
+ logging.info(f'Predicting image {filename} ...')
+ img = Image.open(filename)
+ true_mask = Image.open(filename.replace('image_1', 'mask'))
+ mask = predict_img(net=net,
+ full_img=img,
+ scale_factor=scale,
+ out_threshold=mask_threshold,
+ device=device)
+
+ if not no_save:
+ out_filename = output_file
+ result = mask_to_image(mask, mask_values)
+ result.putpalette([
+ 0, 0, 0, # Black background
+ 255, 255, 255, # Class 1
+ 0, 0, 255, # Class 2
+ 0, 255, 0, # Class 3
+ ])
+ result.save(out_filename)
+ logging.info(f'Mask saved to {out_filename}')
+ return result
+
+
+
+
+if __name__ == '__main__':
+ args = get_args()
+ MainSolve(model_path=args.model,
+ input_files=args.input,
+ output_files=args.output,
+ visualize=args.viz,
+ no_save=args.no_save,
+ mask_threshold=args.mask_threshold,
+ scale=args.scale,
+ bilinear=args.bilinear,
+ num_classes=args.classes)
diff --git a/ui/MainWindow.ui b/ui/MainWindow.ui
new file mode 100644
index 0000000..95d6d2b
--- /dev/null
+++ b/ui/MainWindow.ui
@@ -0,0 +1,789 @@
+
+
+ MainWindow
+
+
+
+ 0
+ 0
+ 1246
+ 826
+
+
+
+ MainWindow
+
+
+
+ pic/icon.jpgpic/icon.jpg
+
+
+
+
+
+ 10
+ 10
+ 180
+ 772
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
-
+
+
-
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 基本处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ 二值化处理
+
+
+
+ -
+
+
+ 几何变换
+
+
+
+ -
+
+
+ 灰度处理
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像增强
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ Sobel算子
+
+
+
+ -
+
+
+ 低通滤波
+
+
+
+ -
+
+
+ Roberts算子
+
+
+
+ -
+
+
+ Prewitt算子
+
+
+
+ -
+
+
+ 灰度直方图
+
+
+
+ -
+
+
+ 高通滤波
+
+
+
+ -
+
+
+ 中值滤波
+
+
+
+ -
+
+
+ 均衡化
+
+
+
+ -
+
+
+ 均值滤波
+
+
+
+ -
+
+
+ 高斯模糊滤波
+
+
+
+ -
+
+
+ 方框滤波
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 形态学处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ 腐蚀运算
+
+
+
+ -
+
+
+ 闭运算
+
+
+
+ -
+
+
+ 开运算
+
+
+
+ -
+
+
+ 膨胀运算
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 边缘检测
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ Canny算子
+
+
+
+ -
+
+
+ Scharr算子
+
+
+
+ -
+
+
+ LoG检测
+
+
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像信息
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ 导入
+
+
+
+ -
+
+
+ 保存
+
+
+
+
+
+ -
+
+
-
+
+
+ 高度:
+
+
+
+ -
+
+
+ Label_H
+
+
+
+
+
+ -
+
+
-
+
+
+ 宽度:
+
+
+
+ -
+
+
+ Label_W
+
+
+
+
+
+ -
+
+
-
+
+
+ 通道数:
+
+
+
+ -
+
+
+ Label_T
+
+
+
+
+
+ -
+
+
-
+
+
+ 图片类型:
+
+
+
+ -
+
+
+ Label_Type
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+
+
+
+
+
+
+ 230
+ 20
+ 471
+ 16
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理前
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+
+
+
+ 230
+ 50
+ 471
+ 331
+
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+ 230
+ 450
+ 471
+ 331
+
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+ 230
+ 410
+ 471
+ 16
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理后
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+
+
+
+ 740
+ 20
+ 471
+ 16
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 频谱图
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+
+
+
+ 740
+ 50
+ 471
+ 331
+
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+ 740
+ 410
+ 471
+ 16
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 频谱图
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+
+
+
+ 740
+ 450
+ 471
+ 331
+
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ui/aa.ui b/ui/aa.ui
new file mode 100644
index 0000000..eff4f01
--- /dev/null
+++ b/ui/aa.ui
@@ -0,0 +1,622 @@
+
+
+ mainWindow
+
+
+
+ 0
+ 0
+ 980
+ 730
+
+
+
+
+ 980
+ 730
+
+
+
+
+ 980
+ 730
+
+
+
+ 数字图像处理
+
+
+
+
+
+ 860
+ 30
+ 111
+ 131
+
+
+
+ -
+
+
+
+ 40
+ 40
+
+
+
+
+ 40
+ 40
+
+
+
+
+
+
+
+ icon/YY_fangda_1.pngicon/YY_fangda_1.png
+
+
+
+ 25
+ 25
+
+
+
+
+ -
+
+
+
+ 40
+ 40
+
+
+
+
+ 40
+ 40
+
+
+
+
+
+
+
+ icon/镜像.pngicon/镜像.png
+
+
+
+ 25
+ 50
+
+
+
+
+ -
+
+
+
+ 40
+ 40
+
+
+
+
+ 40
+ 40
+
+
+
+
+
+
+
+ icon/缩小.pngicon/缩小.png
+
+
+
+ 35
+ 35
+
+
+
+
+ -
+
+
+
+ 40
+ 40
+
+
+
+
+ 40
+ 40
+
+
+
+
+
+
+
+ icon/3.1-旋转.pngicon/3.1-旋转.png
+
+
+
+ 40
+ 30
+
+
+
+
+
+
+
+
+
+ 880
+ 0
+ 71
+ 31
+
+
+
+ 几何变换:
+
+
+
+
+
+ 870
+ 480
+ 91
+ 21
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 880
+ 270
+ 61
+ 21
+
+
+
+ 平滑处理:
+
+
+
+
+
+ 150
+ 10
+ 700
+ 680
+
+
+
+
+ 700
+ 680
+
+
+
+
+ 700
+ 680
+
+
+
+
+
+
+ 0
+ 10
+ 141
+ 221
+
+
+
+
+ -1
+
+
+ 12
+
+ -
+
+
+ 亮度调节:
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 对比度调节:
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 饱和度调节:
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+
+
+
+ 10
+ 240
+ 118
+ 3
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 10
+ 250
+ 118
+ 3
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 10
+ 260
+ 121
+ 201
+
+
+
+ -
+
+
+ 恢复原图
+
+
+
+ -
+
+
+ 灰度化
+
+
+
+ -
+
+
+ 二值化
+
+
+
+
+
+
+
+
+ 880
+ 430
+ 91
+ 21
+
+
+
+ 直方图处理:
+
+
+
+
+
+ 870
+ 260
+ 91
+ 21
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 860
+ 450
+ 111
+ 41
+
+
+
+ 直方图均衡化
+
+
+
+
+
+ 870
+ 410
+ 91
+ 21
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 890
+ 490
+ 61
+ 21
+
+
+
+ 锐化处理:
+
+
+
+
+
+ 880
+ 170
+ 71
+ 21
+
+
+
+ 加性噪声:
+
+
+
+
+
+ 870
+ 190
+ 91
+ 41
+
+
+
+ 高斯噪声
+
+
+
+
+
+ 870
+ 230
+ 91
+ 41
+
+
+
+ 椒盐噪声
+
+
+
+
+
+ 870
+ 150
+ 91
+ 21
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ 870
+ 330
+ 91
+ 41
+
+
+
+ 中值滤波
+
+
+
+
+
+ 870
+ 290
+ 91
+ 41
+
+
+
+ 均值滤波
+
+
+
+
+
+ 870
+ 370
+ 91
+ 41
+
+
+
+ 高斯滤波
+
+
+
+
+
+ 20
+ 470
+ 113
+ 32
+
+
+
+ 二维码定位
+
+
+
+
+
+ 870
+ 510
+ 91
+ 41
+
+
+
+ Prewitt
+
+
+
+
+
+ 870
+ 560
+ 91
+ 41
+
+
+
+ Canny
+
+
+
+
+
+
+
+ icon/open.pngicon/open.png
+
+
+ 打开文件
+
+
+
+
+
+ icon/保存.pngicon/保存.png
+
+
+ 保存文件
+
+
+
+
+ 灰度
+
+
+
+
+ 彩色
+
+
+
+
+
+
\ No newline at end of file
diff --git a/ui/bb.ui b/ui/bb.ui
new file mode 100644
index 0000000..071c52a
--- /dev/null
+++ b/ui/bb.ui
@@ -0,0 +1,882 @@
+
+
+ MainWindow
+
+
+
+ 0
+ 0
+ 1092
+ 799
+
+
+
+
+ 0
+ 0
+
+
+
+
+ 10000
+ 10000
+
+
+
+ MainWindow
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 形态学处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 腐蚀运算
+
+
+
+ -
+
+
+ 膨胀运算
+
+
+
+
+
+ -
+
+
-
+
+
+ 开运算
+
+
+
+ -
+
+
+ 闭运算
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 20
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像增强
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
-
+
+
+ 灰度直方图
+
+
+
+ -
+
+
+ 均值滤波
+
+
+
+ -
+
+
+ 中值滤波
+
+
+
+ -
+
+
+ Prewitt算子
+
+
+
+ -
+
+
+ 低通滤波
+
+
+
+
+
+ -
+
+
-
+
+
+ 均衡化
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 高斯模糊滤波
+
+
+
+ -
+
+
+ Roberts算子
+
+
+
+ -
+
+
+ Sobel算子
+
+
+
+ -
+
+
+ 高通滤波
+
+
+
+
+
+
+
+ -
+
+
+ 方框滤波
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像信息
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ 导入
+
+
+
+ -
+
+
+ 保存
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 高度:
+
+
+
+ -
+
+
+ 宽度:
+
+
+
+ -
+
+
+ 通道数:
+
+
+
+ -
+
+
+ 图片类型:
+
+
+
+
+
+ -
+
+
-
+
+
+ Label_H
+
+
+
+ -
+
+
+ Label_W
+
+
+
+ -
+
+
+ Label_T
+
+
+
+ -
+
+
+ Label_Type
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ -
+
+
-
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理前
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理后
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ -
+
+
-
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 边缘检测
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ LoG检测
+
+
+
+ -
+
+
+ Scharr算子
+
+
+
+
+
+ -
+
+
+ Canny算子
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 基本处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 灰度处理
+
+
+
+ -
+
+
+ 二值化处理
+
+
+
+
+
+ -
+
+
+ 几何变换
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 深度学习
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+ 肺炎图像处理
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Preferred
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 黑色
+
+
+
+ -
+
+
+ 白色
+
+
+
+ -
+
+
+ 蓝色
+
+
+
+ -
+
+
+ 绿色
+
+
+
+
+
+ -
+
+
-
+
+
+ 背景
+
+
+
+ -
+
+
+ 肺区
+
+
+
+ -
+
+
+ 磨玻璃影
+
+
+
+ -
+
+
+ 实变
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Expanding
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ui/main.ui b/ui/main.ui
new file mode 100644
index 0000000..3a86d6b
--- /dev/null
+++ b/ui/main.ui
@@ -0,0 +1,779 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 1321
+ 905
+
+
+
+ Form
+
+
+
+
+ 9
+ 9
+ 1301
+ 881
+
+
+
+ -
+
+
-
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 形态学处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 腐蚀运算
+
+
+
+ -
+
+
+ 膨胀运算
+
+
+
+
+
+ -
+
+
-
+
+
+ 开运算
+
+
+
+ -
+
+
+ 闭运算
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 20
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像增强
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
-
+
+
+ 灰度直方图
+
+
+
+ -
+
+
+ 均值滤波
+
+
+
+ -
+
+
+ 中值滤波
+
+
+
+ -
+
+
+ Prewitt算子
+
+
+
+ -
+
+
+ 低通滤波
+
+
+
+
+
+ -
+
+
-
+
+
+ 均衡化
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 高斯模糊滤波
+
+
+
+ -
+
+
+ Roberts算子
+
+
+
+ -
+
+
+ Sobel算子
+
+
+
+ -
+
+
+ 高通滤波
+
+
+
+
+
+
+
+ -
+
+
+ 方框滤波
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 图像信息
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
+ 导入
+
+
+
+ -
+
+
+ 保存
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 高度:
+
+
+
+ -
+
+
+ 宽度:
+
+
+
+ -
+
+
+ 通道数:
+
+
+
+ -
+
+
+ 图片类型:
+
+
+
+
+
+ -
+
+
-
+
+
+ Label_H
+
+
+
+ -
+
+
+ Label_W
+
+
+
+ -
+
+
+ Label_T
+
+
+
+ -
+
+
+ Label_Type
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ -
+
+
-
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理前
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ 处理后
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ -
+
+
-
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 边缘检测
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ LoG检测
+
+
+
+ -
+
+
+ Scharr算子
+
+
+
+
+
+ -
+
+
+ Canny算子
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 基本处理
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
-
+
+
-
+
+
+ 灰度处理
+
+
+
+ -
+
+
+ 二值化处理
+
+
+
+
+
+ -
+
+
+ 几何变换
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
+ 深度学习
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+
+
+ -
+
+
+ 肺炎图像处理
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+ QSizePolicy::Fixed
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/unet/__init__.py b/unet/__init__.py
new file mode 100644
index 0000000..3bbcd52
--- /dev/null
+++ b/unet/__init__.py
@@ -0,0 +1 @@
+from .unet_model import UNet, SEUNet, UResnet34, UResnet50, UResnet101, UResnet152
diff --git a/unet/swin_transformer.py b/unet/swin_transformer.py
new file mode 100644
index 0000000..9fbae12
--- /dev/null
+++ b/unet/swin_transformer.py
@@ -0,0 +1,721 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def autopad(k, p=None, d=1): # kernel, padding, dilation
+ """Pad to 'same' shape outputs."""
+ if d > 1:
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
+class Conv(nn.Module):
+ """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
+ default_act = nn.SiLU() # default activation
+
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
+ """Initialize Conv layer with given arguments including activation."""
+ super().__init__()
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
+
+ def forward(self, x):
+ """Apply convolution, batch normalization and activation to input tensor."""
+ return self.act(self.bn(self.conv(x)))
+
+ def forward_fuse(self, x):
+ """Perform transposed convolution of 2D data."""
+ return self.act(self.conv(x))
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):# mg
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path_f(x, self.drop_prob, self.training)
+
+def drop_path_f(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+##### swin transformer #####
+class WindowAttention(nn.Module):
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ nn.init.normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ # print(attn.dtype, v.dtype)
+ try:
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ except:
+ # print(attn.dtype, v.dtype)
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Mlp(nn.Module): # tc
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ B, H, W, C = x.shape
+ assert H % window_size == 0, 'feature map h and w can not divide by window size'
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SwinTransformerLayer(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size=8, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ # if min(self.input_resolution) <= self.window_size:
+ # # if window size is larger than input resolution, we don't partition windows
+ # self.shift_size = 0
+ # self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def create_mask(self, H, W):
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x):
+ # reshape x[b c h w] to x[b l c]
+ _, _, H_, W_ = x.shape
+
+ Padding = False
+ if min(H_, W_) < self.window_size or H_ % self.window_size != 0 or W_ % self.window_size != 0:
+ Padding = True
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
+ x = F.pad(x, (0, pad_r, 0, pad_b))
+
+ # print('2', x.shape)
+ B, C, H, W = x.shape
+ L = H * W
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
+
+ # create mask from init to forward
+ if self.shift_size > 0:
+ attn_mask = self.create_mask(H, W).to(x.device)
+ else:
+ attn_mask = None
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
+
+ if Padding:
+ x = x[:, :, :H_, :W_] # reverse padding
+
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
+ super().__init__()
+ self.conv = None
+ if c1 != c2:
+ self.conv = Conv(c1, c2)
+
+ # remove input_resolution
+ self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in
+ range(num_layers)])
+
+ def forward(self, x):
+ if self.conv is not None:
+ x = self.conv(x)
+ x = self.blocks(x)
+ return x
+
+
+class STCSPA(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(STCSPA, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ y1 = self.m(self.cv1(x))
+ y2 = self.cv2(x)
+ return self.cv3(torch.cat((y1, y2), dim=1))
+
+
+class STCSPB(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(STCSPB, self).__init__()
+ c_ = int(c2) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ x1 = self.cv1(x)
+ y1 = self.m(x1)
+ y2 = self.cv2(x1)
+ return self.cv3(torch.cat((y1, y2), dim=1))
+
+
+class STCSPC(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(STCSPC, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(c_, c_, 1, 1)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(torch.cat((y1, y2), dim=1))
+
+
+##### end of swin transformer #####
+
+
+##### swin transformer v2 #####
+
+class WindowAttention_v2(nn.Module):
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True).to('cuda')
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ try:
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ except:
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class Mlp_v2(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition_v2(x, window_size):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse_v2(windows, window_size, H, W):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SwinTransformerLayer_v2(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ # self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ # if min(self.input_resolution) <= self.window_size:
+ # # if window size is larger than input resolution, we don't partition windows
+ # self.shift_size = 0
+ # self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention_v2(
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=(pretrained_window_size, pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def create_mask(self, H, W):
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x):
+ # reshape x[b c h w] to x[b l c]
+ _, _, H_, W_ = x.shape
+
+ Padding = False
+ if min(H_, W_) < self.window_size or H_ % self.window_size != 0 or W_ % self.window_size != 0:
+ Padding = True
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
+ x = F.pad(x, (0, pad_r, 0, pad_b))
+
+ # print('2', x.shape)
+ B, C, H, W = x.shape
+ L = H * W
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
+
+ # create mask from init to forward
+ if self.shift_size > 0:
+ attn_mask = self.create_mask(H, W).to(x.device)
+ else:
+ attn_mask = None
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
+
+ if Padding:
+ x = x[:, :, :H_, :W_] # reverse padding
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class SwinTransformer2Block(nn.Module):
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
+ super().__init__()
+ self.conv = None
+ if c1 != c2:
+ self.conv = Conv(c1, c2)
+
+ # remove input_resolution
+ self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i
+ in range(num_layers)])
+
+ def forward(self, x):
+ if self.conv is not None:
+ x = self.conv(x)
+ x = self.blocks(x)
+ return x
+
+
+class ST2CSPA(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(ST2CSPA, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ y1 = self.m(self.cv1(x))
+ y2 = self.cv2(x)
+ return self.cv3(torch.cat((y1, y2), dim=1))
+
+
+class ST2CSPB(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(ST2CSPB, self).__init__()
+ c_ = int(c2) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ x1 = self.cv1(x)
+ y1 = self.m(x1)
+ y2 = self.cv2(x1)
+ return self.cv3(torch.cat((y1, y2), dim=1))
+
+
+class ST2CSPC(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(ST2CSPC, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(c_, c_, 1, 1)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ num_heads = c_ // 32
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
+ # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(torch.cat((y1, y2), dim=1))
+
+##### end of swin transformer v2 #####
\ No newline at end of file
diff --git a/unet/unet_model.py b/unet/unet_model.py
new file mode 100644
index 0000000..e4d02c6
--- /dev/null
+++ b/unet/unet_model.py
@@ -0,0 +1,336 @@
+""" Full assembly of the parts to form the complete network """
+
+import torch
+import torch.nn as nn
+import torchvision.transforms as transforms
+from .unet_parts import DoubleConv, Down, Up, OutConv
+from .swin_transformer import SwinTransformerBlock, Conv
+
+class UNet(nn.Module):
+ def __init__(self, n_channels, n_classes, bilinear=False):
+ super(UNet, self).__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.bilinear = bilinear
+
+ self.inc = (DoubleConv(n_channels, 64))
+ self.down1 = (Down(64, 128))
+ self.down2 = (Down(128, 256))
+ self.down3 = (Down(256, 512))
+ factor = 2 if bilinear else 1
+ self.down4 = (Down(512, 1024 // factor))
+ self.up1 = (Up(1024, 512 // factor, bilinear))
+ self.up2 = (Up(512, 256 // factor, bilinear))
+ self.up3 = (Up(256, 128 // factor, bilinear))
+ self.up4 = (Up(128, 64, bilinear))
+ self.outc = (OutConv(64, n_classes))
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ logits = self.outc(x)
+ return logits
+
+ def use_checkpointing(self):
+ self.inc = torch.utils.checkpoint(self.inc)
+ self.down1 = torch.utils.checkpoint(self.down1)
+ self.down2 = torch.utils.checkpoint(self.down2)
+ self.down3 = torch.utils.checkpoint(self.down3)
+ self.down4 = torch.utils.checkpoint(self.down4)
+ self.up1 = torch.utils.checkpoint(self.up1)
+ self.up2 = torch.utils.checkpoint(self.up2)
+ self.up3 = torch.utils.checkpoint(self.up3)
+ self.up4 = torch.utils.checkpoint(self.up4)
+ self.outc = torch.utils.checkpoint(self.outc)
+
+class SEUNet(nn.Module):
+ def __init__(self, n_channels, n_classes, bilinear=False):
+ super(SEUNet, self).__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.bilinear = bilinear
+
+ self.inc = (DoubleConv(n_channels, 64))
+ self.down1 = (Down(64, 128))
+ self.se1 = SEBlock(128) # 添加SE Block
+ self.down2 = (Down(128, 256))
+ self.se2 = SEBlock(256) # 添加SE Block
+ self.down3 = (Down(256, 512))
+ self.se3 = SEBlock(512) # 添加SE Block
+ factor = 2 if bilinear else 1
+ self.down4 = (Down(512, 1024 // factor))
+ self.se4 = SEBlock(1024 // factor) # 添加SE Block
+ self.up1 = (Up(1024, 512 // factor, bilinear))
+ self.se_up1 = SEBlock(512 // factor) # 添加SE Block
+ self.up2 = (Up(512, 256 // factor, bilinear))
+ self.se_up2 = SEBlock(256 // factor) # 添加SE Block
+ self.up3 = (Up(256, 128 // factor, bilinear))
+ self.se_up3 = SEBlock(128 // factor) # 添加SE Block
+ self.up4 = (Up(128, 64, bilinear))
+ self.outc = (OutConv(64, n_classes))
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x2 = self.se1(x2) # 添加SE Block
+ x3 = self.down2(x2)
+ x3 = self.se2(x3) # 添加SE Block
+ x4 = self.down3(x3)
+ x4 = self.se3(x4) # 添加SE Block
+ x5 = self.down4(x4)
+ x5 = self.se4(x5) # 添加SE Block
+ x = self.up1(x5, x4)
+ x = self.se_up1(x) # 添加SE Block
+ x = self.up2(x, x3)
+ x = self.se_up2(x) # 添加SE Block
+ x = self.up3(x, x2)
+ x = self.se_up3(x) # 添加SE Block
+ x = self.up4(x, x1)
+ logits = self.outc(x)
+
+ return logits
+
+ def use_checkpointing(self):
+ self.inc = torch.utils.checkpoint(self.inc)
+ self.down1 = torch.utils.checkpoint(self.down1)
+ self.down2 = torch.utils.checkpoint(self.down2)
+ self.down3 = torch.utils.checkpoint(self.down3)
+ self.down4 = torch.utils.checkpoint(self.down4)
+ self.up1 = torch.utils.checkpoint(self.up1)
+ self.up2 = torch.utils.checkpoint(self.up2)
+ self.up3 = torch.utils.checkpoint(self.up3)
+ self.up4 = torch.utils.checkpoint(self.up4)
+ self.outc = torch.utils.checkpoint(self.outc)
+
+# SE注意力机制
+class SEBlock(nn.Module):
+ def __init__(self, in_channels, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(in_channels // reduction, in_channels, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ # print(f"Input size to SEBlock: {x.size()}")
+ y = self.avg_pool(x).view(b, c)
+ # print(f"Size after avg_pool: {y.size()}")
+ y = self.fc(y).view(b, c, 1, 1)
+ # print(f"Size after fc: {y.size()}")
+ return x * y.expand_as(x)
+
+class BasicBlock(nn.Module):
+ expansion = 1 # 通道扩充比例
+
+ def __init__(self, in_channels, out_channels, stride=1):
+ super().__init__()
+
+ self.residual_function = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
+ )
+
+ self.shortcut = nn.Sequential()
+
+ if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
+ )
+
+ def forward(self, x):
+ return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
+
+
+class BottleNeck(nn.Module):
+ expansion = 4
+ '''
+ expansion 是通道扩充的比例
+ 注意实际输出channel = middle_channels * BottleNeck.expansion
+ '''
+ def __init__(self, in_channels, middle_channels, stride=1):
+ super().__init__()
+ self.residual_function = nn.Sequential(
+ nn.Conv2d(in_channels, middle_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(middle_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(middle_channels, middle_channels, stride=stride, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(middle_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(middle_channels, middle_channels * BottleNeck.expansion, kernel_size=1, bias=False),
+ nn.BatchNorm2d(middle_channels * BottleNeck.expansion),
+ )
+
+ self.shortcut = nn.Sequential()
+
+ if stride != 1 or in_channels != middle_channels * BottleNeck.expansion:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, middle_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
+ nn.BatchNorm2d(middle_channels * BottleNeck.expansion)
+ )
+
+ def forward(self, x):
+ return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
+
+
+class Bottleneck(nn.Module):
+ """Standard bottleneck."""
+
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
+ """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
+ expansion.
+ """
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, k[0], 1)
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ """'forward()' applies the YOLO FPN to input data."""
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class C2f(nn.Module):
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
+
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
+ """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
+ expansion.
+ """
+ super().__init__()
+ self.c = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
+
+ def forward(self, x):
+ """Forward pass through C2f layer."""
+ y = list(self.cv1(x).chunk(2, 1))
+ y.extend(m(y[-1]) for m in self.m)
+ return self.cv2(torch.cat(y, 1))
+
+ def forward_split(self, x):
+ """Forward pass using split() instead of chunk()."""
+ y = list(self.cv1(x).split((self.c, self.c), 1))
+ y.extend(m(y[-1]) for m in self.m)
+ return self.cv2(torch.cat(y, 1))
+
+
+
+class C2fST(C2f):
+ """C2f module with Swin TransformerBlock()."""
+
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
+ super().__init__(c1, c2, n, shortcut, g, e)
+ num_heads = self.c // 32
+ self.m = nn.ModuleList(SwinTransformerBlock(self.c, self.c, num_heads, n) for _ in range(n))
+
+
+class UResnet(nn.Module):
+ def __init__(self, block, layers, n_channels, n_classes, bilinear):
+ super().__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.bilinear = bilinear
+ nb_filter = [64, 128, 256, 512, 1024]
+
+ self.in_channel = nb_filter[0]
+
+ self.pool = nn.MaxPool2d(2, 2)
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+
+ self.conv0_0 = DoubleConv(n_channels, nb_filter[0], nb_filter[0])
+ self.conv1_0 = self._make_layer(block, nb_filter[1], layers[0], 1)
+ self.trans1 = SwinTransformerBlock(nb_filter[1], nb_filter[1], num_heads = 4, num_layers = 2)
+ self.conv2_0 = self._make_layer(block, nb_filter[2], layers[1], 1)
+ self.trans2 = SwinTransformerBlock(nb_filter[2], nb_filter[2], num_heads = 8, num_layers = 2)
+ self.conv3_0 = self._make_layer(block, nb_filter[3], layers[2], 1)
+ self.trans3 = SwinTransformerBlock(nb_filter[3], nb_filter[3], num_heads = 16, num_layers = 2)
+ self.conv4_0 = self._make_layer(block, nb_filter[4], layers[3], 1)
+ self.trans4 = SwinTransformerBlock(nb_filter[4], nb_filter[4], num_heads = 32, num_layers = 2)
+
+ self.conv3_1 = DoubleConv((nb_filter[3] + nb_filter[4]) * block.expansion, nb_filter[3] * block.expansion, nb_filter[3])
+ self.conv2_2 = DoubleConv((nb_filter[2] + nb_filter[3]) * block.expansion, nb_filter[2] * block.expansion, nb_filter[2])
+ self.conv1_3 = DoubleConv((nb_filter[1] + nb_filter[2]) * block.expansion, nb_filter[1] * block.expansion, nb_filter[1])
+ self.conv0_4 = DoubleConv(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])
+
+ self.final = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
+
+ def _make_layer(self, block, middle_channel, num_blocks, stride):
+ """
+ middle_channels中间维度,实际输出channels = middle_channels * block.expansion
+ num_blocks,一个Layer包含block的个数
+ """
+
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ # for stride in strides:
+ # layers.append(block(self.in_channel, middle_channel, stride))
+ # self.in_channel = middle_channel * block.expansion
+ for stride in strides:
+ layers.append(block(self.in_channel, middle_channel, stride))
+ self.in_channel = middle_channel * block.expansion # 更新输入通道数为当前层的输出通道数
+ return nn.Sequential(*layers)
+
+ def forward(self, input):
+ x0_0 = self.conv0_0(input)
+ x1_0 = self.conv1_0(self.pool(x0_0))
+ x1_0 = self.trans1(x1_0)
+ # print("conv1_0:", x1_0.shape)
+
+ x2_0 = self.conv2_0(self.pool(x1_0))
+ x2_0 = self.trans2(x2_0)
+ # print("conv2_0:", x2_0.shape)
+
+ x3_0 = self.conv3_0(self.pool(x2_0))
+ x3_0 = self.trans3(x3_0)
+ # print("conv3_0:", x3_0.shape)
+
+ x4_0 = self.conv4_0(self.pool(x3_0))
+ x4_0 = self.trans4(x4_0)
+ # print("conv4_0:", x4_0.shape)
+
+ # x4_0 = self.trans1(self.pool(x3_0))
+
+ x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
+ x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
+ x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
+ x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
+
+ output = self.final(x0_4)
+ return output
+
+
+class UResnet34(UResnet):
+ def __init__(self, n_channels, n_classes=2, bilinear=False):
+ super(UResnet34, self).__init__(block=BasicBlock,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
+
+
+class UResnet50(UResnet):
+ def __init__(self, n_channels, n_classes=2, bilinear=False):
+ super(UResnet50, self).__init__(block=BottleNeck,layers=[3,4,6,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
+
+
+class UResnet101(UResnet):
+ def __init__(self, n_channels, n_classes=2, bilinear=False):
+ super(UResnet101, self).__init__(block=BottleNeck,layers=[3,4,23,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
+
+
+class UResnet152(UResnet):
+ def __init__(self, n_channels, n_classes=2, bilinear=False):
+ super(UResnet152, self).__init__(block=BottleNeck,layers=[3,8,36,3], n_channels=n_channels, n_classes=n_classes, bilinear=bilinear)
\ No newline at end of file
diff --git a/unet/unet_parts.py b/unet/unet_parts.py
new file mode 100644
index 0000000..986ba25
--- /dev/null
+++ b/unet/unet_parts.py
@@ -0,0 +1,77 @@
+""" Parts of the U-Net model """
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DoubleConv(nn.Module):
+ """(convolution => [BN] => ReLU) * 2"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Down(nn.Module):
+ """Downscaling with maxpool then double conv"""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool2d(2),
+ DoubleConv(in_channels, out_channels)
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+
+class Up(nn.Module):
+ """Upscaling then double conv"""
+
+ def __init__(self, in_channels, out_channels, bilinear=True):
+ super().__init__()
+
+ # if bilinear, use the normal convolutions to reduce the number of channels
+ if bilinear:
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+ else:
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x1, x2):
+ x1 = self.up(x1)
+ # input is CHW
+ diffY = x2.size()[2] - x1.size()[2]
+ diffX = x2.size()[3] - x1.size()[3]
+
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+ diffY // 2, diffY - diffY // 2])
+ # if you have padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+ x = torch.cat([x2, x1], dim=1)
+ return self.conv(x)
+
+
+class OutConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(OutConv, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ return self.conv(x)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/data_loading.py b/utils/data_loading.py
new file mode 100644
index 0000000..2d23617
--- /dev/null
+++ b/utils/data_loading.py
@@ -0,0 +1,146 @@
+import torch
+import torch.nn as nn
+import torchvision.transforms as transforms
+from torch.utils.data import Dataset
+from PIL import Image
+import os
+import numpy as np
+import random
+from pathlib import Path
+import cv2 # 导入 OpenCV
+import torch.nn.functional as F
+
+
+
+
+def load_image(filename):
+ return Image.open(filename)
+
+class BasicDataset(Dataset):
+ def __init__(self, images_dir: Path, mask_dir: Path, scale: float = 1.0, crop_size: int = None, use_erosion=False, erosion_size=3):
+ self.images_dir = images_dir
+ self.mask_dir = mask_dir
+ assert 0 < scale <= 1, 'Scale must be between 0 and 1'
+ self.scale = scale
+ self.crop_size = crop_size if crop_size is not None else 224 # 设置默认裁剪尺寸
+ self.use_erosion = use_erosion
+ self.erosion_size = erosion_size
+
+ self.files = []
+ for dirpath, _, filenames in os.walk(images_dir):
+ for filename in [f for f in filenames if f.endswith(".png") or f.endswith(".jpg")]:
+ image_path = Path(dirpath) / filename
+ relative_path = image_path.relative_to(images_dir)
+ mask_file = relative_path.with_suffix('.png')
+ mask_path = mask_dir / mask_file
+
+ if mask_path.exists():
+ self.files.append((image_path, mask_path))
+ else:
+ print(f"Mask file {mask_path} not found for image {image_path}")
+
+ if not self.files:
+ raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
+
+ self.mask_values = self.calculate_unique_mask_values()
+ self.transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomVerticalFlip(),
+ transforms.RandomRotation(10)
+ ])
+
+ def calculate_unique_mask_values(self):
+ unique_values = set()
+ for _, mask_path in self.files:
+ mask = np.asarray(Image.open(mask_path))
+ unique_values.update(np.unique(mask))
+ return sorted(list(unique_values))
+
+ def __len__(self):
+ return len(self.files)
+
+ @staticmethod
+ def preprocess(mask_values, pil_img, scale, is_mask):
+ w, h = pil_img.size
+ newW, newH = int(scale * w), int(scale * h)
+ assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
+ pil_img = pil_img.resize((newW, newH))
+ img = np.asarray(pil_img)
+
+ if is_mask:
+ # if self.use_erosion:
+ # # 转换为灰度图,如果不是灰度图
+ # if img.ndim == 3:
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ # # 定义腐蚀操作的核
+ # kernel = np.ones((self.erosion_size, self.erosion_size), np.uint8)
+ # # 应用腐蚀操作
+ # img = cv2.erode(img, kernel, iterations=1)
+
+ mask = np.zeros((newH, newW), dtype=np.int64)
+ for i, v in enumerate(mask_values):
+ if img.ndim == 2:
+ mask[img == v] = i
+ else:
+ mask[(img == v).all(-1)] = i
+ return mask
+ else:
+ if img.ndim == 2:
+ img = img[np.newaxis, ...]
+ else:
+ img = img.transpose((2, 0, 1)) # Convert to C, H, W
+
+ if (img > 1).any():
+ img = img / 255.0
+ return img
+
+ def random_crop(self, img, mask):
+ assert img.size == mask.size, \
+ f'Image and mask should be the same size, but are {img.size} and {mask.size}'
+
+ w, h = img.size
+ th, tw = self.crop_size, self.crop_size
+
+ if w == tw and h == th:
+ return img, mask
+
+ x1 = random.randint(0, w - tw)
+ y1 = random.randint(0, h - th)
+ img = img.crop((x1, y1, x1 + tw, y1 + th))
+ mask = mask.crop((x1, y1, x1 + tw, y1 + th))
+
+ return img, mask
+
+ def __getitem__(self, idx):
+ try:
+ image_path, mask_path = self.files[idx]
+ img = load_image(image_path)
+ mask = load_image(mask_path)
+ assert img.size == mask.size, \
+ f'Image and mask should be the same size, but are {img.size} and {mask.size}'
+
+ if self.crop_size:
+ img, mask = self.random_crop(img, mask)
+
+ # Apply additional data augmentations to both image and mask
+ # seed = np.random.randint(2147483647) # Make a seed with numpy generator
+ # random.seed(seed) # Apply this seed to img transforms
+ # img = self.transform(img)
+ # random.seed(seed) # Apply the same seed to mask transforms
+ # mask = self.transform(mask)
+
+ img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
+ mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
+
+ return {
+ 'image': torch.as_tensor(img.copy()).float().contiguous().clone(),
+ 'mask': torch.as_tensor(mask.copy()).long().contiguous().clone()
+ }
+
+ except Exception as e:
+ print(f"Error processing file {self.files[idx]}: {e}")
+ raise
+
+class LungDataset(BasicDataset):
+ def __init__(self, images_dir, mask_dir, scale=1, crop_size=None, use_erosion=False, erosion_size=None):
+ super().__init__(images_dir, mask_dir, scale, crop_size, use_erosion, erosion_size)
diff --git a/utils/dice_score.py b/utils/dice_score.py
new file mode 100644
index 0000000..c89eebe
--- /dev/null
+++ b/utils/dice_score.py
@@ -0,0 +1,28 @@
+import torch
+from torch import Tensor
+
+
+def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
+ # Average of Dice coefficient for all batches, or for a single mask
+ assert input.size() == target.size()
+ assert input.dim() == 3 or not reduce_batch_first
+
+ sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
+
+ inter = 2 * (input * target).sum(dim=sum_dim)
+ sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
+ sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
+
+ dice = (inter + epsilon) / (sets_sum + epsilon)
+ return dice.mean()
+
+
+def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
+ # Average of Dice coefficient for all classes
+ return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
+
+
+def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
+ # Dice loss (objective to minimize) between 0 and 1
+ fn = multiclass_dice_coeff if multiclass else dice_coeff
+ return 1 - fn(input, target, reduce_batch_first=True)
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..6ed8d8d
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,13 @@
+import matplotlib.pyplot as plt
+
+
+def plot_img_and_mask(img, mask):
+ classes = mask.max() + 1
+ fig, ax = plt.subplots(1, classes + 1)
+ ax[0].set_title('Input image')
+ ax[0].imshow(img)
+ for i in range(classes):
+ ax[i + 1].set_title(f'Mask (class {i + 1})')
+ ax[i + 1].imshow(mask == i)
+ plt.xticks([]), plt.yticks([])
+ plt.show()