import argparse import json import os import shutil import time import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from config import * import dataset from model_ENET_SAD import ENet_SAD # from utils.tensorboard import TensorBoard from utils.transforms import * from utils.lr_scheduler import PolyLR from multiprocessing import Process, JoinableQueue from threading import Lock import pickle #from torch.multiprocessing import Process, SimpleQueue def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--exp_dir", type=str, default="./experiments/exp0") parser.add_argument("--resume", "-r", action="store_true") args = parser.parse_args() return args args = parse_args() # ------------ config ------------ exp_dir = args.exp_dir while exp_dir[-1]=='/': exp_dir = exp_dir[:-1] exp_name = exp_dir.split('/')[-1] with open(os.path.join(exp_dir, "cfg.json")) as f: exp_cfg = json.load(f) resize_shape = tuple(exp_cfg['dataset']['resize_shape']) device = torch.device(exp_cfg['device']) # tensorboard = TensorBoard(exp_dir) # ------------ train data ------------ # # CULane mean, std mean=(0.3598, 0.3653, 0.3662) std=(0.2573, 0.2663, 0.2756) # Imagenet mean, std # mean=(0.485, 0.456, 0.406) # std=(0.229, 0.224, 0.225) transform_train = Compose(Resize(resize_shape), Rotation(2), ToTensor(), Normalize(mean=mean, std=std)) dataset_name = exp_cfg['dataset'].pop('dataset_name') Dataset_Type = getattr(dataset, dataset_name) train_dataset = Dataset_Type(Dataset_Path[dataset_name], "train", transform_train) train_loader = DataLoader(train_dataset, batch_size=exp_cfg['dataset']['batch_size'], shuffle=True, collate_fn=train_dataset.collate, num_workers=0) # ------------ val data ------------ transform_val_img = Resize(resize_shape) transform_val_x = Compose(ToTensor(), Normalize(mean=mean, std=std)) transform_val = Compose(transform_val_img, transform_val_x) val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val) val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_dataset.collate, num_workers=0) # ------------ preparation ------------ net = ENet_SAD(resize_shape, sad=True) net = net.to(device) net = torch.nn.DataParallel(net) optimizer = optim.SGD(net.parameters(), **exp_cfg['optim']) lr_scheduler = PolyLR(optimizer, 0.9, **exp_cfg['lr_scheduler']) best_val_loss = 1e6 """ def batch_processor(arg): b_queue, data_loader = arg while True: if b_queue.empty(): sample = next(data_loader) b_queue.put(sample) b_queue.join() """ def train(epoch): print("Train Epoch: {}".format(epoch)) net.train() train_loss = 0 train_loss_seg = 0 train_loss_exist = 0 progressbar = tqdm(range(len(train_loader))) for batch_idx, sample in enumerate(train_loader): img = sample['img'].to(device) segLabel = sample['segLabel'].to(device) exist = sample['exist'].to(device) optimizer.zero_grad() if exp_cfg['model'] == "scnn": seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist) elif exp_cfg['model'] == "enet_sad": if (epoch * len(train_loader) + batch_idx) < exp_cfg['sad_start_iter']: seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist, False) else: print("sad activated") seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist, True) if isinstance(net, torch.nn.DataParallel): loss_seg = loss_seg.sum() loss_exist = loss_exist.sum() loss = loss.sum() loss.backward() optimizer.step() lr_scheduler.step() iter_idx = epoch * len(train_loader) + batch_idx train_loss = loss.item() train_loss_seg = loss_seg.item() train_loss_exist = loss_exist.item() progressbar.set_description("batch loss: {:.3f}".format(loss.item())) progressbar.update(1) lr = optimizer.param_groups[0]['lr'] progressbar.close() if epoch % 1 == 0: save_dict = { "epoch": epoch, "net": net.module.state_dict() if isinstance(net, torch.nn.DataParallel) else net.state_dict(), "optim": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "best_val_loss": best_val_loss } save_name = os.path.join(exp_dir, exp_name + '.pth') torch.save(save_dict, save_name) print("model is saved: {}".format(save_name)) print("------------------------\n") def val(epoch): global best_val_loss print("Val Epoch: {}".format(epoch)) net.eval() val_loss = 0 val_loss_seg = 0 val_loss_exist = 0 progressbar = tqdm(range(len(val_loader))) with torch.no_grad(): for batch_idx, sample in enumerate(val_loader): img = sample['img'].to(device) segLabel = sample['segLabel'].to(device) exist = sample['exist'].to(device) seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist) if isinstance(net, torch.nn.DataParallel): loss_seg = loss_seg.sum() loss_exist = loss_exist.sum() loss = loss.sum() # visualize validation every 5 frame, 50 frames in all gap_num = 5 if batch_idx%gap_num == 0 and batch_idx < 50 * gap_num: origin_imgs = [] seg_pred = seg_pred.detach().cpu().numpy() exist_pred = exist_pred.detach().cpu().numpy() for b in range(len(img)): img_name = sample['img_name'][b] img = cv2.imread(img_name) img = transform_val_img({'img': img})['img'] lane_img = np.zeros_like(img) color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8') coord_mask = np.argmax(seg_pred[b], axis=0) for i in range(0, 4): if exist_pred[b, i] > 0.5: lane_img[coord_mask==(i+1)] = color[i] img = cv2.addWeighted(src1=lane_img, alpha=0.8, src2=img, beta=1., gamma=0.) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) lane_img = cv2.cvtColor(lane_img, cv2.COLOR_BGR2RGB) cv2.putText(lane_img, "{}".format([1 if exist_pred[b, i]>0.5 else 0 for i in range(4)]), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (255, 255, 255), 2) origin_imgs.append(img) origin_imgs.append(lane_img) val_loss += loss.item() val_loss_seg += loss_seg.item() val_loss_exist += loss_exist.item() progressbar.set_description("batch loss: {:.3f}".format(loss.item())) progressbar.update(1) progressbar.close() iter_idx = (epoch + 1) * len(train_loader) # keep align with training process iter_idx print("------------------------\n") if val_loss < best_val_loss: best_val_loss = val_loss save_name = os.path.join(exp_dir, exp_name + '.pth') copy_name = os.path.join(exp_dir, exp_name + '_best.pth') shutil.copyfile(save_name, copy_name) def main(): global best_val_loss if args.resume: save_dict = torch.load(os.path.join(exp_dir, exp_name + '.pth')) if isinstance(net, torch.nn.DataParallel): net.module.load_state_dict(save_dict['net']) else: net.load_state_dict(save_dict['net']) optimizer.load_state_dict(save_dict['optim']) lr_scheduler.load_state_dict(save_dict['lr_scheduler']) start_epoch = save_dict['epoch'] + 1 best_val_loss = save_dict.get("best_val_loss", 1e6) else: start_epoch = 0 for epoch in range(start_epoch, exp_cfg['MAX_EPOCHES']): train(epoch) if epoch % 1 == 0: print("\nValidation For Experiment: ", exp_dir) print(time.strftime('%H:%M:%S', time.localtime())) val(epoch) if __name__ == "__main__": main()