You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
8.1 KiB

import argparse
import json
import os
import shutil
import time
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import *
import dataset
from model_ENET_SAD import ENet_SAD
# from utils.tensorboard import TensorBoard
from utils.transforms import *
from utils.lr_scheduler import PolyLR
from multiprocessing import Process, JoinableQueue
from threading import Lock
import pickle
#from torch.multiprocessing import Process, SimpleQueue
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--exp_dir", type=str, default="./experiments/exp0")
parser.add_argument("--resume", "-r", action="store_true")
args = parser.parse_args()
return args
args = parse_args()
# ------------ config ------------
exp_dir = args.exp_dir
while exp_dir[-1]=='/':
exp_dir = exp_dir[:-1]
exp_name = exp_dir.split('/')[-1]
with open(os.path.join(exp_dir, "cfg.json")) as f:
exp_cfg = json.load(f)
resize_shape = tuple(exp_cfg['dataset']['resize_shape'])
device = torch.device(exp_cfg['device'])
# tensorboard = TensorBoard(exp_dir)
# ------------ train data ------------
# # CULane mean, std
mean=(0.3598, 0.3653, 0.3662)
std=(0.2573, 0.2663, 0.2756)
# Imagenet mean, std
# mean=(0.485, 0.456, 0.406)
# std=(0.229, 0.224, 0.225)
transform_train = Compose(Resize(resize_shape), Rotation(2), ToTensor(),
Normalize(mean=mean, std=std))
dataset_name = exp_cfg['dataset'].pop('dataset_name')
Dataset_Type = getattr(dataset, dataset_name)
train_dataset = Dataset_Type(Dataset_Path[dataset_name], "train", transform_train)
train_loader = DataLoader(train_dataset, batch_size=exp_cfg['dataset']['batch_size'], shuffle=True, collate_fn=train_dataset.collate, num_workers=0)
# ------------ val data ------------
transform_val_img = Resize(resize_shape)
transform_val_x = Compose(ToTensor(), Normalize(mean=mean, std=std))
transform_val = Compose(transform_val_img, transform_val_x)
val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_dataset.collate, num_workers=0)
# ------------ preparation ------------
net = ENet_SAD(resize_shape, sad=True)
net = net.to(device)
net = torch.nn.DataParallel(net)
optimizer = optim.SGD(net.parameters(), **exp_cfg['optim'])
lr_scheduler = PolyLR(optimizer, 0.9, **exp_cfg['lr_scheduler'])
best_val_loss = 1e6
"""
def batch_processor(arg):
b_queue, data_loader = arg
while True:
if b_queue.empty():
sample = next(data_loader)
b_queue.put(sample)
b_queue.join()
"""
def train(epoch):
print("Train Epoch: {}".format(epoch))
net.train()
train_loss = 0
train_loss_seg = 0
train_loss_exist = 0
progressbar = tqdm(range(len(train_loader)))
for batch_idx, sample in enumerate(train_loader):
img = sample['img'].to(device)
segLabel = sample['segLabel'].to(device)
exist = sample['exist'].to(device)
optimizer.zero_grad()
if exp_cfg['model'] == "scnn":
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist)
elif exp_cfg['model'] == "enet_sad":
if (epoch * len(train_loader) + batch_idx) < exp_cfg['sad_start_iter']:
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist, False)
else:
print("sad activated")
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist, True)
if isinstance(net, torch.nn.DataParallel):
loss_seg = loss_seg.sum()
loss_exist = loss_exist.sum()
loss = loss.sum()
loss.backward()
optimizer.step()
lr_scheduler.step()
iter_idx = epoch * len(train_loader) + batch_idx
train_loss = loss.item()
train_loss_seg = loss_seg.item()
train_loss_exist = loss_exist.item()
progressbar.set_description("batch loss: {:.3f}".format(loss.item()))
progressbar.update(1)
lr = optimizer.param_groups[0]['lr']
progressbar.close()
if epoch % 1 == 0:
save_dict = {
"epoch": epoch,
"net": net.module.state_dict() if isinstance(net, torch.nn.DataParallel) else net.state_dict(),
"optim": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"best_val_loss": best_val_loss
}
save_name = os.path.join(exp_dir, exp_name + '.pth')
torch.save(save_dict, save_name)
print("model is saved: {}".format(save_name))
print("------------------------\n")
def val(epoch):
global best_val_loss
print("Val Epoch: {}".format(epoch))
net.eval()
val_loss = 0
val_loss_seg = 0
val_loss_exist = 0
progressbar = tqdm(range(len(val_loader)))
with torch.no_grad():
for batch_idx, sample in enumerate(val_loader):
img = sample['img'].to(device)
segLabel = sample['segLabel'].to(device)
exist = sample['exist'].to(device)
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist)
if isinstance(net, torch.nn.DataParallel):
loss_seg = loss_seg.sum()
loss_exist = loss_exist.sum()
loss = loss.sum()
# visualize validation every 5 frame, 50 frames in all
gap_num = 5
if batch_idx%gap_num == 0 and batch_idx < 50 * gap_num:
origin_imgs = []
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
for b in range(len(img)):
img_name = sample['img_name'][b]
img = cv2.imread(img_name)
img = transform_val_img({'img': img})['img']
lane_img = np.zeros_like(img)
color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8')
coord_mask = np.argmax(seg_pred[b], axis=0)
for i in range(0, 4):
if exist_pred[b, i] > 0.5:
lane_img[coord_mask==(i+1)] = color[i]
img = cv2.addWeighted(src1=lane_img, alpha=0.8, src2=img, beta=1., gamma=0.)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
lane_img = cv2.cvtColor(lane_img, cv2.COLOR_BGR2RGB)
cv2.putText(lane_img, "{}".format([1 if exist_pred[b, i]>0.5 else 0 for i in range(4)]), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (255, 255, 255), 2)
origin_imgs.append(img)
origin_imgs.append(lane_img)
val_loss += loss.item()
val_loss_seg += loss_seg.item()
val_loss_exist += loss_exist.item()
progressbar.set_description("batch loss: {:.3f}".format(loss.item()))
progressbar.update(1)
progressbar.close()
iter_idx = (epoch + 1) * len(train_loader) # keep align with training process iter_idx
print("------------------------\n")
if val_loss < best_val_loss:
best_val_loss = val_loss
save_name = os.path.join(exp_dir, exp_name + '.pth')
copy_name = os.path.join(exp_dir, exp_name + '_best.pth')
shutil.copyfile(save_name, copy_name)
def main():
global best_val_loss
if args.resume:
save_dict = torch.load(os.path.join(exp_dir, exp_name + '.pth'))
if isinstance(net, torch.nn.DataParallel):
net.module.load_state_dict(save_dict['net'])
else:
net.load_state_dict(save_dict['net'])
optimizer.load_state_dict(save_dict['optim'])
lr_scheduler.load_state_dict(save_dict['lr_scheduler'])
start_epoch = save_dict['epoch'] + 1
best_val_loss = save_dict.get("best_val_loss", 1e6)
else:
start_epoch = 0
for epoch in range(start_epoch, exp_cfg['MAX_EPOCHES']):
train(epoch)
if epoch % 1 == 0:
print("\nValidation For Experiment: ", exp_dir)
print(time.strftime('%H:%M:%S', time.localtime()))
val(epoch)
if __name__ == "__main__":
main()