You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
7.2 KiB
209 lines
7.2 KiB
import argparse
|
|
import json
|
|
import os
|
|
import shutil
|
|
import time
|
|
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from config import *
|
|
import dataset
|
|
from model import SCNN
|
|
from utils.transforms import *
|
|
from utils.lr_scheduler import PolyLR
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--exp_dir", type=str, default="experiments/exp0")
|
|
parser.add_argument("--resume", "-r", action="store_true", default=False)
|
|
args = parser.parse_args()
|
|
return args
|
|
args = parse_args()
|
|
|
|
# ------------ config ------------
|
|
exp_dir = args.exp_dir
|
|
while exp_dir[-1]=='/':
|
|
exp_dir = exp_dir[:-1]
|
|
exp_name = exp_dir.split('/')[-1]
|
|
|
|
with open(os.path.join(exp_dir, "cfg.json")) as f:
|
|
exp_cfg = json.load(f)
|
|
resize_shape = tuple(exp_cfg['dataset']['resize_shape'])
|
|
|
|
device = torch.device(exp_cfg['device'])
|
|
|
|
# tensorboard = TensorBoard(exp_dir)
|
|
|
|
# ------------ train data ------------
|
|
# Imagenet mean, std
|
|
mean=(0.485, 0.456, 0.406)
|
|
std=(0.229, 0.224, 0.225)
|
|
transform_train = Compose(Resize(resize_shape), Rotation(2), ToTensor(),
|
|
Normalize(mean=mean, std=std))
|
|
dataset_name = exp_cfg['dataset'].pop('dataset_name')
|
|
Dataset_Type = getattr(dataset, dataset_name)
|
|
train_dataset = Dataset_Type(Dataset_Path[dataset_name], "train", transform_train)
|
|
train_loader = DataLoader(train_dataset, batch_size=exp_cfg['dataset']['batch_size'], shuffle=True, collate_fn=train_dataset.collate, num_workers=0)
|
|
|
|
# ------------ val data ------------
|
|
transform_val_img = Resize(resize_shape)
|
|
transform_val_x = Compose(ToTensor(), Normalize(mean=mean, std=std))
|
|
transform_val = Compose(transform_val_img, transform_val_x)
|
|
val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val)
|
|
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_dataset.collate, num_workers=0)
|
|
|
|
# ------------ preparation ------------
|
|
net = SCNN(resize_shape, pretrained=False)
|
|
net = net.to(device)
|
|
net = torch.nn.DataParallel(net)
|
|
|
|
optimizer = optim.SGD(net.parameters(), **exp_cfg['optim'])
|
|
lr_scheduler = PolyLR(optimizer, 0.9, **exp_cfg['lr_scheduler'])
|
|
best_val_loss = 1e6
|
|
|
|
|
|
def train(epoch):
|
|
print("Train Epoch: {}".format(epoch))
|
|
net.train()
|
|
train_loss = 0
|
|
train_loss_seg = 0
|
|
train_loss_exist = 0
|
|
progressbar = tqdm(range(len(train_loader)))
|
|
|
|
for batch_idx, sample in enumerate(train_loader):
|
|
img = sample['img'].to(device)
|
|
segLabel = sample['segLabel'].to(device)
|
|
exist = sample['exist'].to(device)
|
|
|
|
optimizer.zero_grad()
|
|
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist)
|
|
if isinstance(net, torch.nn.DataParallel):
|
|
loss_seg = loss_seg.sum()
|
|
loss_exist = loss_exist.sum()
|
|
loss = loss.sum()
|
|
loss.backward()
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
|
|
iter_idx = epoch * len(train_loader) + batch_idx
|
|
train_loss = loss.item()
|
|
train_loss_seg = loss_seg.item()
|
|
train_loss_exist = loss_exist.item()
|
|
progressbar.set_description("batch loss: {:.3f}".format(loss.item()))
|
|
progressbar.update(1)
|
|
|
|
lr = optimizer.param_groups[0]['lr']
|
|
|
|
progressbar.close()
|
|
|
|
if epoch % 1 == 0:
|
|
save_dict = {
|
|
"epoch": epoch,
|
|
"net": net.module.state_dict() if isinstance(net, torch.nn.DataParallel) else net.state_dict(),
|
|
"optim": optimizer.state_dict(),
|
|
"lr_scheduler": lr_scheduler.state_dict(),
|
|
"best_val_loss": best_val_loss
|
|
}
|
|
save_name = os.path.join(exp_dir, exp_name + '.pth')
|
|
torch.save(save_dict, save_name)
|
|
print("model is saved: {}".format(save_name))
|
|
|
|
print("------------------------\n")
|
|
|
|
|
|
def val(epoch):
|
|
global best_val_loss
|
|
|
|
print("Val Epoch: {}".format(epoch))
|
|
|
|
net.eval()
|
|
val_loss = 0
|
|
val_loss_seg = 0
|
|
val_loss_exist = 0
|
|
progressbar = tqdm(range(len(val_loader)))
|
|
|
|
with torch.no_grad():
|
|
for batch_idx, sample in enumerate(val_loader):
|
|
img = sample['img'].to(device)
|
|
segLabel = sample['segLabel'].to(device)
|
|
exist = sample['exist'].to(device)
|
|
|
|
seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist)
|
|
if isinstance(net, torch.nn.DataParallel):
|
|
loss_seg = loss_seg.sum()
|
|
loss_exist = loss_exist.sum()
|
|
loss = loss.sum()
|
|
|
|
# visualize validation every 5 frame, 50 frames in all
|
|
gap_num = 5
|
|
if batch_idx%gap_num == 0 and batch_idx < 50 * gap_num:
|
|
origin_imgs = []
|
|
seg_pred = seg_pred.detach().cpu().numpy()
|
|
exist_pred = exist_pred.detach().cpu().numpy()
|
|
|
|
for b in range(len(img)):
|
|
img_name = sample['img_name'][b]
|
|
img = cv2.imread(img_name)
|
|
img = transform_val_img({'img': img})['img']
|
|
|
|
lane_img = np.zeros_like(img)
|
|
color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8')
|
|
|
|
coord_mask = np.argmax(seg_pred[b], axis=0)
|
|
for i in range(0, 4):
|
|
if exist_pred[b, i] > 0.5:
|
|
lane_img[coord_mask==(i+1)] = color[i]
|
|
img = cv2.addWeighted(src1=lane_img, alpha=0.8, src2=img, beta=1., gamma=0.)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
lane_img = cv2.cvtColor(lane_img, cv2.COLOR_BGR2RGB)
|
|
cv2.putText(lane_img, "{}".format([1 if exist_pred[b, i]>0.5 else 0 for i in range(4)]), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (255, 255, 255), 2)
|
|
origin_imgs.append(img)
|
|
origin_imgs.append(lane_img)
|
|
|
|
val_loss += loss.item()
|
|
val_loss_seg += loss_seg.item()
|
|
val_loss_exist += loss_exist.item()
|
|
|
|
progressbar.set_description("batch loss: {:.3f}".format(loss.item()))
|
|
progressbar.update(1)
|
|
|
|
progressbar.close()
|
|
iter_idx = (epoch + 1) * len(train_loader) # keep align with training process iter_idx
|
|
|
|
print("------------------------\n")
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
save_name = os.path.join(exp_dir, exp_name + '.pth')
|
|
copy_name = os.path.join(exp_dir, exp_name + '_best.pth')
|
|
shutil.copyfile(save_name, copy_name)
|
|
|
|
|
|
def main():
|
|
global best_val_loss
|
|
if args.resume:
|
|
save_dict = torch.load(os.path.join(exp_dir, exp_name + '.pth'))
|
|
if isinstance(net, torch.nn.DataParallel):
|
|
net.module.load_state_dict(save_dict['net'])
|
|
else:
|
|
net.load_state_dict(save_dict['net'])
|
|
optimizer.load_state_dict(save_dict['optim'])
|
|
lr_scheduler.load_state_dict(save_dict['lr_scheduler'])
|
|
start_epoch = save_dict['epoch'] + 1
|
|
best_val_loss = save_dict.get("best_val_loss", 1e6)
|
|
else:
|
|
start_epoch = 0
|
|
|
|
for epoch in range(start_epoch, exp_cfg['MAX_EPOCHES']):
|
|
train(epoch)
|
|
if epoch % 1 == 0:
|
|
print("\nValidation For Experiment: ", exp_dir)
|
|
print(time.strftime('%H:%M:%S', time.localtime()))
|
|
val(epoch)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|