You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
129 lines
5.1 KiB
129 lines
5.1 KiB
from utils.loss import SoftmaxFocalLoss, ParsingRelationLoss, ParsingRelationDis
|
|
from utils.metrics import MultiLabelAcc, AccTopk, Metric_mIoU
|
|
from utils.dist_utils import DistSummaryWriter
|
|
|
|
import torch
|
|
|
|
|
|
def get_optimizer(net,cfg):
|
|
training_params = filter(lambda p: p.requires_grad, net.parameters())
|
|
if cfg.optimizer == 'Adam':
|
|
optimizer = torch.optim.Adam(training_params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
|
elif cfg.optimizer == 'SGD':
|
|
optimizer = torch.optim.SGD(training_params, lr=cfg.learning_rate, momentum=cfg.momentum,
|
|
weight_decay=cfg.weight_decay)
|
|
else:
|
|
raise NotImplementedError
|
|
return optimizer
|
|
|
|
def get_scheduler(optimizer, cfg, iters_per_epoch):
|
|
if cfg.scheduler == 'multi':
|
|
scheduler = MultiStepLR(optimizer, cfg.steps, cfg.gamma, iters_per_epoch, cfg.warmup, iters_per_epoch if cfg.warmup_iters is None else cfg.warmup_iters)
|
|
elif cfg.scheduler == 'cos':
|
|
scheduler = CosineAnnealingLR(optimizer, cfg.epoch * iters_per_epoch, eta_min = 0, warmup = cfg.warmup, warmup_iters = cfg.warmup_iters)
|
|
else:
|
|
raise NotImplementedError
|
|
return scheduler
|
|
|
|
def get_loss_dict(cfg):
|
|
|
|
if cfg.use_aux:
|
|
loss_dict = {
|
|
'name': ['cls_loss', 'relation_loss', 'aux_loss', 'relation_dis'],
|
|
'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), torch.nn.CrossEntropyLoss(), ParsingRelationDis()],
|
|
'weight': [1.0, cfg.sim_loss_w, 1.0, cfg.shp_loss_w],
|
|
'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('seg_out', 'seg_label'), ('cls_out',)]
|
|
}
|
|
else:
|
|
loss_dict = {
|
|
'name': ['cls_loss', 'relation_loss', 'relation_dis'],
|
|
'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), ParsingRelationDis()],
|
|
'weight': [1.0, cfg.sim_loss_w, cfg.shp_loss_w],
|
|
'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('cls_out',)]
|
|
}
|
|
|
|
return loss_dict
|
|
|
|
def get_metric_dict(cfg):
|
|
|
|
if cfg.use_aux:
|
|
metric_dict = {
|
|
'name': ['top1', 'top2', 'top3', 'iou'],
|
|
'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3), Metric_mIoU(cfg.num_lanes+1)],
|
|
'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('seg_out', 'seg_label')]
|
|
}
|
|
else:
|
|
metric_dict = {
|
|
'name': ['top1', 'top2', 'top3'],
|
|
'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3)],
|
|
'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label')]
|
|
}
|
|
|
|
|
|
return metric_dict
|
|
|
|
|
|
class MultiStepLR:
|
|
def __init__(self, optimizer, steps, gamma = 0.1, iters_per_epoch = None, warmup = None, warmup_iters = None):
|
|
self.warmup = warmup
|
|
self.warmup_iters = warmup_iters
|
|
self.optimizer = optimizer
|
|
self.steps = steps
|
|
self.steps.sort()
|
|
self.gamma = gamma
|
|
self.iters_per_epoch = iters_per_epoch
|
|
self.iters = 0
|
|
self.base_lr = [group['lr'] for group in optimizer.param_groups]
|
|
|
|
def step(self, external_iter = None):
|
|
self.iters += 1
|
|
if external_iter is not None:
|
|
self.iters = external_iter
|
|
if self.warmup == 'linear' and self.iters < self.warmup_iters:
|
|
rate = self.iters / self.warmup_iters
|
|
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
|
|
group['lr'] = lr * rate
|
|
return
|
|
|
|
# multi policy
|
|
if self.iters % self.iters_per_epoch == 0:
|
|
epoch = int(self.iters / self.iters_per_epoch)
|
|
power = -1
|
|
for i, st in enumerate(self.steps):
|
|
if epoch < st:
|
|
power = i
|
|
break
|
|
if power == -1:
|
|
power = len(self.steps)
|
|
# print(self.iters, self.iters_per_epoch, self.steps, power)
|
|
|
|
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
|
|
group['lr'] = lr * (self.gamma ** power)
|
|
import math
|
|
class CosineAnnealingLR:
|
|
def __init__(self, optimizer, T_max , eta_min = 0, warmup = None, warmup_iters = None):
|
|
self.warmup = warmup
|
|
self.warmup_iters = warmup_iters
|
|
self.optimizer = optimizer
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
|
|
self.iters = 0
|
|
self.base_lr = [group['lr'] for group in optimizer.param_groups]
|
|
|
|
def step(self, external_iter = None):
|
|
self.iters += 1
|
|
if external_iter is not None:
|
|
self.iters = external_iter
|
|
if self.warmup == 'linear' and self.iters < self.warmup_iters:
|
|
rate = self.iters / self.warmup_iters
|
|
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
|
|
group['lr'] = lr * rate
|
|
return
|
|
|
|
# cos policy
|
|
|
|
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
|
|
group['lr'] = self.eta_min + (lr - self.eta_min) * (1 + math.cos(math.pi * self.iters / self.T_max)) / 2
|
|
|
|
|