You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
6.2 KiB
178 lines
6.2 KiB
from openvino.inference_engine import IECore
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
import time
|
|
from skimage import transform, io
|
|
from PIL import Image
|
|
from PyQt5 import QtGui
|
|
from PyQt5.QtWidgets import *
|
|
from PyQt5.QtCore import *
|
|
# import openvino.inference_engine.constants
|
|
import cv2
|
|
|
|
def preprocess(image_dir):
|
|
|
|
image = cv2.imread(image_dir) # PIL格式 RGB
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = cv2.resize(image, (320, 320))
|
|
image = np.asarray(image)
|
|
image = np.swapaxes(image, 1, 2) # 转换维度
|
|
image = np.swapaxes(image, 0, 1)
|
|
# 归一化处理
|
|
image[:, :, 0] = (image[:, :, 0] - np.min(image[:, :, 0])) / (np.max(image[:, :, 0]) - np.min(image[:, :, 0]))
|
|
image[:, :, 0] = (image[:, :, 0] - np.mean(image[:, :, 0])) / (np.std(image[:, :, 0]))
|
|
image[:, :, 1] = (image[:, :, 1] - np.min(image[:, :, 1])) / (np.max(image[:, :, 1]) - np.min(image[:, :, 1]))
|
|
image[:, :, 1] = (image[:, :, 1] - np.mean(image[:, :, 1])) / (np.std(image[:, :, 1]))
|
|
image[:, :, 2] = (image[:, :, 2] - np.min(image[:, :, 2])) / (np.max(image[:, :, 2]) - np.min(image[:, :, 2]))
|
|
image[:, :, 2] = (image[:, :, 2] - np.mean(image[:, :, 2])) / (np.std(image[:, :, 2]))
|
|
image = np.expand_dims(image, axis=0) # 在零维度加1
|
|
# print("image_shape has been resized to : ", image.shape)
|
|
image = torch.from_numpy(image) # 转成tensor
|
|
return image
|
|
|
|
|
|
def reprocess(image, image_dir, save_dir):
|
|
image = image[:, 0, :, :]
|
|
ma = np.max(image)
|
|
mi = np.min(image)
|
|
image = (image - mi) / (ma - mi)
|
|
image = image.squeeze()
|
|
# print("image_np shape is ", image_np.shape) 320*320
|
|
# 二值化处理
|
|
# for h in range(320):
|
|
# for w in range(320):
|
|
# if(image_np[h][w] < 30.0/255.0):
|
|
# image_np[h][w] = 0
|
|
# else:
|
|
# image_np[h][w] = 1
|
|
# image = Image.fromarray(image * 255).convert('RGB') # 从numpy转成PIL格式
|
|
# image_dir is initialized in PyQt where the user open the image
|
|
# 打开原图
|
|
image = image * 255
|
|
pri_image = cv2.imread(image_dir) # PIL格式 RGB
|
|
h, w = pri_image.shape[0], pri_image.shape[1]
|
|
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
# 保存到本地
|
|
cv2.imwrite(save_dir, image)
|
|
# print("inferenced image has been saved at ", save_dir)
|
|
return image
|
|
|
|
|
|
def openimage():
|
|
imgName, imgType = QFileDialog.getOpenFileName()
|
|
jpg = QtGui.QPixmap(imgName).scaled(label.width(), label.height())
|
|
label.setPixmap(jpg)
|
|
# pri_image_dir = imgName
|
|
image_dir = imgName
|
|
ie = IECore()
|
|
save_dir = os.path.join(os.getcwd() + os.sep + "inference_show.png") # 保存路径
|
|
model_name = 'u2netp(canny)'
|
|
xml_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.xml')
|
|
bin_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.bin')
|
|
net = ie.read_network(xml_dir, bin_dir)
|
|
input_blob = next(iter(net.input_info))
|
|
out_blob = next(iter(net.outputs))
|
|
net.batch_size = 1
|
|
# n, c, h, w = net.input_info[input_blob].input_data.shape
|
|
# print("Model input data shape = %d, %d, %d, %d" % (n, c, h, w))
|
|
# print("image dir = ", image_dir)
|
|
image = preprocess(image_dir=image_dir)
|
|
exec_net = ie.load_network(network=net, device_name=("CUDA" if torch.cuda.is_available() else "CPU"))
|
|
# 写界面的同学记得在选择图片后加一句
|
|
# pri_image_dir = "用户选择打开图片的路径"
|
|
start_time = time.time()
|
|
res = exec_net.infer(inputs={input_blob: image})
|
|
res = res[out_blob]
|
|
# print("res : ", res.shape)
|
|
# res = torch.from_numpy(res)
|
|
# 后处理
|
|
final_img = reprocess(res, image_dir=image_dir, save_dir=save_dir)
|
|
# compute the execute time
|
|
# ----------------------------------------------------------------
|
|
init_image = Image.open(image_dir) # RGB PIL [n, c, h, w]
|
|
|
|
init_image = np.array(init_image)
|
|
|
|
init_image = np.squeeze(init_image)
|
|
mask = np.squeeze(final_img)
|
|
# print("mask shape : ", mask.shape)
|
|
|
|
img_r = init_image[:, :, 0]
|
|
img_g = init_image[:, :, 1]
|
|
img_b = init_image[:, :, 2]
|
|
# print("img_r/g/b shape : ", img_b.shape)
|
|
|
|
mask_r = mask[:, :, 0]
|
|
mask_g = mask[:, :, 1]
|
|
mask_b = mask[:, :, 2]
|
|
# print("mask_r/g/b : ", mask_r.shape)
|
|
|
|
height = img_r.shape[0]
|
|
weight = img_r.shape[1]
|
|
for h in range(height):
|
|
for w in range(weight):
|
|
if mask_r[h][w] <= 50:
|
|
img_r[h][w] = 255
|
|
if mask_g[h][w] <= 50:
|
|
img_g[h][w] = 255
|
|
if mask_b[h][w] <= 50:
|
|
img_b[h][w] = 255
|
|
|
|
re_img = np.stack((img_r, img_g, img_b), axis=-1)
|
|
# print("re_img shape : ", re_img.shape)
|
|
final_img = Image.fromarray(re_img)
|
|
final_save_dir = os.path.join(os.getcwd() + os.sep + "final_img.png")
|
|
final_img.save(final_save_dir)
|
|
# ----------------------------------------------------------------
|
|
|
|
exe_time = time.time() - start_time
|
|
print("---total execute time = ", exe_time)
|
|
ccc = save_dir # 图片路径
|
|
jpg = QtGui.QPixmap(ccc).scaled(label.width(), label.height())
|
|
label_2.setPixmap(jpg)
|
|
ddd = final_save_dir # 图片路径
|
|
jpg = QtGui.QPixmap(ddd).scaled(label.width(), label.height())
|
|
label_3.setPixmap(jpg)
|
|
# 显示时间
|
|
a = exe_time # time为float类型变量
|
|
label_4.setText("所用时间为:" + str(a) + "s")
|
|
|
|
# 窗口
|
|
app = QApplication([])
|
|
|
|
window = QMainWindow()
|
|
window.resize(1200, 600)
|
|
window.move(300, 310)
|
|
window.setWindowTitle('前景分割')
|
|
|
|
#三张图片加时间显示的label
|
|
label = QLabel(window)
|
|
label.move(10,80)
|
|
label.resize(320,320)
|
|
|
|
label_2 = QLabel(window)
|
|
label_2.move(450,80)
|
|
label_2.resize(320,320)
|
|
|
|
label_3 = QLabel(window)
|
|
label_3.move(800,80)
|
|
label_3.resize(320,320)
|
|
|
|
label_4 = QLabel(window)
|
|
label_4.move(450, 500)
|
|
label_4.resize(300, 50)
|
|
|
|
button = QPushButton('RUN', window)
|
|
button.move(500, 20)
|
|
button.clicked.connect(openimage)
|
|
|
|
button_2 = QPushButton('STOP', window)
|
|
button_2.move(1000,500)
|
|
button_2.clicked.connect(QCoreApplication.quit)
|
|
|
|
window.show()
|
|
|
|
app.exec_()
|
|
|