parent
5a94f1aeb5
commit
fb2464d3cf
@ -0,0 +1,177 @@
|
||||
from openvino.inference_engine import IECore
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from skimage import transform, io
|
||||
from PIL import Image
|
||||
from PyQt5 import QtGui
|
||||
from PyQt5.QtWidgets import *
|
||||
from PyQt5.QtCore import *
|
||||
# import openvino.inference_engine.constants
|
||||
import cv2
|
||||
|
||||
def preprocess(image_dir):
|
||||
|
||||
image = cv2.imread(image_dir) # PIL格式 RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (320, 320))
|
||||
image = np.asarray(image)
|
||||
image = np.swapaxes(image, 1, 2) # 转换维度
|
||||
image = np.swapaxes(image, 0, 1)
|
||||
# 归一化处理
|
||||
image[:, :, 0] = (image[:, :, 0] - np.min(image[:, :, 0])) / (np.max(image[:, :, 0]) - np.min(image[:, :, 0]))
|
||||
image[:, :, 0] = (image[:, :, 0] - np.mean(image[:, :, 0])) / (np.std(image[:, :, 0]))
|
||||
image[:, :, 1] = (image[:, :, 1] - np.min(image[:, :, 1])) / (np.max(image[:, :, 1]) - np.min(image[:, :, 1]))
|
||||
image[:, :, 1] = (image[:, :, 1] - np.mean(image[:, :, 1])) / (np.std(image[:, :, 1]))
|
||||
image[:, :, 2] = (image[:, :, 2] - np.min(image[:, :, 2])) / (np.max(image[:, :, 2]) - np.min(image[:, :, 2]))
|
||||
image[:, :, 2] = (image[:, :, 2] - np.mean(image[:, :, 2])) / (np.std(image[:, :, 2]))
|
||||
image = np.expand_dims(image, axis=0) # 在零维度加1
|
||||
# print("image_shape has been resized to : ", image.shape)
|
||||
image = torch.from_numpy(image) # 转成tensor
|
||||
return image
|
||||
|
||||
|
||||
def reprocess(image, image_dir, save_dir):
|
||||
image = image[:, 0, :, :]
|
||||
ma = np.max(image)
|
||||
mi = np.min(image)
|
||||
image = (image - mi) / (ma - mi)
|
||||
image = image.squeeze()
|
||||
# print("image_np shape is ", image_np.shape) 320*320
|
||||
# 二值化处理
|
||||
# for h in range(320):
|
||||
# for w in range(320):
|
||||
# if(image_np[h][w] < 30.0/255.0):
|
||||
# image_np[h][w] = 0
|
||||
# else:
|
||||
# image_np[h][w] = 1
|
||||
# image = Image.fromarray(image * 255).convert('RGB') # 从numpy转成PIL格式
|
||||
# image_dir is initialized in PyQt where the user open the image
|
||||
# 打开原图
|
||||
image = image * 255
|
||||
pri_image = cv2.imread(image_dir) # PIL格式 RGB
|
||||
h, w = pri_image.shape[0], pri_image.shape[1]
|
||||
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
# 保存到本地
|
||||
cv2.imwrite(save_dir, image)
|
||||
# print("inferenced image has been saved at ", save_dir)
|
||||
return image
|
||||
|
||||
|
||||
def openimage():
|
||||
imgName, imgType = QFileDialog.getOpenFileName()
|
||||
jpg = QtGui.QPixmap(imgName).scaled(label.width(), label.height())
|
||||
label.setPixmap(jpg)
|
||||
# pri_image_dir = imgName
|
||||
image_dir = imgName
|
||||
ie = IECore()
|
||||
save_dir = os.path.join(os.getcwd() + os.sep + "inference_show.png") # 保存路径
|
||||
model_name = 'u2netp(canny)'
|
||||
xml_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.xml')
|
||||
bin_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.bin')
|
||||
net = ie.read_network(xml_dir, bin_dir)
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
net.batch_size = 1
|
||||
# n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
# print("Model input data shape = %d, %d, %d, %d" % (n, c, h, w))
|
||||
# print("image dir = ", image_dir)
|
||||
image = preprocess(image_dir=image_dir)
|
||||
exec_net = ie.load_network(network=net, device_name=("CUDA" if torch.cuda.is_available() else "CPU"))
|
||||
# 写界面的同学记得在选择图片后加一句
|
||||
# pri_image_dir = "用户选择打开图片的路径"
|
||||
start_time = time.time()
|
||||
res = exec_net.infer(inputs={input_blob: image})
|
||||
res = res[out_blob]
|
||||
# print("res : ", res.shape)
|
||||
# res = torch.from_numpy(res)
|
||||
# 后处理
|
||||
final_img = reprocess(res, image_dir=image_dir, save_dir=save_dir)
|
||||
# compute the execute time
|
||||
# ----------------------------------------------------------------
|
||||
init_image = Image.open(image_dir) # RGB PIL [n, c, h, w]
|
||||
|
||||
init_image = np.array(init_image)
|
||||
|
||||
init_image = np.squeeze(init_image)
|
||||
mask = np.squeeze(final_img)
|
||||
# print("mask shape : ", mask.shape)
|
||||
|
||||
img_r = init_image[:, :, 0]
|
||||
img_g = init_image[:, :, 1]
|
||||
img_b = init_image[:, :, 2]
|
||||
# print("img_r/g/b shape : ", img_b.shape)
|
||||
|
||||
mask_r = mask[:, :, 0]
|
||||
mask_g = mask[:, :, 1]
|
||||
mask_b = mask[:, :, 2]
|
||||
# print("mask_r/g/b : ", mask_r.shape)
|
||||
|
||||
height = img_r.shape[0]
|
||||
weight = img_r.shape[1]
|
||||
for h in range(height):
|
||||
for w in range(weight):
|
||||
if mask_r[h][w] <= 50:
|
||||
img_r[h][w] = 255
|
||||
if mask_g[h][w] <= 50:
|
||||
img_g[h][w] = 255
|
||||
if mask_b[h][w] <= 50:
|
||||
img_b[h][w] = 255
|
||||
|
||||
re_img = np.stack((img_r, img_g, img_b), axis=-1)
|
||||
# print("re_img shape : ", re_img.shape)
|
||||
final_img = Image.fromarray(re_img)
|
||||
final_save_dir = os.path.join(os.getcwd() + os.sep + "final_img.png")
|
||||
final_img.save(final_save_dir)
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
exe_time = time.time() - start_time
|
||||
print("---total execute time = ", exe_time)
|
||||
ccc = save_dir # 图片路径
|
||||
jpg = QtGui.QPixmap(ccc).scaled(label.width(), label.height())
|
||||
label_2.setPixmap(jpg)
|
||||
ddd = final_save_dir # 图片路径
|
||||
jpg = QtGui.QPixmap(ddd).scaled(label.width(), label.height())
|
||||
label_3.setPixmap(jpg)
|
||||
# 显示时间
|
||||
a = exe_time # time为float类型变量
|
||||
label_4.setText("所用时间为:" + str(a) + "s")
|
||||
|
||||
# 窗口
|
||||
app = QApplication([])
|
||||
|
||||
window = QMainWindow()
|
||||
window.resize(1200, 600)
|
||||
window.move(300, 310)
|
||||
window.setWindowTitle('前景分割')
|
||||
|
||||
#三张图片加时间显示的label
|
||||
label = QLabel(window)
|
||||
label.move(10,80)
|
||||
label.resize(320,320)
|
||||
|
||||
label_2 = QLabel(window)
|
||||
label_2.move(450,80)
|
||||
label_2.resize(320,320)
|
||||
|
||||
label_3 = QLabel(window)
|
||||
label_3.move(800,80)
|
||||
label_3.resize(320,320)
|
||||
|
||||
label_4 = QLabel(window)
|
||||
label_4.move(450, 500)
|
||||
label_4.resize(300, 50)
|
||||
|
||||
button = QPushButton('RUN', window)
|
||||
button.move(500, 20)
|
||||
button.clicked.connect(openimage)
|
||||
|
||||
button_2 = QPushButton('STOP', window)
|
||||
button_2.move(1000,500)
|
||||
button_2.clicked.connect(QCoreApplication.quit)
|
||||
|
||||
window.show()
|
||||
|
||||
app.exec_()
|
||||
|
Loading…
Reference in new issue