You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
6.4 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
from PyQt5 import QtGui
from PyQt5.QtWidgets import *
import cv2
import numpy as np
from PyQt5 import QtWidgets, QtCore
import sys
from PyQt5.QtCore import *
import time
import utils
from utils import *
from models import *
import detect
import utils.general
from pathlib import Path
def init(self):
self.unit7_img = np.ndarray(())
self.unit7_img_channel = 1
self.unit7_result = np.ndarray(())
self.unit7_result_channel = 1
self.unit7_filepath = ''
self.unit7_imgpath = ''
self.unit7_savepath = ''
self.unit7_suffix = ''
self.ui.textBrowser_6.setText('')
self.ui.textBrowser_7.setText('')
def img_load(self):
fileName, tmp = QFileDialog.getOpenFileName(self, '打开图像', 'Image', '*.png *.jpg *.bmp *.jpeg')
if fileName == '':
return
self.unit7_img = np.ndarray(())
self.unit7_img_channel = 1
self.unit7_result = np.ndarray(())
self.unit7_result_channel = 1
self.unit7_img = cv2.imread(fileName, -1)
self.unit7_suffix = fileName.split('/')[-1]
print(self.unit7_suffix)
self.unit7_imgpath = fileName
if self.unit7_img.size <= 1:
return
if len(self.unit7_img.shape) == 3:
self.unit7_img_channel = 3
if self.unit7_img.shape[2] == 4:
self.unit7_img = cv2.cvtColor(self.unit7_img, cv2.COLOR_BGRA2BGR)
print(self.unit7_img.shape)
img_refresh(self)
def img_refresh(self):
array = \
[self.unit7_img,
self.unit7_result]
array2 = [self.ui.label_54,
self.ui.label_55]
channel = [self.unit7_img_channel,
self.unit7_result_channel]
height = 480
weight = 500
for index in range(len(array)):
M = np.float32([[1, 0, 0], [0, 1, 0]])
if array[index].size <= 1:
array2[index].setPixmap(QtGui.QPixmap(''))
continue
print(array[index].shape)
index_h = array[index].shape[0]
index_w = array[index].shape[1]
if index_h / index_w == height / weight:
img = array[index].tobytes()
if channel[index] == 1:
image = QtGui.QImage(img, index_w, index_h, index_w * channel[index], QtGui.QImage.Format_Grayscale8)
pix = QtGui.QPixmap.fromImage(image)
scale_pix = pix.scaled(weight, height)
array2[index].setPixmap(scale_pix)
continue
elif channel[index] == 3:
image = QtGui.QImage(img, index_w, index_h, index_w * channel[index], QtGui.QImage.Format_BGR888)
pix = QtGui.QPixmap.fromImage(image)
scale_pix = pix.scaled(weight, height)
array2[index].setPixmap(scale_pix)
continue
elif index_h / index_w > height / weight:
h_ = index_h
w_ = int(index_h * weight / height + 0.5)
M[0, 2] += (w_ - index_w) / 2
M[1, 2] += (h_ - index_h) / 2
else:
h_ = int(index_w * height / weight + 0.5)
w_ = index_w
M[0, 2] += (w_ - index_w) / 2
M[1, 2] += (h_ - index_h) / 2
img = cv2.warpAffine(array[index], M, (w_, h_))
data = img.tobytes()
if channel[index] == 1:
image = QtGui.QImage(data, w_, h_, w_ * channel[index], QtGui.QImage.Format_Grayscale8)
pix = QtGui.QPixmap.fromImage(image)
scale_pix = pix.scaled(weight, height)
array2[index].setPixmap(scale_pix)
continue
else:
image = QtGui.QImage(data, w_, h_, w_ * channel[index], QtGui.QImage.Format_BGR888)
pix = QtGui.QPixmap.fromImage(image)
scale_pix = pix.scaled(weight, height)
array2[index].setPixmap(scale_pix)
continue
return
def result_save(self):
fileName= QFileDialog.getExistingDirectory(self, '保存图像')
if fileName == '':
return
self.unit7_savepath = fileName
self.ui.textBrowser_7.setText(fileName.split('/')[-2]+'/'+fileName.split('/')[-1])
msg_box = QMessageBox(QMessageBox.Information, '成功', '选择路径成功,保存路径为:' + fileName)
msg_box.exec_()
def clear(self):
if self.unit7_img.size > 1:
init(self)
img_refresh(self)
else:
msg_box = QMessageBox(QMessageBox.Warning, '无需清空', '没有图片')
msg_box.exec_()
def result_show(self):
if self.unit7_result.size > 1:
cv2.imshow('Original pic', self.unit7_result)
cv2.waitKey(0)
else:
msg_box = QMessageBox(QMessageBox.Warning, '没有图像', '没有生成图像')
msg_box.exec_()
def object_detection(self):
if self.unit7_filepath !='' and self.unit7_img.size>1 and self.unit7_savepath!='':
modelpath = self.unit7_filepath
imgpath = self.unit7_imgpath
savepath = self.unit7_savepath
detect.main(imgpath, modelpath,savepath)
name ='exp'
z = utils.general.increment_path_num(Path(savepath) / name, exist_ok=False)
num = str(z) if z!=1 else ''
path = savepath +'/exp'+ num+'/'+self.unit7_suffix
print(path)
self.unit7_result = cv2.imread(path, -1)
if self.unit7_result.size >1:
if len(self.unit7_result.shape) == 3:
self.unit7_result_channel = 3
if self.unit7_result.shape[2] == 4:
self.unit7_result = cv2.cvtColor(self.unit7_result, cv2.COLOR_BGRA2BGR)
print(self.unit7_result.shape)
img_refresh(self)
else:
msg_box = QMessageBox(QMessageBox.Warning, 'error', 'error1')
msg_box.exec_()
else:
msg_box = QMessageBox(QMessageBox.Warning, '没有导入模型或图片', '请导入模型和图片后再进行尝试')
msg_box.exec_()
def model_load(self):
fileName, tmp = QFileDialog.getOpenFileName(self, '选择模型路径', 'Model', '*.pt')
if fileName == '':
return
self.unit7_filepath = fileName
self.ui.textBrowser_6.setText(fileName.split('/')[-2]+'/'+fileName.split('/')[-1])
print(self.unit7_filepath)
if self.unit7_filepath =='':
return
else:
msg_box = QMessageBox(QMessageBox.Information, '已检测到模型', '模型导入成功')
msg_box.exec_()