parent
8fe18057e1
commit
5a94f1aeb5
@ -0,0 +1,199 @@
|
|||||||
|
import os
|
||||||
|
from skimage import io, transform
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms#, utils
|
||||||
|
# import torch.optim as optim
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from data_loader import RescaleT
|
||||||
|
from data_loader import ToTensor
|
||||||
|
from data_loader import ToTensorLab
|
||||||
|
from data_loader import SalObjDataset
|
||||||
|
|
||||||
|
from model import U2NET # full size version u2net 173.6 MB
|
||||||
|
from model import U2NETP # small version u2net 4.7 MB
|
||||||
|
from model import U2NETM # middle size version u2net 99MB
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def IOU(pred, label):
|
||||||
|
Iand1 = np.sum(label[:, :]*pred[:, :])
|
||||||
|
Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1
|
||||||
|
IoU1 = Iand1/Ior1
|
||||||
|
print(Iand1, Ior1)
|
||||||
|
return IoU1
|
||||||
|
|
||||||
|
# normalize the predicted SOD probability map
|
||||||
|
def normPRED(d):
|
||||||
|
ma = torch.max(d)
|
||||||
|
mi = torch.min(d)
|
||||||
|
|
||||||
|
dn = (d-mi)/(ma-mi)
|
||||||
|
|
||||||
|
return dn
|
||||||
|
|
||||||
|
def save_output(image_name, pred, d_dir):
|
||||||
|
|
||||||
|
predict = pred
|
||||||
|
predict = predict.squeeze()
|
||||||
|
predict_np = predict.cpu().data.numpy()
|
||||||
|
|
||||||
|
im = Image.fromarray(predict_np*255).convert('RGB')
|
||||||
|
img_name = image_name.split(os.sep)[-1]
|
||||||
|
image = io.imread(image_name)
|
||||||
|
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
|
||||||
|
|
||||||
|
pb_np = np.array(imo)
|
||||||
|
|
||||||
|
aaa = img_name.split(".")
|
||||||
|
bbb = aaa[0:-1]
|
||||||
|
imidx = bbb[0]
|
||||||
|
for i in range(1,len(bbb)):
|
||||||
|
imidx = imidx + "." + bbb[i]
|
||||||
|
|
||||||
|
imo.save(d_dir+imidx+'.png')
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
# --------- 1. get image path and name ---------
|
||||||
|
model_name='u2netp' # u2netm
|
||||||
|
print("Image name GOTTEN")
|
||||||
|
|
||||||
|
# image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
|
||||||
|
image_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'img' + os.sep)
|
||||||
|
# prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
|
||||||
|
prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + 'iou_img' + os.sep)
|
||||||
|
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
|
||||||
|
label_dir = os.path.join(os.getcwd(), 'train_data', 'dataset1', 'lbl' + os.sep)
|
||||||
|
|
||||||
|
img_name_list = glob.glob(image_dir + os.sep + '*')
|
||||||
|
print('Image list GOTTEN:\n', img_name_list)
|
||||||
|
# lbl_name_list = []
|
||||||
|
# for img_path in img_name_list:
|
||||||
|
# img_name = img_path.split(os.sep)[-1]
|
||||||
|
# aaa = img_name.split(".")
|
||||||
|
# bbb = aaa[0:-1]
|
||||||
|
# imidx = bbb[0]
|
||||||
|
# for i in range(1, len(bbb)):
|
||||||
|
# imidx = imidx + "." + bbb[i]
|
||||||
|
# lbl_name_list.append(label_dir + imidx + '.png')
|
||||||
|
|
||||||
|
# --------- 2. dataloader ---------
|
||||||
|
# 1. dataloader
|
||||||
|
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
|
||||||
|
lbl_name_list = [],
|
||||||
|
transform=transforms.Compose([RescaleT(320),
|
||||||
|
ToTensorLab(flag=0)])
|
||||||
|
)
|
||||||
|
# test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
|
||||||
|
# lbl_name_list = [],
|
||||||
|
# transform=transforms.Compose([RescaleT(320),
|
||||||
|
# ])
|
||||||
|
# )
|
||||||
|
test_salobj_dataloader = DataLoader(test_salobj_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1)
|
||||||
|
|
||||||
|
# --------- 3. model define ---------
|
||||||
|
if(model_name =='u2net'):
|
||||||
|
print("...load U2NET---173.6 MB")
|
||||||
|
net = U2NET(3, 1)
|
||||||
|
elif(model_name =='u2netp'):
|
||||||
|
print("...load U2NEP---4.7 MB")
|
||||||
|
net = U2NETP(3, 1)
|
||||||
|
elif(model_name == 'u2netm'):
|
||||||
|
net = U2NETM(3, 1)
|
||||||
|
print('...load U2NETM---91MB')
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# net.load_state_dict(torch.load(model_dir))
|
||||||
|
net = torch.load(model_dir)
|
||||||
|
net.cuda()
|
||||||
|
else:
|
||||||
|
net = torch.load(model_dir, map_location='cpu')
|
||||||
|
net.eval()
|
||||||
|
# --------- 4. inference for each image ---------
|
||||||
|
total_iou = 0
|
||||||
|
s = 1
|
||||||
|
lbl_name_list = []
|
||||||
|
for img_path in img_name_list:
|
||||||
|
img_name = img_path.split(os.sep)[-1]
|
||||||
|
aaa = img_name.split(".")
|
||||||
|
bbb = aaa[0:-1]
|
||||||
|
imidx = bbb[0]
|
||||||
|
for i in range(1, len(bbb)):
|
||||||
|
imidx = imidx + "." + bbb[i]
|
||||||
|
lbl_name_list.append(label_dir + imidx + '.png')
|
||||||
|
|
||||||
|
for i_test, data_test in enumerate(test_salobj_dataloader):
|
||||||
|
print("inferencing:",img_name_list[i_test])
|
||||||
|
print("number : ", s)
|
||||||
|
inputs_test = data_test['image']
|
||||||
|
inputs_test = inputs_test.type(torch.FloatTensor)
|
||||||
|
# print(inputs_test.size())
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs_test = Variable(inputs_test.cuda())
|
||||||
|
else:
|
||||||
|
inputs_test = Variable(inputs_test)
|
||||||
|
|
||||||
|
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
# print("d1.shape", d1.shape) # (1, 1, 320, 320)
|
||||||
|
pred = d1[:,0,:,:]
|
||||||
|
pred = normPRED(pred)
|
||||||
|
pred = pred.squeeze()
|
||||||
|
pred = pred.cpu().data.numpy()
|
||||||
|
pred = pred * 255 # (320, 320)
|
||||||
|
image_name = img_name_list[i_test]
|
||||||
|
# print(image_name)
|
||||||
|
# image_name = image_name.split(os.sep)[-1]
|
||||||
|
pri_image = cv2.imread(image_name, -1)
|
||||||
|
pri_image = pri_image[:, :, 0]
|
||||||
|
# pri_image = cv2.cvtColor(pri_image, cv2.COLOR_BGR2RGB)
|
||||||
|
h, w = pri_image.shape[1], pri_image.shape[0]
|
||||||
|
# print("h, w", h, w) # (400, 300)
|
||||||
|
pred = cv2.resize(pred, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
# 二值化处理
|
||||||
|
for i in range(h):
|
||||||
|
for j in range(w):
|
||||||
|
if pred[i][j] <= 127:
|
||||||
|
pred[i][j] = 0
|
||||||
|
if pred[i][j] > 127:
|
||||||
|
pred[i][j] = 1
|
||||||
|
pred_cv = pred * 255
|
||||||
|
cv2.imwrite("pred_iou.jpg", pred_cv)
|
||||||
|
# print("pred.shape", pred.shape)
|
||||||
|
label_iou = cv2.imread(lbl_name_list[i_test], 0)
|
||||||
|
# print("image name: ", image_name)
|
||||||
|
# print("label name: ", lbl_name_list[i_test])
|
||||||
|
label_iou = cv2.resize(label_iou, (w, h))
|
||||||
|
cv2.imwrite("label_iou.png", label_iou)
|
||||||
|
for i in range(h):
|
||||||
|
for j in range(w):
|
||||||
|
if label_iou[i][j] <= 127:
|
||||||
|
label_iou[i][j] = 0
|
||||||
|
if label_iou[i][j] > 127:
|
||||||
|
label_iou[i][j] = 1
|
||||||
|
# print("label_iou.shape", label_iou.shape)
|
||||||
|
iou = IOU(pred, label_iou)
|
||||||
|
total_iou += iou
|
||||||
|
miou = total_iou/s
|
||||||
|
print("IOU = %5f, MIOU = %5f" % (iou, miou))
|
||||||
|
# # save results to test_results folder
|
||||||
|
# if not os.path.exists(prediction_dir):
|
||||||
|
# os.makedirs(prediction_dir, exist_ok=True)
|
||||||
|
# save_output(img_name_list[i_test],pred,prediction_dir)
|
||||||
|
s += 1
|
||||||
|
del d1,d2,d3,d4,d5,d6,d7
|
||||||
|
print("\n")
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in new issue