diff --git a/main.py b/main.py new file mode 100644 index 0000000..d80e73b --- /dev/null +++ b/main.py @@ -0,0 +1,115 @@ +import sys + +import cv2 +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtCore import QThread +from PyQt5.QtWidgets import QFileDialog, QMessageBox +from PyQt5 import QtWidgets, QtGui +import numpy as np + +import model +import utils +from GUI import Ui_MainWindow +import tensorflow.compat.v1 as tf + + +def cnn(picture_name): + picture_name = cv2.imread(picture_name) + image = cv2.resize(picture_name, (utils.IMAGE_WIDTH, utils.IMAGE_HEIGHT)) + + normalizer_image = image / 255.0 - 0.5 + + image = np.array(normalizer_image) + + x, _, _, result = model.get_model(is_train=False, keep_prob=1) + + with tf.Session() as sess: + pics = [] + saver = tf.train.Saver() + saver.restore(sess, "./result/result.ckpt") + pics.append(image) + # 概率 + prediction = sess.run(result, feed_dict={x: pics}) + # 类别 + pred_result = np.argmax(prediction) + return pred_result, prediction + + +class dpv_thread(QThread): # 检测的线程 + img_breakSignal = pyqtSignal(int) + + def __init__(self, parent=None): + super(dpv_thread, self).__init__() + + def run(self): + global picture_name, pred_result, prediction + pred_result, prediction = cnn(picture_name) + self.img_breakSignal.emit(prediction) + + +class HelmetWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self): + super(HelmetWindow, self).__init__() + self.setupUi(self) + self.setWindowTitle("花花检测") # 设置窗口程序标题 + self.setStyleSheet("#MainWindow{background-color:lightskyblue}") + + self.read_img.clicked.connect(self.Read_img) + self.start_detect.clicked.connect(self.Start_detect) + self.show_img.setScaledContents(True) + self.DPV_thread = dpv_thread() # 实例化检测 + self.DPV_thread.img_breakSignal.connect(self.Show_text) + + def Read_img(self): + global picture_name + picture_name, imgType = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg;;*.png;;All Files(*)") + jpg = QtGui.QPixmap(picture_name).scaled(self.show_img.width(), self.show_img.height()) + self.show_img.setPixmap(jpg) + print(picture_name) + if picture_name == '': + return '图片打开失败' + + def Start_detect(self): + self.DPV_thread.start() + + def Show_text(self): + global pred_result, prediction + if pred_result == 0: + show = ('桃花: %.2f' % prediction[:, 0]) + elif pred_result == 1: + show = ('梅花: %.2f' % prediction[:, 1]) + elif pred_result == 2: + show = ('牡丹: %.2f' % prediction[:, 2]) + elif pred_result == 3: + show = ('牵牛花: %.2f' % prediction[:, 3]) + elif pred_result == 4: + show = ('玫瑰: %.2f' % prediction[:, 4]) + elif pred_result == 5: + show = ('茉莉: %.2f' % prediction[:, 5]) + else: + show = ('蒲公英: %.2f' % prediction[:, 6]) + self.show_text.append(show) + + def closeEvent(self, event): # 关闭窗口响应函数,在这里面弹出确认框,并销毁线程 + reply = QMessageBox.question(self, + '本程序', + "是否要退出程序?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.No) + if reply == QMessageBox.Yes: + event.accept() + else: + event.ignore() + + def mousePressEvent(self, QMouseEvent): + if self.isMaximized(): + self.showNormal() + else: + self.showMaximized() + + +if __name__ == '__main__': + app = QtWidgets.QApplication(sys.argv) + ui = HelmetWindow() + ui.show() + sys.exit(app.exec_())