You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
6.2 KiB

from openvino.inference_engine import IECore
import os
import torch
import numpy as np
import time
from skimage import transform, io
from PIL import Image
from PyQt5 import QtGui
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
# import openvino.inference_engine.constants
import cv2
def preprocess(image_dir):
image = cv2.imread(image_dir) # PIL格式 RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (320, 320))
image = np.asarray(image)
image = np.swapaxes(image, 1, 2) # 转换维度
image = np.swapaxes(image, 0, 1)
# 归一化处理
image[:, :, 0] = (image[:, :, 0] - np.min(image[:, :, 0])) / (np.max(image[:, :, 0]) - np.min(image[:, :, 0]))
image[:, :, 0] = (image[:, :, 0] - np.mean(image[:, :, 0])) / (np.std(image[:, :, 0]))
image[:, :, 1] = (image[:, :, 1] - np.min(image[:, :, 1])) / (np.max(image[:, :, 1]) - np.min(image[:, :, 1]))
image[:, :, 1] = (image[:, :, 1] - np.mean(image[:, :, 1])) / (np.std(image[:, :, 1]))
image[:, :, 2] = (image[:, :, 2] - np.min(image[:, :, 2])) / (np.max(image[:, :, 2]) - np.min(image[:, :, 2]))
image[:, :, 2] = (image[:, :, 2] - np.mean(image[:, :, 2])) / (np.std(image[:, :, 2]))
image = np.expand_dims(image, axis=0) # 在零维度加1
# print("image_shape has been resized to : ", image.shape)
image = torch.from_numpy(image) # 转成tensor
return image
def reprocess(image, image_dir, save_dir):
image = image[:, 0, :, :]
ma = np.max(image)
mi = np.min(image)
image = (image - mi) / (ma - mi)
image = image.squeeze()
# print("image_np shape is ", image_np.shape) 320*320
# 二值化处理
# for h in range(320):
# for w in range(320):
# if(image_np[h][w] < 30.0/255.0):
# image_np[h][w] = 0
# else:
# image_np[h][w] = 1
# image = Image.fromarray(image * 255).convert('RGB') # 从numpy转成PIL格式
# image_dir is initialized in PyQt where the user open the image
# 打开原图
image = image * 255
pri_image = cv2.imread(image_dir) # PIL格式 RGB
h, w = pri_image.shape[0], pri_image.shape[1]
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
# 保存到本地
cv2.imwrite(save_dir, image)
# print("inferenced image has been saved at ", save_dir)
return image
def openimage():
imgName, imgType = QFileDialog.getOpenFileName()
jpg = QtGui.QPixmap(imgName).scaled(label.width(), label.height())
label.setPixmap(jpg)
# pri_image_dir = imgName
image_dir = imgName
ie = IECore()
save_dir = os.path.join(os.getcwd() + os.sep + "inference_show.png") # 保存路径
model_name = 'u2netp(canny)'
xml_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.xml')
bin_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.bin')
net = ie.read_network(xml_dir, bin_dir)
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
net.batch_size = 1
# n, c, h, w = net.input_info[input_blob].input_data.shape
# print("Model input data shape = %d, %d, %d, %d" % (n, c, h, w))
# print("image dir = ", image_dir)
image = preprocess(image_dir=image_dir)
exec_net = ie.load_network(network=net, device_name=("CUDA" if torch.cuda.is_available() else "CPU"))
# 写界面的同学记得在选择图片后加一句
# pri_image_dir = "用户选择打开图片的路径"
start_time = time.time()
res = exec_net.infer(inputs={input_blob: image})
res = res[out_blob]
# print("res : ", res.shape)
# res = torch.from_numpy(res)
# 后处理
final_img = reprocess(res, image_dir=image_dir, save_dir=save_dir)
# compute the execute time
# ----------------------------------------------------------------
init_image = Image.open(image_dir) # RGB PIL [n, c, h, w]
init_image = np.array(init_image)
init_image = np.squeeze(init_image)
mask = np.squeeze(final_img)
# print("mask shape : ", mask.shape)
img_r = init_image[:, :, 0]
img_g = init_image[:, :, 1]
img_b = init_image[:, :, 2]
# print("img_r/g/b shape : ", img_b.shape)
mask_r = mask[:, :, 0]
mask_g = mask[:, :, 1]
mask_b = mask[:, :, 2]
# print("mask_r/g/b : ", mask_r.shape)
height = img_r.shape[0]
weight = img_r.shape[1]
for h in range(height):
for w in range(weight):
if mask_r[h][w] <= 50:
img_r[h][w] = 255
if mask_g[h][w] <= 50:
img_g[h][w] = 255
if mask_b[h][w] <= 50:
img_b[h][w] = 255
re_img = np.stack((img_r, img_g, img_b), axis=-1)
# print("re_img shape : ", re_img.shape)
final_img = Image.fromarray(re_img)
final_save_dir = os.path.join(os.getcwd() + os.sep + "final_img.png")
final_img.save(final_save_dir)
# ----------------------------------------------------------------
exe_time = time.time() - start_time
print("---total execute time = ", exe_time)
ccc = save_dir # 图片路径
jpg = QtGui.QPixmap(ccc).scaled(label.width(), label.height())
label_2.setPixmap(jpg)
ddd = final_save_dir # 图片路径
jpg = QtGui.QPixmap(ddd).scaled(label.width(), label.height())
label_3.setPixmap(jpg)
# 显示时间
a = exe_time # time为float类型变量
label_4.setText("所用时间为:" + str(a) + "s")
# 窗口
app = QApplication([])
window = QMainWindow()
window.resize(1200, 600)
window.move(300, 310)
window.setWindowTitle('前景分割')
#三张图片加时间显示的label
label = QLabel(window)
label.move(10,80)
label.resize(320,320)
label_2 = QLabel(window)
label_2.move(450,80)
label_2.resize(320,320)
label_3 = QLabel(window)
label_3.move(800,80)
label_3.resize(320,320)
label_4 = QLabel(window)
label_4.move(450, 500)
label_4.resize(300, 50)
button = QPushButton('RUN', window)
button.move(500, 20)
button.clicked.connect(openimage)
button_2 = QPushButton('STOP', window)
button_2.move(1000,500)
button_2.clicked.connect(QCoreApplication.quit)
window.show()
app.exec_()