import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version u2net 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
from model import U2NETM # middle size version u2net 99MB
import cv2

def IOU(pred, label):
    Iand1 = np.sum(label[:, :]*pred[:, :])
    Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1
    IoU1 = Iand1/Ior1
    print(Iand1, Ior1)
    return IoU1

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name, pred, d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def main():

    # --------- 1. get image path and name ---------
    model_name='u2netp' # u2netm
    print("Image name GOTTEN")

    # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    image_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'img' + os.sep)
    # prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + 'iou_img' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
    label_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'lbl' + os.sep)
    
    img_name_list = glob.glob(image_dir + os.sep + '*')
    print('Image list GOTTEN:\n', img_name_list)
    # lbl_name_list = []
    # for img_path in img_name_list:
    #     img_name = img_path.split(os.sep)[-1]
    #     aaa = img_name.split(".")
    #     bbb = aaa[0:-1]
    #     imidx = bbb[0]
    #     for i in range(1, len(bbb)):
    #         imidx = imidx + "." + bbb[i]
    #     lbl_name_list.append(label_dir + imidx + '.png')

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                                                      )
    # test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
    #                                     lbl_name_list = [],
    #                                     transform=transforms.Compose([RescaleT(320),
    #                                                                   ])
    #                                                                   )                                                                  
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name =='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif(model_name =='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    elif(model_name == 'u2netm'):
        net = U2NETM(3, 1)
        print('...load U2NETM---91MB')

    if torch.cuda.is_available():
        # net.load_state_dict(torch.load(model_dir))
        net = torch.load(model_dir)
        net.cuda()
    else:
        net = torch.load(model_dir, map_location='cpu')
    net.eval()
    # --------- 4. inference for each image ---------
    total_iou = 0
    s = 1
    lbl_name_list = []
    for img_path in img_name_list:
        img_name = img_path.split(os.sep)[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        lbl_name_list.append(label_dir + imidx + '.png')
        
    for i_test, data_test in enumerate(test_salobj_dataloader):
        print("inferencing:",img_name_list[i_test])
        print("number : ", s)
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        # print(inputs_test.size())
        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        # print("d1.shape", d1.shape)  # (1, 1, 320, 320)
        pred = d1[:,0,:,:]
        pred = normPRED(pred)
        pred = pred.squeeze()  
        pred = pred.cpu().data.numpy()  
        pred = pred * 255  # (320, 320)
        image_name = img_name_list[i_test]
        # print(image_name)
        # image_name = image_name.split(os.sep)[-1]
        pri_image = cv2.imread(image_name, -1)
        pri_image = pri_image[:, :, 0]
        # pri_image = cv2.cvtColor(pri_image, cv2.COLOR_BGR2RGB)
        h, w = pri_image.shape[1], pri_image.shape[0]
        # print("h, w", h, w)  # (400, 300)
        pred = cv2.resize(pred, (w, h), interpolation=cv2.INTER_LINEAR)
        # 二值化处理
        for i in range(h):
            for j in range(w):
                if pred[i][j] <= 127:
                    pred[i][j] = 0
                if pred[i][j] > 127:
                    pred[i][j] = 1
        pred_cv = pred * 255
        cv2.imwrite("pred_iou.jpg", pred_cv)
        # print("pred.shape", pred.shape)
        label_iou = cv2.imread(lbl_name_list[i_test], 0)
        # print("image name: ", image_name)
        # print("label name: ", lbl_name_list[i_test])
        label_iou = cv2.resize(label_iou, (w, h))
        cv2.imwrite("label_iou.png", label_iou)
        for i in range(h):
            for j in range(w):
                if label_iou[i][j] <= 127:
                    label_iou[i][j] = 0
                if label_iou[i][j] > 127:
                    label_iou[i][j] = 1
        # print("label_iou.shape", label_iou.shape)
        iou = IOU(pred, label_iou)
        total_iou += iou
        miou = total_iou/s
        print("IOU = %5f, MIOU = %5f" % (iou, miou))
        # # save results to test_results folder
        # if not os.path.exists(prediction_dir):
        #     os.makedirs(prediction_dir, exist_ok=True)
        # save_output(img_name_list[i_test],pred,prediction_dir)
        s += 1
        del d1,d2,d3,d4,d5,d6,d7
        print("\n")
if __name__ == "__main__":
    main()