|
|
import os
|
|
|
import torch
|
|
|
import torchvision
|
|
|
from torch.autograd import Variable
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from torchvision import transforms, utils
|
|
|
import torch.optim as optim
|
|
|
import torchvision.transforms as standard_transforms
|
|
|
|
|
|
import numpy as np
|
|
|
import glob
|
|
|
import os
|
|
|
|
|
|
from data_loader import Rescale
|
|
|
from data_loader import RescaleT
|
|
|
from data_loader import RandomCrop
|
|
|
from data_loader import ToTensor
|
|
|
from data_loader import ToTensorLab
|
|
|
from data_loader import SalObjDataset
|
|
|
|
|
|
from skimage import io, transform
|
|
|
from PIL import Image
|
|
|
|
|
|
from model import U2NET
|
|
|
from model import U2NETP
|
|
|
from model import U2NETM
|
|
|
import cv2
|
|
|
|
|
|
# ------- 1. define loss function --------
|
|
|
|
|
|
bce_loss = nn.BCELoss(size_average=True)
|
|
|
|
|
|
|
|
|
# normalize the predicted SOD probability map
|
|
|
def normPRED(d):
|
|
|
ma = torch.max(d)
|
|
|
mi = torch.min(d)
|
|
|
|
|
|
dn = (d-mi)/(ma-mi)
|
|
|
|
|
|
return dn
|
|
|
|
|
|
def IOU(pred, label):
|
|
|
|
|
|
Iand1 = np.sum(label[:, :]*pred[:, :])
|
|
|
Ior1 = np.sum(label[:, :]) + np.sum(pred[:, :]) - Iand1
|
|
|
IoU1 = Iand1/Ior1
|
|
|
|
|
|
return IoU1
|
|
|
# temp = np.abs(label - pred)
|
|
|
|
|
|
# union_minus_inter = np.sum(temp > 5)
|
|
|
|
|
|
# union_plus_inter = h * w * 2
|
|
|
|
|
|
# union = ((union_minus_inter + union_plus_inter)/2) - minus
|
|
|
|
|
|
# inter = union - union_minus_inter - minus
|
|
|
|
|
|
# print("inter: %d, union: %d" % (inter, union))
|
|
|
|
|
|
# iou = inter / union
|
|
|
|
|
|
return iou
|
|
|
|
|
|
# -----------------------------------修改loss函数,增加canny loss--------------------------------------------------------
|
|
|
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v, mask_canny, label_canny):
|
|
|
loss0 = bce_loss(d0, labels_v)
|
|
|
loss1 = bce_loss(d1, labels_v)
|
|
|
loss2 = bce_loss(d2, labels_v)
|
|
|
loss3 = bce_loss(d3, labels_v)
|
|
|
loss4 = bce_loss(d4, labels_v)
|
|
|
loss5 = bce_loss(d5, labels_v)
|
|
|
loss6 = bce_loss(d6, labels_v)
|
|
|
loss_canny = bce_loss(mask_canny, label_canny)
|
|
|
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss_canny
|
|
|
print("l_canny: %3f, loss0: %3f, loss1: %3f, loss2: %3f, loss3: %3f, loss4: %3f, loss5: %3f, loss6: %3f\n" % (
|
|
|
loss_canny.data.item(), loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(),
|
|
|
loss5.data.item(),
|
|
|
loss6.data.item()))
|
|
|
return loss0, loss
|
|
|
|
|
|
|
|
|
# ------- 2. set the directory of training dataset --------
|
|
|
|
|
|
model_name = 'u2net' # or 'u2net' or 'u2netm' or 'u2netp'
|
|
|
|
|
|
data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
|
|
|
tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'miou_im' + os.sep)
|
|
|
tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'miou_gt' + os.sep)
|
|
|
|
|
|
image_ext = '.jpg'
|
|
|
label_ext = '.png'
|
|
|
|
|
|
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
|
|
|
|
|
|
|
|
|
epoch_num = 666
|
|
|
batch_size_train = 10
|
|
|
print("---batch size = ", batch_size_train)
|
|
|
batch_size_val = 1
|
|
|
train_num = 0
|
|
|
val_num = 0
|
|
|
|
|
|
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
|
|
|
|
|
|
tra_lbl_name_list = []
|
|
|
for img_path in tra_img_name_list:
|
|
|
img_name = img_path.split(os.sep)[-1]
|
|
|
aaa = img_name.split(".")
|
|
|
bbb = aaa[0:-1]
|
|
|
imidx = bbb[0]
|
|
|
for i in range(1, len(bbb)):
|
|
|
imidx = imidx + "." + bbb[i]
|
|
|
|
|
|
tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
|
|
|
|
|
|
print("---")
|
|
|
print("train images: ", len(tra_img_name_list))
|
|
|
print("train labels: ", len(tra_lbl_name_list))
|
|
|
print("---")
|
|
|
|
|
|
train_num = len(tra_img_name_list)
|
|
|
|
|
|
salobj_dataset = SalObjDataset(
|
|
|
img_name_list=tra_img_name_list,
|
|
|
lbl_name_list=tra_lbl_name_list,
|
|
|
transform=transforms.Compose([
|
|
|
RescaleT(320),
|
|
|
RandomCrop(288),
|
|
|
ToTensorLab(flag=0)]))
|
|
|
# compose 对列表中的操作依次执行,最后返回处理后的img
|
|
|
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
|
|
|
|
|
|
# ------- 3. define model --------
|
|
|
# define the net
|
|
|
if (model_name == 'u2net'):
|
|
|
net = U2NET(3, 1)
|
|
|
elif (model_name == 'u2netp'):
|
|
|
net = U2NETP(3, 1)
|
|
|
elif (model_name == 'u2netm'):
|
|
|
net = U2NETM(3, 1)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
net.cuda()
|
|
|
|
|
|
learning_rate = 1.e-7
|
|
|
# ------- 4. define optimizer --------
|
|
|
print("---defining optimizer...")
|
|
|
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
print('---optimizer defined successfully---')
|
|
|
|
|
|
# ------- 5. training process --------
|
|
|
print("---start TRAINING...")
|
|
|
|
|
|
# ------- 5(1). load model data --------
|
|
|
|
|
|
# c = input('? Do you need to LOAD the former model data ? (y/n): ')
|
|
|
# if c == 'y':
|
|
|
# # 加载模型参数
|
|
|
if torch.cuda.is_available():
|
|
|
# net.load_state_dict(torch.load(model_dir + model_name + '.pth'))
|
|
|
net = torch.load(model_dir + model_name + '.pth')
|
|
|
print('---GPU::the former model data LOADED successfully---')
|
|
|
else:
|
|
|
net.load_state_dict(torch.load(model_dir + model_name + '.pth', map_location='cpu'))
|
|
|
print('---CPU::the former model data LOADED successfully---')
|
|
|
|
|
|
|
|
|
# elif c == 'n':
|
|
|
# print('---the former model data IGNORED successfully---')
|
|
|
# else:
|
|
|
# print('ERROR ! invalid command !!!')
|
|
|
# print('---continue to TRAIN the model\nGood LUCK!...')
|
|
|
# ------------------------------------------
|
|
|
|
|
|
ite_num = 0
|
|
|
running_loss = 0.0
|
|
|
running_tar_loss = 0.0
|
|
|
ite_num4val = 0
|
|
|
save_frq = 10 # save the model every 10 iterations
|
|
|
iou_sum = 0
|
|
|
n = 0.
|
|
|
|
|
|
for epoch in range(0, epoch_num):
|
|
|
net.train()
|
|
|
|
|
|
for i, data in enumerate(salobj_dataloader):
|
|
|
n = n + 1.
|
|
|
ite_num = ite_num + 1
|
|
|
ite_num4val = ite_num4val + 1
|
|
|
|
|
|
if ite_num % 250 == 0:
|
|
|
learning_rate = learning_rate / 2.
|
|
|
print("---Optimizer: learning_rate REDUCED by half successfully---")
|
|
|
print('---NOTE: present learning_rate = ', learning_rate)
|
|
|
print("---GOOD LUCK---")
|
|
|
else:
|
|
|
learning_rate = learning_rate
|
|
|
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
|
|
|
inputs, labels = data['image'], data['label']
|
|
|
|
|
|
inputs = inputs.type(torch.FloatTensor)
|
|
|
print(inputs.size())
|
|
|
labels = labels.type(torch.FloatTensor)
|
|
|
|
|
|
# wrap them in Variable
|
|
|
if torch.cuda.is_available():
|
|
|
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
|
|
else:
|
|
|
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
|
|
|
|
|
# y zero the parameter gradients
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
# forward + backward + optimize
|
|
|
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
|
|
|
# ------compute mask_canny and label_canny------------------------------------------------------------------------------
|
|
|
mask_c = d0[:, 0, :, :]
|
|
|
mask_c = normPRED(mask_c)
|
|
|
mask_c = mask_c.squeeze()
|
|
|
mask_c = mask_c.cpu().data.numpy()
|
|
|
mask_c = Image.fromarray(np.uint8(mask_c[0] * 255)) # PIL格式
|
|
|
canny_image_name = tra_img_name_list[i] # 原图地址
|
|
|
canny_image = io.imread(canny_image_name)
|
|
|
mask_c = mask_c.resize((canny_image.shape[1], canny_image.shape[0]), resample=Image.BILINEAR) # 改为原图尺寸
|
|
|
mask_c = np.asarray(mask_c)
|
|
|
mask_canny = cv2.Canny(mask_c, 100, 200, apertureSize=5)
|
|
|
label_c = cv2.imread(tra_lbl_name_list[i])
|
|
|
label_canny = cv2.Canny(label_c, 100, 200, apertureSize=5)
|
|
|
# 对输入bce loss的mask canny和label canny转为list再转为tensor
|
|
|
# bce loss要求传入的参数为list tensor并且数据大小为0-1
|
|
|
mask_canny = np.array(mask_canny)/255.0
|
|
|
mask_canny = mask_canny.tolist()
|
|
|
mask_canny = torch.Tensor(mask_canny)
|
|
|
label_canny = np.array(label_canny)/255.0
|
|
|
label_canny = label_canny.tolist()
|
|
|
label_canny = torch.Tensor(label_canny)
|
|
|
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v, mask_canny, label_canny)
|
|
|
# 记得修改最上面的loss函数
|
|
|
# ------canny loss end here---------------------------------------------------------------------------------------------
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
# print statistics
|
|
|
running_loss += loss.data.item()
|
|
|
running_tar_loss += loss2.data.item()
|
|
|
|
|
|
# print(tra_lbl_name_list[i])
|
|
|
# pred = d0[0, 0, :, :]
|
|
|
# pred = normPRED(pred)
|
|
|
# iou_image_name = tra_img_name_list[i]
|
|
|
# def create_pred(image_name, pred):
|
|
|
# predict = pred
|
|
|
# predict = predict.squeeze()
|
|
|
# predict_np = predict.cpu().data.numpy()
|
|
|
# print("predice_np:", predict_np.shape)
|
|
|
# predict_np = predict_np*255
|
|
|
# # im = Image.fromarray(np.uint8(predict_np[0]*255))
|
|
|
# img_name = image_name.split(os.sep)[-1]
|
|
|
# img_name = os.path.join(os.getcwd(), "train_data", tra_image_dir, img_name)
|
|
|
# print(img_name)
|
|
|
# image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
|
|
|
# print("image_shape", image.shape)
|
|
|
# pred = cv2.resize(predict_np, (image.shape[1], image.shape[0]))
|
|
|
# # image = io.imread(image_name)
|
|
|
# # imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
|
|
|
|
|
|
# print("pred_shape:", pred.shape)
|
|
|
# h = pred.shape[0]
|
|
|
# w = pred.shape[1]
|
|
|
# print(h, w)
|
|
|
# #二值化处理
|
|
|
# for i in range(h):
|
|
|
# for j in range(w):
|
|
|
# if pred[i][j] <= 127:
|
|
|
# pred[i][j] = 0
|
|
|
# if pred[i][j] > 127:
|
|
|
# pred[i][j] = 1
|
|
|
# return pred
|
|
|
|
|
|
|
|
|
# pred = create_pred(iou_image_name, pred)
|
|
|
|
|
|
# cv2.imwrite("pred_img.jpg", pred)
|
|
|
|
|
|
# iou_label_name = tra_lbl_name_list[i]
|
|
|
# h = pred.shape[0]
|
|
|
# w = pred.shape[1]
|
|
|
# mask = cv2.imread(iou_label_name, cv2.IMREAD_GRAYSCALE)
|
|
|
# #二值化处理
|
|
|
# for i in range(h):
|
|
|
# for j in range(w):
|
|
|
# if mask[i][j] <= 127:
|
|
|
# mask[i][j] = 0
|
|
|
# if mask[i][j] > 127:
|
|
|
# mask[i][j] = 1
|
|
|
# cv2.imwrite("mask_img.jpg", mask)
|
|
|
# # print("mask.shape: ", mask.shape) (300, 400)
|
|
|
# # mask_img = Image.fromarray(mask)
|
|
|
# # pred_img = Image.fromarray(pred)
|
|
|
# # mask_img.save("mask_img.jpg")
|
|
|
# # pred_img.save("pred_img.png")
|
|
|
# iou = IOU(pred, mask)
|
|
|
# iou_sum = iou_sum + iou
|
|
|
# MIOU = iou_sum / n
|
|
|
# print("MIOU = ", MIOU)
|
|
|
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d]train loss: %3f, tar: %3f" % (
|
|
|
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val,
|
|
|
running_tar_loss / ite_num4val))
|
|
|
# del temporary outputs and loss
|
|
|
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
|
|
|
|
|
|
if ite_num % save_frq == 0:
|
|
|
torch.save(net, model_dir + model_name + '.pth')
|
|
|
# torch.save(model, f"../output/ckpt_%d.pth" % epoch)
|
|
|
# torch.save(net.state_dict(), model_dir + model_name + '.pth')
|
|
|
running_loss = 0.0
|
|
|
running_tar_loss = 0.0
|
|
|
net.train() # resume train
|
|
|
ite_num4val = 0
|
|
|
print('Congrats: model saved successfully!')
|
|
|
|
|
|
# ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
|
|
# import os
|
|
|
# import torch
|
|
|
# import torchvision
|
|
|
# from torch.autograd import Variable
|
|
|
# import torch.nn as nn
|
|
|
# import torch.nn.functional as F
|
|
|
|
|
|
# from torch.utils.data import Dataset, DataLoader
|
|
|
# from torchvision import transforms, utils
|
|
|
# import torch.optim as optim
|
|
|
# import torchvision.transforms as standard_transforms
|
|
|
|
|
|
# import numpy as np
|
|
|
# import glob
|
|
|
# import os
|
|
|
|
|
|
# from data_loader import Rescale
|
|
|
# from data_loader import RescaleT
|
|
|
# from data_loader import RandomCrop
|
|
|
# from data_loader import ToTensor
|
|
|
# from data_loader import ToTensorLab
|
|
|
# from data_loader import SalObjDataset
|
|
|
|
|
|
# from model import U2NET
|
|
|
# from model import U2NETP
|
|
|
# from model import U2NETM
|
|
|
# # print(os.getcwd()) D:\A02分割论文\U-2-Net\U-2-Net-master
|
|
|
# # ------- 1. define loss function --------
|
|
|
|
|
|
# bce_loss = nn.BCELoss(size_average=True)
|
|
|
|
|
|
# def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
|
|
|
|
|
|
# loss0 = bce_loss(d0,labels_v)
|
|
|
# loss1 = bce_loss(d1,labels_v)
|
|
|
# loss2 = bce_loss(d2,labels_v)
|
|
|
# loss3 = bce_loss(d3,labels_v)
|
|
|
# loss4 = bce_loss(d4,labels_v)
|
|
|
# loss5 = bce_loss(d5,labels_v)
|
|
|
# loss6 = bce_loss(d6,labels_v)
|
|
|
|
|
|
# loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
|
|
|
# print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(), loss6.data.item()))
|
|
|
# return loss0, loss
|
|
|
|
|
|
|
|
|
# # ------- 2. set the directory of training dataset --------
|
|
|
|
|
|
# model_name = 'u2netm' # or 'u2net' or 'u2netm' or 'u2netp'
|
|
|
|
|
|
# data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
|
|
|
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'im_aug' + os.sep) # DUTS\DUTS-TR\im_aug\
|
|
|
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'gt_aug' + os.sep)
|
|
|
# # print(tra_image_dir) DUTS\DUTS-TR\im_aug\
|
|
|
# # print(data_dir) D:\A02图像分割\U-2-Net\U-2-Net-master\train_data\
|
|
|
# image_ext = '.jpg'
|
|
|
# label_ext = '.png'
|
|
|
|
|
|
# model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
|
|
|
# # print(model_dir) # model_dir D:\A02图像分割\U-2-Net\U-2-Net-master\saved_models\u2netm\
|
|
|
|
|
|
# epoch_num = 666
|
|
|
# batch_size_train = 15
|
|
|
# batch_size_val = 1
|
|
|
# train_num = 0
|
|
|
# val_num = 0
|
|
|
|
|
|
# tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
|
|
|
# # print(data_dir + tra_image_dir)
|
|
|
# # print(tra_img_name_list)
|
|
|
# tra_lbl_name_list = []
|
|
|
# for img_path in tra_img_name_list:
|
|
|
# img_name = img_path.split(os.sep)[-1]
|
|
|
# aaa = img_name.split(".")
|
|
|
# bbb = aaa[0:-1]
|
|
|
# imidx = bbb[0]
|
|
|
# for i in range(1,len(bbb)):
|
|
|
# imidx = imidx + "." + bbb[i]
|
|
|
|
|
|
# tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
|
|
|
|
|
|
# print("---")
|
|
|
# print("train images: ", len(tra_img_name_list))
|
|
|
# print("train labels: ", len(tra_lbl_name_list))
|
|
|
# print("---")
|
|
|
|
|
|
# train_num = len(tra_img_name_list)
|
|
|
|
|
|
# salobj_dataset = SalObjDataset(
|
|
|
# img_name_list=tra_img_name_list,
|
|
|
# lbl_name_list=tra_lbl_name_list,
|
|
|
# transform=transforms.Compose([
|
|
|
# RescaleT(320),
|
|
|
# RandomCrop(288),
|
|
|
# ToTensorLab(flag=0)]))
|
|
|
|
|
|
# salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
|
|
|
|
|
|
# # ------- 3. define model --------
|
|
|
# # define the net
|
|
|
# if(model_name == 'u2net'):
|
|
|
# net = U2NET(3,1)
|
|
|
# elif(model_name == 'u2netp'):
|
|
|
# net = U2NETP(3,1)
|
|
|
# elif(model_name == 'u2netm'):
|
|
|
# net = U2NETM(3,1)
|
|
|
|
|
|
|
|
|
# if torch.cuda.is_available():
|
|
|
# net.cuda()
|
|
|
|
|
|
# learning_rate = 3.e-5
|
|
|
# # ------- 4. define optimizer --------
|
|
|
# print("---defining optimizer...")
|
|
|
# optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
# print('---optimizer defined successfully---')
|
|
|
|
|
|
# # ------- 5. training process --------
|
|
|
# print("---start TRAINING...")
|
|
|
|
|
|
|
|
|
# # ------- 5(1). load model data --------
|
|
|
# c = input('? Do you need to LOAD the former model data ? (y/n): ')
|
|
|
# if c == 'y':
|
|
|
# # 加载模型参数
|
|
|
# net.load_state_dict(torch.load(model_dir + model_name + '.pth'))
|
|
|
# print('---the former model data LOADED successfully---')
|
|
|
# elif c == 'n':
|
|
|
# print('---the former model data IGNORED successfully---')
|
|
|
# else:
|
|
|
# print('ERROR ! invalid command !!!')
|
|
|
# print('---continue to TRAIN the model\nGood LUCK!...')
|
|
|
|
|
|
# ite_num = 0
|
|
|
# running_loss = 0.0
|
|
|
# running_tar_loss = 0.0
|
|
|
# ite_num4val = 0
|
|
|
# save_frq = 10 # save the model every 20 iterations
|
|
|
|
|
|
# for epoch in range(0, epoch_num):
|
|
|
# net.train()
|
|
|
|
|
|
# for i, data in enumerate(salobj_dataloader):
|
|
|
# ite_num = ite_num + 1
|
|
|
# ite_num4val = ite_num4val + 1
|
|
|
|
|
|
|
|
|
|
|
|
# if ite_num % 250 == 0:
|
|
|
# learning_rate = learning_rate/2.
|
|
|
# print("---Optimizer: learning_rate REDUCED by half successfully---")
|
|
|
# print('---NOTE: present learning_rate = ', learning_rate)
|
|
|
# print("---GOOD LUCK---")
|
|
|
# else:
|
|
|
# learning_rate = learning_rate
|
|
|
# optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
|
|
|
|
|
|
|
|
|
# inputs, labels = data['image'], data['label']
|
|
|
|
|
|
# inputs = inputs.type(torch.FloatTensor)
|
|
|
# labels = labels.type(torch.FloatTensor)
|
|
|
|
|
|
# # wrap them in Variable
|
|
|
# if torch.cuda.is_available():
|
|
|
# inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
|
|
# else:
|
|
|
# inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
|
|
|
|
|
|
|
|
# # y zero the parameter gradients
|
|
|
# optimizer.zero_grad()
|
|
|
|
|
|
# # forward + backward + optimize
|
|
|
# d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
|
|
|
# loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
|
|
|
|
|
|
# loss.backward()
|
|
|
# optimizer.step()
|
|
|
|
|
|
# # # print statistics
|
|
|
# running_loss += loss.data.item()
|
|
|
# running_tar_loss += loss2.data.item()
|
|
|
|
|
|
# # del temporary outputs and loss
|
|
|
# del d0, d1, d2, d3, d4, d5, d6, loss2, loss
|
|
|
|
|
|
# print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
|
|
|
# epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
|
|
|
|
|
|
# if ite_num % save_frq == 0:
|
|
|
|
|
|
# torch.save(net.state_dict(), model_dir + model_name + '.pth')
|
|
|
# running_loss = 0.0
|
|
|
# running_tar_loss = 0.0
|
|
|
# net.train() # resume train
|
|
|
# ite_num4val = 0
|
|
|
# print('Congrats: model saved successfully!') |