from openvino.inference_engine import IECore
import os
import torch
import numpy as np
import time
from skimage import transform, io
from PIL import Image
from PyQt5 import QtGui
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
# import openvino.inference_engine.constants
import cv2

def preprocess(image_dir):

    image = cv2.imread(image_dir)  # PIL格式 RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (320, 320))
    image = np.asarray(image)
    image = np.swapaxes(image, 1, 2)  # 转换维度
    image = np.swapaxes(image, 0, 1)
    # 归一化处理
    image[:, :, 0] = (image[:, :, 0] - np.min(image[:, :, 0])) / (np.max(image[:, :, 0]) - np.min(image[:, :, 0]))
    image[:, :, 0] = (image[:, :, 0] - np.mean(image[:, :, 0])) / (np.std(image[:, :, 0]))
    image[:, :, 1] = (image[:, :, 1] - np.min(image[:, :, 1])) / (np.max(image[:, :, 1]) - np.min(image[:, :, 1]))
    image[:, :, 1] = (image[:, :, 1] - np.mean(image[:, :, 1])) / (np.std(image[:, :, 1]))
    image[:, :, 2] = (image[:, :, 2] - np.min(image[:, :, 2])) / (np.max(image[:, :, 2]) - np.min(image[:, :, 2]))
    image[:, :, 2] = (image[:, :, 2] - np.mean(image[:, :, 2])) / (np.std(image[:, :, 2]))
    image = np.expand_dims(image, axis=0)  # 在零维度加1
    # print("image_shape has been resized to : ", image.shape)
    image = torch.from_numpy(image)  # 转成tensor
    return image


def reprocess(image, image_dir, save_dir):
    image = image[:, 0, :, :]
    ma = np.max(image)
    mi = np.min(image)
    image = (image - mi) / (ma - mi)
    image = image.squeeze()
    # print("image_np shape is ", image_np.shape) 320*320
    # 二值化处理
    # for h in range(320):
    #     for w in range(320):
    #         if(image_np[h][w] < 30.0/255.0):
    #             image_np[h][w] = 0
    #         else:
    #             image_np[h][w] = 1
    # image = Image.fromarray(image * 255).convert('RGB')  # 从numpy转成PIL格式
    # image_dir is initialized in PyQt where the user open the image
    # 打开原图
    image = image * 255
    pri_image = cv2.imread(image_dir)  # PIL格式 RGB
    h, w = pri_image.shape[0], pri_image.shape[1]
    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
    # 保存到本地
    cv2.imwrite(save_dir, image)
    # print("inferenced image has been saved at ", save_dir)
    return image


def openimage():
    imgName, imgType = QFileDialog.getOpenFileName()
    jpg = QtGui.QPixmap(imgName).scaled(label.width(), label.height())
    label.setPixmap(jpg)
    # pri_image_dir = imgName
    image_dir = imgName
    ie = IECore()
    save_dir = os.path.join(os.getcwd() + os.sep + "inference_show.png")  # 保存路径
    model_name = 'u2netp(canny)'
    xml_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.xml')
    bin_dir = os.path.join(os.getcwd(), 'saved_models', model_name + '.bin')
    net = ie.read_network(xml_dir, bin_dir)
    input_blob = next(iter(net.input_info))
    out_blob = next(iter(net.outputs))
    net.batch_size = 1
    # n, c, h, w = net.input_info[input_blob].input_data.shape
    # print("Model input data shape = %d, %d, %d, %d" % (n, c, h, w))
    # print("image dir = ", image_dir)
    image = preprocess(image_dir=image_dir)
    exec_net = ie.load_network(network=net, device_name=("CUDA" if torch.cuda.is_available() else "CPU"))
    # 写界面的同学记得在选择图片后加一句
    # pri_image_dir = "用户选择打开图片的路径"
    start_time = time.time()
    res = exec_net.infer(inputs={input_blob: image})
    res = res[out_blob]
    # print("res : ", res.shape)
    # res = torch.from_numpy(res)
    # 后处理
    final_img = reprocess(res, image_dir=image_dir, save_dir=save_dir)
    # compute the execute time
    # ----------------------------------------------------------------
    init_image = Image.open(image_dir)  # RGB PIL [n, c, h, w]

    init_image = np.array(init_image)

    init_image = np.squeeze(init_image)
    mask = np.squeeze(final_img)
    # print("mask shape : ", mask.shape)

    img_r = init_image[:, :, 0]
    img_g = init_image[:, :, 1]
    img_b = init_image[:, :, 2]
    # print("img_r/g/b shape : ", img_b.shape)

    mask_r = mask[:, :, 0]
    mask_g = mask[:, :, 1]
    mask_b = mask[:, :, 2]
    # print("mask_r/g/b : ", mask_r.shape)

    height = img_r.shape[0]
    weight = img_r.shape[1]
    for h in range(height):
        for w in range(weight):
            if mask_r[h][w] <= 50:
                img_r[h][w] = 255
            if mask_g[h][w] <= 50:
                img_g[h][w] = 255
            if mask_b[h][w] <= 50:
                img_b[h][w] = 255

    re_img = np.stack((img_r, img_g, img_b), axis=-1)
    # print("re_img shape : ", re_img.shape)
    final_img = Image.fromarray(re_img)
    final_save_dir = os.path.join(os.getcwd() + os.sep + "final_img.png")
    final_img.save(final_save_dir)
    # ----------------------------------------------------------------

    exe_time = time.time() - start_time
    print("---total execute time = ", exe_time)
    ccc = save_dir  # 图片路径
    jpg = QtGui.QPixmap(ccc).scaled(label.width(), label.height())
    label_2.setPixmap(jpg)
    ddd = final_save_dir  # 图片路径
    jpg = QtGui.QPixmap(ddd).scaled(label.width(), label.height())
    label_3.setPixmap(jpg)
    # 显示时间
    a = exe_time  # time为float类型变量
    label_4.setText("所用时间为:" + str(a) + "s")

# 窗口
app = QApplication([])

window = QMainWindow()
window.resize(1200, 600)
window.move(300, 310)
window.setWindowTitle('前景分割')

#三张图片加时间显示的label
label = QLabel(window)
label.move(10,80)
label.resize(320,320)

label_2 = QLabel(window)
label_2.move(450,80)
label_2.resize(320,320)

label_3 = QLabel(window)
label_3.move(800,80)
label_3.resize(320,320)

label_4 = QLabel(window)
label_4.move(450, 500)
label_4.resize(300, 50)

button = QPushButton('RUN', window)
button.move(500, 20)
button.clicked.connect(openimage)

button_2 = QPushButton('STOP', window)
button_2.move(1000,500)
button_2.clicked.connect(QCoreApplication.quit)

window.show()

app.exec_()