You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

527 lines
18 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms
import numpy as np
import glob
import os
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from skimage import io, transform
from PIL import Image
from model import U2NET
from model import U2NETP
from model import U2NETM
import cv2
# ------- 1. define loss function --------
bce_loss = nn.BCELoss(size_average=True)
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def IOU(pred, label):
Iand1 = np.sum(label[:, :]*pred[:, :])
Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1
IoU1 = Iand1/Ior1
return IoU1
# temp = np.abs(label - pred)
# union_minus_inter = np.sum(temp > 5)
# union_plus_inter = h * w * 2
# union = ((union_minus_inter + union_plus_inter)/2) - minus
# inter = union - union_minus_inter - minus
# print("inter: %d, union: %d" % (inter, union))
# iou = inter / union
return iou
# -----------------------------------修改loss函数增加canny loss--------------------------------------------------------
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v, mask_canny, label_canny):
loss0 = bce_loss(d0, labels_v)
loss1 = bce_loss(d1, labels_v)
loss2 = bce_loss(d2, labels_v)
loss3 = bce_loss(d3, labels_v)
loss4 = bce_loss(d4, labels_v)
loss5 = bce_loss(d5, labels_v)
loss6 = bce_loss(d6, labels_v)
loss_canny = bce_loss(mask_canny, label_canny)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss_canny
print("l_canny: %3f, loss0: %3f, loss1: %3f, loss2: %3f, loss3: %3f, loss4: %3f, loss5: %3f, loss6: %3f\n" % (
loss_canny.data.item(), loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(),
loss5.data.item(),
loss6.data.item()))
return loss0, loss
# ------- 2. set the directory of training dataset --------
model_name = 'u2net' # or 'u2net' or 'u2netm' or 'u2netp'
data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'miou_im' + os.sep)
tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'miou_gt' + os.sep)
image_ext = '.jpg'
label_ext = '.png'
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
epoch_num = 666
batch_size_train = 10
print("---batch size = ", batch_size_train)
batch_size_val = 1
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1, len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
train_num = len(tra_img_name_list)
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
# compose 对列表中的操作依次执行最后返回处理后的img
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
# ------- 3. define model --------
# define the net
if (model_name == 'u2net'):
net = U2NET(3, 1)
elif (model_name == 'u2netp'):
net = U2NETP(3, 1)
elif (model_name == 'u2netm'):
net = U2NETM(3, 1)
if torch.cuda.is_available():
net.cuda()
learning_rate = 1.e-7
# ------- 4. define optimizer --------
print("---defining optimizer...")
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
print('---optimizer defined successfully---')
# ------- 5. training process --------
print("---start TRAINING...")
# ------- 5(1). load model data --------
# c = input('? Do you need to LOAD the former model data ? (y/n): ')
# if c == 'y':
# # 加载模型参数
if torch.cuda.is_available():
# net.load_state_dict(torch.load(model_dir + model_name + '.pth'))
net = torch.load(model_dir + model_name + '.pth')
print('---GPU::the former model data LOADED successfully---')
else:
net.load_state_dict(torch.load(model_dir + model_name + '.pth', map_location='cpu'))
print('---CPU::the former model data LOADED successfully---')
# elif c == 'n':
# print('---the former model data IGNORED successfully---')
# else:
# print('ERROR ! invalid command !!!')
# print('---continue to TRAIN the model\nGood LUCK!...')
# ------------------------------------------
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 10 # save the model every 10 iterations
iou_sum = 0
n = 0.
for epoch in range(0, epoch_num):
net.train()
for i, data in enumerate(salobj_dataloader):
n = n + 1.
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
if ite_num % 250 == 0:
learning_rate = learning_rate / 2.
print("---Optimizer: learning_rate REDUCED by half successfully---")
print('---NOTE: present learning_rate = ', learning_rate)
print("---GOOD LUCK---")
else:
learning_rate = learning_rate
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
inputs, labels = data['image'], data['label']
inputs = inputs.type(torch.FloatTensor)
print(inputs.size())
labels = labels.type(torch.FloatTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# y zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
# ------compute mask_canny and label_canny------------------------------------------------------------------------------
mask_c = d0[:, 0, :, :]
mask_c = normPRED(mask_c)
mask_c = mask_c.squeeze()
mask_c = mask_c.cpu().data.numpy()
mask_c = Image.fromarray(np.uint8(mask_c[0] * 255)) # PIL格式
canny_image_name = tra_img_name_list[i] # 原图地址
canny_image = io.imread(canny_image_name)
mask_c = mask_c.resize((canny_image.shape[1], canny_image.shape[0]), resample=Image.BILINEAR) # 改为原图尺寸
mask_c = np.asarray(mask_c)
mask_canny = cv2.Canny(mask_c, 100, 200, apertureSize=5)
label_c = cv2.imread(tra_lbl_name_list[i])
label_canny = cv2.Canny(label_c, 100, 200, apertureSize=5)
# 对输入bce loss的mask canny和label canny转为list再转为tensor
# bce loss要求传入的参数为list tensor并且数据大小为0-1
mask_canny = np.array(mask_canny)/255.0
mask_canny = mask_canny.tolist()
mask_canny = torch.Tensor(mask_canny)
label_canny = np.array(label_canny)/255.0
label_canny = label_canny.tolist()
label_canny = torch.Tensor(label_canny)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v, mask_canny, label_canny)
# 记得修改最上面的loss函数
# ------canny loss end here---------------------------------------------------------------------------------------------
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.data.item()
running_tar_loss += loss2.data.item()
# print(tra_lbl_name_list[i])
# pred = d0[0, 0, :, :]
# pred = normPRED(pred)
# iou_image_name = tra_img_name_list[i]
# def create_pred(image_name, pred):
# predict = pred
# predict = predict.squeeze()
# predict_np = predict.cpu().data.numpy()
# print("predice_np:", predict_np.shape)
# predict_np = predict_np*255
# # im = Image.fromarray(np.uint8(predict_np[0]*255))
# img_name = image_name.split(os.sep)[-1]
# img_name = os.path.join(os.getcwd(), "train_data", tra_image_dir, img_name)
# print(img_name)
# image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
# print("image_shape", image.shape)
# pred = cv2.resize(predict_np, (image.shape[1], image.shape[0]))
# # image = io.imread(image_name)
# # imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
# print("pred_shape:", pred.shape)
# h = pred.shape[0]
# w = pred.shape[1]
# print(h, w)
# #二值化处理
# for i in range(h):
# for j in range(w):
# if pred[i][j] <= 127:
# pred[i][j] = 0
# if pred[i][j] > 127:
# pred[i][j] = 1
# return pred
# pred = create_pred(iou_image_name, pred)
# cv2.imwrite("pred_img.jpg", pred)
# iou_label_name = tra_lbl_name_list[i]
# h = pred.shape[0]
# w = pred.shape[1]
# mask = cv2.imread(iou_label_name, cv2.IMREAD_GRAYSCALE)
# #二值化处理
# for i in range(h):
# for j in range(w):
# if mask[i][j] <= 127:
# mask[i][j] = 0
# if mask[i][j] > 127:
# mask[i][j] = 1
# cv2.imwrite("mask_img.jpg", mask)
# # print("mask.shape: ", mask.shape) (300, 400)
# # mask_img = Image.fromarray(mask)
# # pred_img = Image.fromarray(pred)
# # mask_img.save("mask_img.jpg")
# # pred_img.save("pred_img.png")
# iou = IOU(pred, mask)
# iou_sum = iou_sum + iou
# MIOU = iou_sum / n
# print("MIOU = ", MIOU)
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d]train loss: %3f, tar: %3f" % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val,
running_tar_loss / ite_num4val))
# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
if ite_num % save_frq == 0:
torch.save(net, model_dir + model_name + '.pth')
# torch.save(model, f"../output/ckpt_%d.pth" % epoch)
# torch.save(net.state_dict(), model_dir + model_name + '.pth')
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
print('Congrats: model saved successfully!')
# ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
# import os
# import torch
# import torchvision
# from torch.autograd import Variable
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms, utils
# import torch.optim as optim
# import torchvision.transforms as standard_transforms
# import numpy as np
# import glob
# import os
# from data_loader import Rescale
# from data_loader import RescaleT
# from data_loader import RandomCrop
# from data_loader import ToTensor
# from data_loader import ToTensorLab
# from data_loader import SalObjDataset
# from model import U2NET
# from model import U2NETP
# from model import U2NETM
# # print(os.getcwd()) D:\A02分割论文\U-2-Net\U-2-Net-master
# # ------- 1. define loss function --------
# bce_loss = nn.BCELoss(size_average=True)
# def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
# loss0 = bce_loss(d0,labels_v)
# loss1 = bce_loss(d1,labels_v)
# loss2 = bce_loss(d2,labels_v)
# loss3 = bce_loss(d3,labels_v)
# loss4 = bce_loss(d4,labels_v)
# loss5 = bce_loss(d5,labels_v)
# loss6 = bce_loss(d6,labels_v)
# loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
# print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(), loss6.data.item()))
# return loss0, loss
# # ------- 2. set the directory of training dataset --------
# model_name = 'u2netm' # or 'u2net' or 'u2netm' or 'u2netp'
# data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'im_aug' + os.sep) # DUTS\DUTS-TR\im_aug\
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'gt_aug' + os.sep)
# # print(tra_image_dir) DUTS\DUTS-TR\im_aug\
# # print(data_dir) D:\A02图像分割\U-2-Net\U-2-Net-master\train_data\
# image_ext = '.jpg'
# label_ext = '.png'
# model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
# # print(model_dir) # model_dir D:\A02图像分割\U-2-Net\U-2-Net-master\saved_models\u2netm\
# epoch_num = 666
# batch_size_train = 15
# batch_size_val = 1
# train_num = 0
# val_num = 0
# tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
# # print(data_dir + tra_image_dir)
# # print(tra_img_name_list)
# tra_lbl_name_list = []
# for img_path in tra_img_name_list:
# img_name = img_path.split(os.sep)[-1]
# aaa = img_name.split(".")
# bbb = aaa[0:-1]
# imidx = bbb[0]
# for i in range(1,len(bbb)):
# imidx = imidx + "." + bbb[i]
# tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
# print("---")
# print("train images: ", len(tra_img_name_list))
# print("train labels: ", len(tra_lbl_name_list))
# print("---")
# train_num = len(tra_img_name_list)
# salobj_dataset = SalObjDataset(
# img_name_list=tra_img_name_list,
# lbl_name_list=tra_lbl_name_list,
# transform=transforms.Compose([
# RescaleT(320),
# RandomCrop(288),
# ToTensorLab(flag=0)]))
# salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
# # ------- 3. define model --------
# # define the net
# if(model_name == 'u2net'):
# net = U2NET(3,1)
# elif(model_name == 'u2netp'):
# net = U2NETP(3,1)
# elif(model_name == 'u2netm'):
# net = U2NETM(3,1)
# if torch.cuda.is_available():
# net.cuda()
# learning_rate = 3.e-5
# # ------- 4. define optimizer --------
# print("---defining optimizer...")
# optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# print('---optimizer defined successfully---')
# # ------- 5. training process --------
# print("---start TRAINING...")
# # ------- 5(1). load model data --------
# c = input('? Do you need to LOAD the former model data ? (y/n): ')
# if c == 'y':
# # 加载模型参数
# net.load_state_dict(torch.load(model_dir + model_name + '.pth'))
# print('---the former model data LOADED successfully---')
# elif c == 'n':
# print('---the former model data IGNORED successfully---')
# else:
# print('ERROR ! invalid command !!!')
# print('---continue to TRAIN the model\nGood LUCK!...')
# ite_num = 0
# running_loss = 0.0
# running_tar_loss = 0.0
# ite_num4val = 0
# save_frq = 10 # save the model every 20 iterations
# for epoch in range(0, epoch_num):
# net.train()
# for i, data in enumerate(salobj_dataloader):
# ite_num = ite_num + 1
# ite_num4val = ite_num4val + 1
# if ite_num % 250 == 0:
# learning_rate = learning_rate/2.
# print("---Optimizer: learning_rate REDUCED by half successfully---")
# print('---NOTE: present learning_rate = ', learning_rate)
# print("---GOOD LUCK---")
# else:
# learning_rate = learning_rate
# optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# inputs, labels = data['image'], data['label']
# inputs = inputs.type(torch.FloatTensor)
# labels = labels.type(torch.FloatTensor)
# # wrap them in Variable
# if torch.cuda.is_available():
# inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
# else:
# inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# # y zero the parameter gradients
# optimizer.zero_grad()
# # forward + backward + optimize
# d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
# loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
# loss.backward()
# optimizer.step()
# # # print statistics
# running_loss += loss.data.item()
# running_tar_loss += loss2.data.item()
# # del temporary outputs and loss
# del d0, d1, d2, d3, d4, d5, d6, loss2, loss
# print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
# epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
# if ite_num % save_frq == 0:
# torch.save(net.state_dict(), model_dir + model_name + '.pth')
# running_loss = 0.0
# running_tar_loss = 0.0
# net.train() # resume train
# ite_num4val = 0
# print('Congrats: model saved successfully!')