diff --git a/u2net_test.py b/u2net_test.py new file mode 100644 index 0000000..54b4f07 --- /dev/null +++ b/u2net_test.py @@ -0,0 +1,199 @@ +import os +from skimage import io, transform +import torch +import torchvision +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms#, utils +# import torch.optim as optim + +import numpy as np +from PIL import Image +import glob + +from data_loader import RescaleT +from data_loader import ToTensor +from data_loader import ToTensorLab +from data_loader import SalObjDataset + +from model import U2NET # full size version u2net 173.6 MB +from model import U2NETP # small version u2net 4.7 MB +from model import U2NETM # middle size version u2net 99MB +import cv2 + +def IOU(pred, label): + Iand1 = np.sum(label[:, :]*pred[:, :]) + Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1 + IoU1 = Iand1/Ior1 + print(Iand1, Ior1) + return IoU1 + +# normalize the predicted SOD probability map +def normPRED(d): + ma = torch.max(d) + mi = torch.min(d) + + dn = (d-mi)/(ma-mi) + + return dn + +def save_output(image_name, pred, d_dir): + + predict = pred + predict = predict.squeeze() + predict_np = predict.cpu().data.numpy() + + im = Image.fromarray(predict_np*255).convert('RGB') + img_name = image_name.split(os.sep)[-1] + image = io.imread(image_name) + imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) + + pb_np = np.array(imo) + + aaa = img_name.split(".") + bbb = aaa[0:-1] + imidx = bbb[0] + for i in range(1,len(bbb)): + imidx = imidx + "." + bbb[i] + + imo.save(d_dir+imidx+'.png') + +def main(): + + # --------- 1. get image path and name --------- + model_name='u2netp' # u2netm + print("Image name GOTTEN") + + # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images') + image_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'img' + os.sep) + # prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep) + prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + 'iou_img' + os.sep) + model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth') + label_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'lbl' + os.sep) + + img_name_list = glob.glob(image_dir + os.sep + '*') + print('Image list GOTTEN:\n', img_name_list) + # lbl_name_list = [] + # for img_path in img_name_list: + # img_name = img_path.split(os.sep)[-1] + # aaa = img_name.split(".") + # bbb = aaa[0:-1] + # imidx = bbb[0] + # for i in range(1, len(bbb)): + # imidx = imidx + "." + bbb[i] + # lbl_name_list.append(label_dir + imidx + '.png') + + # --------- 2. dataloader --------- + # 1. dataloader + test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, + lbl_name_list = [], + transform=transforms.Compose([RescaleT(320), + ToTensorLab(flag=0)]) + ) + # test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, + # lbl_name_list = [], + # transform=transforms.Compose([RescaleT(320), + # ]) + # ) + test_salobj_dataloader = DataLoader(test_salobj_dataset, + batch_size=1, + shuffle=False, + num_workers=1) + + # --------- 3. model define --------- + if(model_name =='u2net'): + print("...load U2NET---173.6 MB") + net = U2NET(3, 1) + elif(model_name =='u2netp'): + print("...load U2NEP---4.7 MB") + net = U2NETP(3, 1) + elif(model_name == 'u2netm'): + net = U2NETM(3, 1) + print('...load U2NETM---91MB') + + if torch.cuda.is_available(): + # net.load_state_dict(torch.load(model_dir)) + net = torch.load(model_dir) + net.cuda() + else: + net = torch.load(model_dir, map_location='cpu') + net.eval() + # --------- 4. inference for each image --------- + total_iou = 0 + s = 1 + lbl_name_list = [] + for img_path in img_name_list: + img_name = img_path.split(os.sep)[-1] + aaa = img_name.split(".") + bbb = aaa[0:-1] + imidx = bbb[0] + for i in range(1, len(bbb)): + imidx = imidx + "." + bbb[i] + lbl_name_list.append(label_dir + imidx + '.png') + + for i_test, data_test in enumerate(test_salobj_dataloader): + print("inferencing:",img_name_list[i_test]) + print("number : ", s) + inputs_test = data_test['image'] + inputs_test = inputs_test.type(torch.FloatTensor) + # print(inputs_test.size()) + if torch.cuda.is_available(): + inputs_test = Variable(inputs_test.cuda()) + else: + inputs_test = Variable(inputs_test) + + d1,d2,d3,d4,d5,d6,d7= net(inputs_test) + + # normalization + # print("d1.shape", d1.shape) # (1, 1, 320, 320) + pred = d1[:,0,:,:] + pred = normPRED(pred) + pred = pred.squeeze() + pred = pred.cpu().data.numpy() + pred = pred * 255 # (320, 320) + image_name = img_name_list[i_test] + # print(image_name) + # image_name = image_name.split(os.sep)[-1] + pri_image = cv2.imread(image_name, -1) + pri_image = pri_image[:, :, 0] + # pri_image = cv2.cvtColor(pri_image, cv2.COLOR_BGR2RGB) + h, w = pri_image.shape[1], pri_image.shape[0] + # print("h, w", h, w) # (400, 300) + pred = cv2.resize(pred, (w, h), interpolation=cv2.INTER_LINEAR) + # 二值化处理 + for i in range(h): + for j in range(w): + if pred[i][j] <= 127: + pred[i][j] = 0 + if pred[i][j] > 127: + pred[i][j] = 1 + pred_cv = pred * 255 + cv2.imwrite("pred_iou.jpg", pred_cv) + # print("pred.shape", pred.shape) + label_iou = cv2.imread(lbl_name_list[i_test], 0) + # print("image name: ", image_name) + # print("label name: ", lbl_name_list[i_test]) + label_iou = cv2.resize(label_iou, (w, h)) + cv2.imwrite("label_iou.png", label_iou) + for i in range(h): + for j in range(w): + if label_iou[i][j] <= 127: + label_iou[i][j] = 0 + if label_iou[i][j] > 127: + label_iou[i][j] = 1 + # print("label_iou.shape", label_iou.shape) + iou = IOU(pred, label_iou) + total_iou += iou + miou = total_iou/s + print("IOU = %5f, MIOU = %5f" % (iou, miou)) + # # save results to test_results folder + # if not os.path.exists(prediction_dir): + # os.makedirs(prediction_dir, exist_ok=True) + # save_output(img_name_list[i_test],pred,prediction_dir) + s += 1 + del d1,d2,d3,d4,d5,d6,d7 + print("\n") +if __name__ == "__main__": + main()