import os from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms#, utils # import torch.optim as optim import numpy as np from PIL import Image import glob from data_loader import RescaleT from data_loader import ToTensor from data_loader import ToTensorLab from data_loader import SalObjDataset from model import U2NET # full size version u2net 173.6 MB from model import U2NETP # small version u2net 4.7 MB from model import U2NETM # middle size version u2net 99MB import cv2 def IOU(pred, label): Iand1 = np.sum(label[:, :]*pred[:, :]) Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1 IoU1 = Iand1/Ior1 print(Iand1, Ior1) return IoU1 # normalize the predicted SOD probability map def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def save_output(image_name, pred, d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1,len(bbb)): imidx = imidx + "." + bbb[i] imo.save(d_dir+imidx+'.png') def main(): # --------- 1. get image path and name --------- model_name='u2netp' # u2netm print("Image name GOTTEN") # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images') image_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'img' + os.sep) # prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep) prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + 'iou_img' + os.sep) model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth') label_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'lbl' + os.sep) img_name_list = glob.glob(image_dir + os.sep + '*') print('Image list GOTTEN:\n', img_name_list) # lbl_name_list = [] # for img_path in img_name_list: # img_name = img_path.split(os.sep)[-1] # aaa = img_name.split(".") # bbb = aaa[0:-1] # imidx = bbb[0] # for i in range(1, len(bbb)): # imidx = imidx + "." + bbb[i] # lbl_name_list.append(label_dir + imidx + '.png') # --------- 2. dataloader --------- # 1. dataloader test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, lbl_name_list = [], transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) ) # test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, # lbl_name_list = [], # transform=transforms.Compose([RescaleT(320), # ]) # ) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) # --------- 3. model define --------- if(model_name =='u2net'): print("...load U2NET---173.6 MB") net = U2NET(3, 1) elif(model_name =='u2netp'): print("...load U2NEP---4.7 MB") net = U2NETP(3, 1) elif(model_name == 'u2netm'): net = U2NETM(3, 1) print('...load U2NETM---91MB') if torch.cuda.is_available(): # net.load_state_dict(torch.load(model_dir)) net = torch.load(model_dir) net.cuda() else: net = torch.load(model_dir, map_location='cpu') net.eval() # --------- 4. inference for each image --------- total_iou = 0 s = 1 lbl_name_list = [] for img_path in img_name_list: img_name = img_path.split(os.sep)[-1] aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1, len(bbb)): imidx = imidx + "." + bbb[i] lbl_name_list.append(label_dir + imidx + '.png') for i_test, data_test in enumerate(test_salobj_dataloader): print("inferencing:",img_name_list[i_test]) print("number : ", s) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) # print(inputs_test.size()) if torch.cuda.is_available(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization # print("d1.shape", d1.shape) # (1, 1, 320, 320) pred = d1[:,0,:,:] pred = normPRED(pred) pred = pred.squeeze() pred = pred.cpu().data.numpy() pred = pred * 255 # (320, 320) image_name = img_name_list[i_test] # print(image_name) # image_name = image_name.split(os.sep)[-1] pri_image = cv2.imread(image_name, -1) pri_image = pri_image[:, :, 0] # pri_image = cv2.cvtColor(pri_image, cv2.COLOR_BGR2RGB) h, w = pri_image.shape[1], pri_image.shape[0] # print("h, w", h, w) # (400, 300) pred = cv2.resize(pred, (w, h), interpolation=cv2.INTER_LINEAR) # 二值化处理 for i in range(h): for j in range(w): if pred[i][j] <= 127: pred[i][j] = 0 if pred[i][j] > 127: pred[i][j] = 1 pred_cv = pred * 255 cv2.imwrite("pred_iou.jpg", pred_cv) # print("pred.shape", pred.shape) label_iou = cv2.imread(lbl_name_list[i_test], 0) # print("image name: ", image_name) # print("label name: ", lbl_name_list[i_test]) label_iou = cv2.resize(label_iou, (w, h)) cv2.imwrite("label_iou.png", label_iou) for i in range(h): for j in range(w): if label_iou[i][j] <= 127: label_iou[i][j] = 0 if label_iou[i][j] > 127: label_iou[i][j] = 1 # print("label_iou.shape", label_iou.shape) iou = IOU(pred, label_iou) total_iou += iou miou = total_iou/s print("IOU = %5f, MIOU = %5f" % (iou, miou)) # # save results to test_results folder # if not os.path.exists(prediction_dir): # os.makedirs(prediction_dir, exist_ok=True) # save_output(img_name_list[i_test],pred,prediction_dir) s += 1 del d1,d2,d3,d4,d5,d6,d7 print("\n") if __name__ == "__main__": main()