You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.9 KiB

4 years ago
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.registry import TRAINER
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return (1-d).mean()
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
self.cfg = cfg
self.loss_type = cfg.loss_type
if self.loss_type == 'cross_entropy':
weights = torch.ones(cfg.num_classes)
weights[0] = cfg.bg_weight
weights = weights.cuda()
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
def forward(self, net, batch):
output = net(batch['img'])
loss_stats = {}
loss = 0.
if self.loss_type == 'dice_loss':
target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
seg_loss = dice_loss(F.softmax(
output['seg'], dim=1)[:, 1:], target[:, 1:])
seg_loss = self.criterion(F.log_softmax(
output['seg'], dim=1), batch['label'].long())
loss += seg_loss * self.cfg.seg_loss_weight
loss_stats.update({'seg_loss': seg_loss})
if 'exist' in output:
exist_loss = 0.1 * \
self.criterion_exist(output['exist'], batch['exist'].float())
loss += exist_loss
loss_stats.update({'exist_loss': exist_loss})
ret = {'loss': loss, 'loss_stats': loss_stats}
return ret