You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
7.1 KiB
200 lines
7.1 KiB
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET # full size version u2net 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
from model import U2NETM # middle size version u2net 99MB
import cv2
def IOU(pred, label):
Iand1 = np.sum(label[:, :]*pred[:, :])
Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1
IoU1 = Iand1/Ior1
print(Iand1, Ior1)
return IoU1
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name, pred, d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
def main():
# --------- 1. get image path and name ---------
model_name='u2netp' # u2netm
print("Image name GOTTEN")
# image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
image_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'img' + os.sep)
# prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + 'iou_img' + os.sep)
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
label_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'lbl' + os.sep)
img_name_list = glob.glob(image_dir + os.sep + '*')
print('Image list GOTTEN:\n', img_name_list)
# lbl_name_list = []
# for img_path in img_name_list:
# img_name = img_path.split(os.sep)[-1]
# aaa = img_name.split(".")
# bbb = aaa[0:-1]
# imidx = bbb[0]
# for i in range(1, len(bbb)):
# imidx = imidx + "." + bbb[i]
# lbl_name_list.append(label_dir + imidx + '.png')
# --------- 2. dataloader ---------
# 1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
# test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
# lbl_name_list = [],
# transform=transforms.Compose([RescaleT(320),
# ])
# )
test_salobj_dataloader = DataLoader(test_salobj_dataset,
# --------- 3. model define ---------
if(model_name =='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3, 1)
elif(model_name =='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3, 1)
elif(model_name == 'u2netm'):
net = U2NETM(3, 1)
print('...load U2NETM---91MB')
if torch.cuda.is_available():
# net.load_state_dict(torch.load(model_dir))
net = torch.load(model_dir)
net = torch.load(model_dir, map_location='cpu')
# --------- 4. inference for each image ---------
total_iou = 0
s = 1
lbl_name_list = []
for img_path in img_name_list:
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1, len(bbb)):
imidx = imidx + "." + bbb[i]
lbl_name_list.append(label_dir + imidx + '.png')
for i_test, data_test in enumerate(test_salobj_dataloader):
print("number : ", s)
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
# print(inputs_test.size())
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
# print("d1.shape", d1.shape) # (1, 1, 320, 320)
pred = d1[:,0,:,:]
pred = normPRED(pred)
pred = pred.squeeze()
pred = pred.cpu().data.numpy()
pred = pred * 255 # (320, 320)
image_name = img_name_list[i_test]
# print(image_name)
# image_name = image_name.split(os.sep)[-1]
pri_image = cv2.imread(image_name, -1)
pri_image = pri_image[:, :, 0]
# pri_image = cv2.cvtColor(pri_image, cv2.COLOR_BGR2RGB)
h, w = pri_image.shape[1], pri_image.shape[0]
# print("h, w", h, w) # (400, 300)
pred = cv2.resize(pred, (w, h), interpolation=cv2.INTER_LINEAR)
# 二值化处理
for i in range(h):
for j in range(w):
if pred[i][j] <= 127:
pred[i][j] = 0
if pred[i][j] > 127:
pred[i][j] = 1
pred_cv = pred * 255
cv2.imwrite("pred_iou.jpg", pred_cv)
# print("pred.shape", pred.shape)
label_iou = cv2.imread(lbl_name_list[i_test], 0)
# print("image name: ", image_name)
# print("label name: ", lbl_name_list[i_test])
label_iou = cv2.resize(label_iou, (w, h))
cv2.imwrite("label_iou.png", label_iou)
for i in range(h):
for j in range(w):
if label_iou[i][j] <= 127:
label_iou[i][j] = 0
if label_iou[i][j] > 127:
label_iou[i][j] = 1
# print("label_iou.shape", label_iou.shape)
iou = IOU(pred, label_iou)
total_iou += iou
miou = total_iou/s
print("IOU = %5f, MIOU = %5f" % (iou, miou))
# # save results to test_results folder
# if not os.path.exists(prediction_dir):
# os.makedirs(prediction_dir, exist_ok=True)
# save_output(img_name_list[i_test],pred,prediction_dir)
s += 1
del d1,d2,d3,d4,d5,d6,d7
if __name__ == "__main__":