You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
904 lines
35 KiB
904 lines
35 KiB
# Enet pytorch code retrieved from https://github.com/davidtvs/PyTorch-ENet/blob/master/models/enet.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from utils.utils import mIoULoss, to_one_hot
|
|
|
|
|
|
class InitialBlock(nn.Module):
|
|
"""The initial block is composed of two branches:
|
|
1. a main branch which performs a regular convolution with stride 2;
|
|
2. an extension branch which performs max-pooling.
|
|
Doing both operations in parallel and concatenating their results
|
|
allows for efficient downsampling and expansion. The main branch
|
|
outputs 13 feature maps while the extension branch outputs 3, for a
|
|
total of 16 feature maps after concatenation.
|
|
Keyword arguments:
|
|
- in_channels (int): the number of input channels.
|
|
- out_channels (int): the number output channels.
|
|
- kernel_size (int, optional): the kernel size of the filters used in
|
|
the convolution layer. Default: 3.
|
|
- padding (int, optional): zero-padding added to both sides of the
|
|
input. Default: 0.
|
|
- bias (bool, optional): Adds a learnable bias to the output if
|
|
``True``. Default: False.
|
|
- relu (bool, optional): When ``True`` ReLU is used as the activation
|
|
function; otherwise, PReLU is used. Default: True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
bias=False,
|
|
relu=True):
|
|
super().__init__()
|
|
|
|
if relu:
|
|
activation = nn.ReLU
|
|
else:
|
|
activation = nn.PReLU
|
|
|
|
# Main branch - As stated above the number of output channels for this
|
|
# branch is the total minus 3, since the remaining channels come from
|
|
# the extension branch
|
|
self.main_branch = nn.Conv2d(
|
|
in_channels,
|
|
out_channels - 3,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
bias=bias)
|
|
|
|
# Extension branch
|
|
self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1)
|
|
|
|
# Initialize batch normalization to be used after concatenation
|
|
self.batch_norm = nn.BatchNorm2d(out_channels)
|
|
|
|
# PReLU layer to apply after concatenating the branches
|
|
self.out_activation = activation()
|
|
|
|
def forward(self, x):
|
|
main = self.main_branch(x)
|
|
ext = self.ext_branch(x)
|
|
|
|
# Concatenate branches
|
|
out = torch.cat((main, ext), 1)
|
|
|
|
# Apply batch normalization
|
|
out = self.batch_norm(out)
|
|
|
|
return self.out_activation(out)
|
|
|
|
|
|
class RegularBottleneck(nn.Module):
|
|
"""Regular bottlenecks are the main building block of ENet.
|
|
Main branch:
|
|
1. Shortcut connection.
|
|
Extension branch:
|
|
1. 1x1 convolution which decreases the number of channels by
|
|
``internal_ratio``, also called a projection;
|
|
2. regular, dilated or asymmetric convolution;
|
|
3. 1x1 convolution which increases the number of channels back to
|
|
``channels``, also called an expansion;
|
|
4. dropout as a regularizer.
|
|
Keyword arguments:
|
|
- channels (int): the number of input and output channels.
|
|
- internal_ratio (int, optional): a scale factor applied to
|
|
``channels`` used to compute the number of
|
|
channels after the projection. eg. given ``channels`` equal to 128 and
|
|
internal_ratio equal to 2 the number of channels after the projection
|
|
is 64. Default: 4.
|
|
- kernel_size (int, optional): the kernel size of the filters used in
|
|
the convolution layer described above in item 2 of the extension
|
|
branch. Default: 3.
|
|
- padding (int, optional): zero-padding added to both sides of the
|
|
input. Default: 0.
|
|
- dilation (int, optional): spacing between kernel elements for the
|
|
convolution described in item 2 of the extension branch. Default: 1.
|
|
asymmetric (bool, optional): flags if the convolution described in
|
|
item 2 of the extension branch is asymmetric or not. Default: False.
|
|
- dropout_prob (float, optional): probability of an element to be
|
|
zeroed. Default: 0 (no dropout).
|
|
- bias (bool, optional): Adds a learnable bias to the output if
|
|
``True``. Default: False.
|
|
- relu (bool, optional): When ``True`` ReLU is used as the activation
|
|
function; otherwise, PReLU is used. Default: True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
internal_ratio=4,
|
|
kernel_size=3,
|
|
padding=0,
|
|
dilation=1,
|
|
asymmetric=False,
|
|
dropout_prob=0,
|
|
bias=False,
|
|
relu=True):
|
|
super().__init__()
|
|
|
|
# Check in the internal_scale parameter is within the expected range
|
|
# [1, channels]
|
|
if internal_ratio <= 1 or internal_ratio > channels:
|
|
raise RuntimeError("Value out of range. Expected value in the "
|
|
"interval [1, {0}], got internal_scale={1}."
|
|
.format(channels, internal_ratio))
|
|
|
|
internal_channels = channels // internal_ratio
|
|
|
|
if relu:
|
|
activation = nn.ReLU
|
|
else:
|
|
activation = nn.PReLU
|
|
|
|
# Main branch - shortcut connection
|
|
|
|
# Extension branch - 1x1 convolution, followed by a regular, dilated or
|
|
# asymmetric convolution, followed by another 1x1 convolution, and,
|
|
# finally, a regularizer (spatial dropout). Number of channels is constant.
|
|
|
|
# 1x1 projection convolution
|
|
self.ext_conv1 = nn.Sequential(
|
|
nn.Conv2d(
|
|
channels,
|
|
internal_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation())
|
|
|
|
# If the convolution is asymmetric we split the main convolution in
|
|
# two. Eg. for a 5x5 asymmetric convolution we have two convolution:
|
|
# the first is 5x1 and the second is 1x5.
|
|
if asymmetric:
|
|
self.ext_conv2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
internal_channels,
|
|
kernel_size=(kernel_size, 1),
|
|
stride=1,
|
|
padding=(padding, 0),
|
|
dilation=dilation,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation(),
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
internal_channels,
|
|
kernel_size=(1, kernel_size),
|
|
stride=1,
|
|
padding=(0, padding),
|
|
dilation=dilation,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation())
|
|
else:
|
|
self.ext_conv2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
internal_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation())
|
|
|
|
# 1x1 expansion convolution
|
|
self.ext_conv3 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
bias=bias), nn.BatchNorm2d(channels), activation())
|
|
|
|
self.ext_regul = nn.Dropout2d(p=dropout_prob)
|
|
|
|
# PReLU layer to apply after adding the branches
|
|
self.out_activation = activation()
|
|
|
|
def forward(self, x):
|
|
# Main branch shortcut
|
|
main = x
|
|
|
|
# Extension branch
|
|
ext = self.ext_conv1(x)
|
|
ext = self.ext_conv2(ext)
|
|
ext = self.ext_conv3(ext)
|
|
ext = self.ext_regul(ext)
|
|
|
|
# Add main and extension branches
|
|
out = main + ext
|
|
|
|
return self.out_activation(out)
|
|
|
|
|
|
class DownsamplingBottleneck(nn.Module):
|
|
"""Downsampling bottlenecks further downsample the feature map size.
|
|
Main branch:
|
|
1. max pooling with stride 2; indices are saved to be used for
|
|
unpooling later.
|
|
Extension branch:
|
|
1. 2x2 convolution with stride 2 that decreases the number of channels
|
|
by ``internal_ratio``, also called a projection;
|
|
2. regular convolution (by default, 3x3);
|
|
3. 1x1 convolution which increases the number of channels to
|
|
``out_channels``, also called an expansion;
|
|
4. dropout as a regularizer.
|
|
Keyword arguments:
|
|
- in_channels (int): the number of input channels.
|
|
- out_channels (int): the number of output channels.
|
|
- internal_ratio (int, optional): a scale factor applied to ``channels``
|
|
used to compute the number of channels after the projection. eg. given
|
|
``channels`` equal to 128 and internal_ratio equal to 2 the number of
|
|
channels after the projection is 64. Default: 4.
|
|
- return_indices (bool, optional): if ``True``, will return the max
|
|
indices along with the outputs. Useful when unpooling later.
|
|
- dropout_prob (float, optional): probability of an element to be
|
|
zeroed. Default: 0 (no dropout).
|
|
- bias (bool, optional): Adds a learnable bias to the output if
|
|
``True``. Default: False.
|
|
- relu (bool, optional): When ``True`` ReLU is used as the activation
|
|
function; otherwise, PReLU is used. Default: True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
internal_ratio=4,
|
|
return_indices=False,
|
|
dropout_prob=0,
|
|
bias=False,
|
|
relu=True):
|
|
super().__init__()
|
|
|
|
# Store parameters that are needed later
|
|
self.return_indices = return_indices
|
|
|
|
# Check in the internal_scale parameter is within the expected range
|
|
# [1, channels]
|
|
if internal_ratio <= 1 or internal_ratio > in_channels:
|
|
raise RuntimeError("Value out of range. Expected value in the "
|
|
"interval [1, {0}], got internal_scale={1}. "
|
|
.format(in_channels, internal_ratio))
|
|
|
|
internal_channels = in_channels // internal_ratio
|
|
|
|
if relu:
|
|
activation = nn.ReLU
|
|
else:
|
|
activation = nn.PReLU
|
|
|
|
# Main branch - max pooling followed by feature map (channels) padding
|
|
self.main_max1 = nn.MaxPool2d(
|
|
2,
|
|
stride=2,
|
|
return_indices=return_indices)
|
|
|
|
# Extension branch - 2x2 convolution, followed by a regular, dilated or
|
|
# asymmetric convolution, followed by another 1x1 convolution. Number
|
|
# of channels is doubled.
|
|
|
|
# 2x2 projection convolution with stride 2
|
|
self.ext_conv1 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
internal_channels,
|
|
kernel_size=2,
|
|
stride=2,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation())
|
|
|
|
# Convolution
|
|
self.ext_conv2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
internal_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=bias), nn.BatchNorm2d(internal_channels), activation())
|
|
|
|
# 1x1 expansion convolution
|
|
self.ext_conv3 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
bias=bias), nn.BatchNorm2d(out_channels), activation())
|
|
|
|
self.ext_regul = nn.Dropout2d(p=dropout_prob)
|
|
|
|
# PReLU layer to apply after concatenating the branches
|
|
self.out_activation = activation()
|
|
|
|
def forward(self, x):
|
|
# Main branch shortcut
|
|
if self.return_indices:
|
|
main, max_indices = self.main_max1(x)
|
|
else:
|
|
main = self.main_max1(x)
|
|
|
|
# Extension branch
|
|
ext = self.ext_conv1(x)
|
|
ext = self.ext_conv2(ext)
|
|
ext = self.ext_conv3(ext)
|
|
ext = self.ext_regul(ext)
|
|
|
|
# Main branch channel padding
|
|
n, ch_ext, h, w = ext.size()
|
|
ch_main = main.size()[1]
|
|
padding = torch.zeros(n, ch_ext - ch_main, h, w)
|
|
|
|
# Before concatenating, check if main is on the CPU or GPU and
|
|
# convert padding accordingly
|
|
if main.is_cuda:
|
|
padding = padding.cuda()
|
|
|
|
# Concatenate
|
|
main = torch.cat((main, padding), 1)
|
|
|
|
# Add main and extension branches
|
|
out = main + ext
|
|
|
|
return self.out_activation(out), max_indices
|
|
|
|
|
|
class UpsamplingBottleneck(nn.Module):
|
|
"""The upsampling bottlenecks upsample the feature map resolution using max
|
|
pooling indices stored from the corresponding downsampling bottleneck.
|
|
Main branch:
|
|
1. 1x1 convolution with stride 1 that decreases the number of channels by
|
|
``internal_ratio``, also called a projection;
|
|
2. max unpool layer using the max pool indices from the corresponding
|
|
downsampling max pool layer.
|
|
Extension branch:
|
|
1. 1x1 convolution with stride 1 that decreases the number of channels by
|
|
``internal_ratio``, also called a projection;
|
|
2. transposed convolution (by default, 3x3);
|
|
3. 1x1 convolution which increases the number of channels to
|
|
``out_channels``, also called an expansion;
|
|
4. dropout as a regularizer.
|
|
Keyword arguments:
|
|
- in_channels (int): the number of input channels.
|
|
- out_channels (int): the number of output channels.
|
|
- internal_ratio (int, optional): a scale factor applied to ``in_channels``
|
|
used to compute the number of channels after the projection. eg. given
|
|
``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number
|
|
of channels after the projection is 64. Default: 4.
|
|
- dropout_prob (float, optional): probability of an element to be zeroed.
|
|
Default: 0 (no dropout).
|
|
- bias (bool, optional): Adds a learnable bias to the output if ``True``.
|
|
Default: False.
|
|
- relu (bool, optional): When ``True`` ReLU is used as the activation
|
|
function; otherwise, PReLU is used. Default: True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
internal_ratio=4,
|
|
dropout_prob=0,
|
|
bias=False,
|
|
relu=True):
|
|
super().__init__()
|
|
|
|
# Check in the internal_scale parameter is within the expected range
|
|
# [1, channels]
|
|
if internal_ratio <= 1 or internal_ratio > in_channels:
|
|
raise RuntimeError("Value out of range. Expected value in the "
|
|
"interval [1, {0}], got internal_scale={1}. "
|
|
.format(in_channels, internal_ratio))
|
|
|
|
internal_channels = in_channels // internal_ratio
|
|
|
|
if relu:
|
|
activation = nn.ReLU
|
|
else:
|
|
activation = nn.PReLU
|
|
|
|
# Main branch - max pooling followed by feature map (channels) padding
|
|
self.main_conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias),
|
|
nn.BatchNorm2d(out_channels))
|
|
|
|
# Remember that the stride is the same as the kernel_size, just like
|
|
# the max pooling layers
|
|
self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
|
|
|
|
# Extension branch - 1x1 convolution, followed by a regular, dilated or
|
|
# asymmetric convolution, followed by another 1x1 convolution. Number
|
|
# of channels is doubled.
|
|
|
|
# 1x1 projection convolution with stride 1
|
|
self.ext_conv1 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels, internal_channels, kernel_size=1, bias=bias),
|
|
nn.BatchNorm2d(internal_channels), activation())
|
|
|
|
# Transposed convolution
|
|
self.ext_tconv1 = nn.ConvTranspose2d(
|
|
internal_channels,
|
|
internal_channels,
|
|
kernel_size=2,
|
|
stride=2,
|
|
bias=bias)
|
|
self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels)
|
|
self.ext_tconv1_activation = activation()
|
|
|
|
# 1x1 expansion convolution
|
|
self.ext_conv2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
internal_channels, out_channels, kernel_size=1, bias=bias),
|
|
nn.BatchNorm2d(out_channels), activation())
|
|
|
|
self.ext_regul = nn.Dropout2d(p=dropout_prob)
|
|
|
|
# PReLU layer to apply after concatenating the branches
|
|
self.out_activation = activation()
|
|
|
|
def forward(self, x, max_indices, output_size):
|
|
# Main branch shortcut
|
|
main = self.main_conv1(x)
|
|
main = self.main_unpool1(
|
|
main, max_indices, output_size=output_size)
|
|
|
|
# Extension branch
|
|
ext = self.ext_conv1(x)
|
|
ext = self.ext_tconv1(ext, output_size=output_size)
|
|
ext = self.ext_tconv1_bnorm(ext)
|
|
ext = self.ext_tconv1_activation(ext)
|
|
ext = self.ext_conv2(ext)
|
|
ext = self.ext_regul(ext)
|
|
|
|
# Add main and extension branches
|
|
out = main + ext
|
|
|
|
return self.out_activation(out)
|
|
|
|
|
|
class ENet(nn.Module):
|
|
"""Generate the ENet model.
|
|
Keyword arguments:
|
|
- num_classes (int): the number of classes to segment.
|
|
- encoder_relu (bool, optional): When ``True`` ReLU is used as the
|
|
activation function in the encoder blocks/layers; otherwise, PReLU
|
|
is used. Default: False.
|
|
- decoder_relu (bool, optional): When ``True`` ReLU is used as the
|
|
activation function in the decoder blocks/layers; otherwise, PReLU
|
|
is used. Default: True.
|
|
"""
|
|
|
|
def __init__(self, num_classes, encoder_relu=False, decoder_relu=True):
|
|
super().__init__()
|
|
|
|
self.initial_block = InitialBlock(3, 16, relu=encoder_relu)
|
|
|
|
# Stage 1 - Encoder
|
|
self.downsample1_0 = DownsamplingBottleneck(
|
|
16,
|
|
64,
|
|
return_indices=True,
|
|
dropout_prob=0.01,
|
|
relu=encoder_relu)
|
|
self.regular1_1 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_2 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_3 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_4 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
|
|
# Stage 2 - Encoder
|
|
self.downsample2_0 = DownsamplingBottleneck(
|
|
64,
|
|
128,
|
|
return_indices=True,
|
|
dropout_prob=0.1,
|
|
relu=encoder_relu)
|
|
self.regular2_1 = RegularBottleneck(
|
|
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_2 = RegularBottleneck(
|
|
128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric2_3 = RegularBottleneck(
|
|
128,
|
|
kernel_size=5,
|
|
padding=2,
|
|
asymmetric=True,
|
|
dropout_prob=0.1,
|
|
relu=encoder_relu)
|
|
self.dilated2_4 = RegularBottleneck(
|
|
128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular2_5 = RegularBottleneck(
|
|
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_6 = RegularBottleneck(
|
|
128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric2_7 = RegularBottleneck(
|
|
128,
|
|
kernel_size=5,
|
|
asymmetric=True,
|
|
padding=2,
|
|
dropout_prob=0.1,
|
|
relu=encoder_relu)
|
|
self.dilated2_8 = RegularBottleneck(
|
|
128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
|
|
|
|
# Stage 3 - Encoder
|
|
self.regular3_0 = RegularBottleneck(
|
|
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_1 = RegularBottleneck(
|
|
128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric3_2 = RegularBottleneck(
|
|
128,
|
|
kernel_size=5,
|
|
padding=2,
|
|
asymmetric=True,
|
|
dropout_prob=0.1,
|
|
relu=encoder_relu)
|
|
self.dilated3_3 = RegularBottleneck(
|
|
128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular3_4 = RegularBottleneck(
|
|
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_5 = RegularBottleneck(
|
|
128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric3_6 = RegularBottleneck(
|
|
128,
|
|
kernel_size=5,
|
|
asymmetric=True,
|
|
padding=2,
|
|
dropout_prob=0.1,
|
|
relu=encoder_relu)
|
|
self.dilated3_7 = RegularBottleneck(
|
|
128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
|
|
|
|
# Stage 4 - Decoder
|
|
self.upsample4_0 = UpsamplingBottleneck(
|
|
128, 64, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular4_1 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular4_2 = RegularBottleneck(
|
|
64, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
|
|
# Stage 5 - Decoder
|
|
self.upsample5_0 = UpsamplingBottleneck(
|
|
64, 16, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular5_1 = RegularBottleneck(
|
|
16, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
self.transposed_conv = nn.ConvTranspose2d(
|
|
16,
|
|
num_classes,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
bias=False)
|
|
|
|
def forward(self, x):
|
|
# Initial block
|
|
input_size = x.size()
|
|
x = self.initial_block(x)
|
|
|
|
# Stage 1 - Encoder
|
|
stage1_input_size = x.size()
|
|
x, max_indices1_0 = self.downsample1_0(x)
|
|
x = self.regular1_1(x)
|
|
x = self.regular1_2(x)
|
|
x = self.regular1_3(x)
|
|
x = self.regular1_4(x)
|
|
|
|
# Stage 2 - Encoder
|
|
stage2_input_size = x.size()
|
|
x, max_indices2_0 = self.downsample2_0(x)
|
|
x = self.regular2_1(x)
|
|
x = self.dilated2_2(x)
|
|
x = self.asymmetric2_3(x)
|
|
x = self.dilated2_4(x)
|
|
x = self.regular2_5(x)
|
|
x = self.dilated2_6(x)
|
|
x = self.asymmetric2_7(x)
|
|
x = self.dilated2_8(x)
|
|
|
|
# Stage 3 - Encoder
|
|
x = self.regular3_0(x)
|
|
x = self.dilated3_1(x)
|
|
x = self.asymmetric3_2(x)
|
|
x = self.dilated3_3(x)
|
|
x = self.regular3_4(x)
|
|
x = self.dilated3_5(x)
|
|
x = self.asymmetric3_6(x)
|
|
x = self.dilated3_7(x)
|
|
|
|
# Stage 4 - Decoder
|
|
x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size)
|
|
x = self.regular4_1(x)
|
|
x = self.regular4_2(x)
|
|
|
|
# Stage 5 - Decoder
|
|
x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size)
|
|
x = self.regular5_1(x)
|
|
x = self.transposed_conv(x, output_size=input_size)
|
|
|
|
return x
|
|
|
|
|
|
class SpatialSoftmax(nn.Module):
|
|
def __init__(self, temperature=1, device='cpu'):
|
|
super(SpatialSoftmax, self).__init__()
|
|
|
|
if temperature:
|
|
self.temperature = Parameter(torch.ones(1) * temperature).to(device)
|
|
else:
|
|
self.temperature = 1.
|
|
|
|
def forward(self, feature):
|
|
feature = feature.view(feature.shape[0], -1, feature.shape[1] * feature.shape[2])
|
|
softmax_attention = F.softmax(feature / self.temperature, dim=-1)
|
|
|
|
return softmax_attention
|
|
|
|
|
|
class ENet_SAD(nn.Module):
|
|
"""Generate the ENet model.
|
|
Keyword arguments:
|
|
- num_classes (int): the number of classes to segment.
|
|
- encoder_relu (bool, optional): When ``True`` ReLU is used as the
|
|
activation function in the encoder blocks/layers; otherwise, PReLU
|
|
is used. Default: False.
|
|
- decoder_relu (bool, optional): When ``True`` ReLU is used as the
|
|
activation function in the decoder blocks/layers; otherwise, PReLU
|
|
is used. Default: True.
|
|
- sad (bool, optional): When ``True``, SAD is added to model
|
|
. If False, SAD is removed.
|
|
"""
|
|
|
|
def __init__(self, input_size, pretrained=False, sad=False, encoder_relu=False, decoder_relu=True, weight_share=True):
|
|
super().__init__()
|
|
|
|
# Init parameter
|
|
input_w, input_h = input_size
|
|
self.fc_input_feature = 5 * int(input_w / 16) * int(input_h / 16)
|
|
|
|
self.num_classes = 5
|
|
self.pretrained = pretrained
|
|
|
|
self.scale_background = 0.4
|
|
|
|
# Loss scale factor for ENet w/o SAD
|
|
self.scale_seg = 1.0
|
|
self.scale_exist = 0.1
|
|
|
|
# Loss scale factor for ENet w SAD
|
|
self.scale_sad_seg = 1.0
|
|
self.scale_sad_iou = 0.1
|
|
self.scale_sad_exist = 0.1
|
|
self.scale_sad_distill = 0.1
|
|
|
|
# Loss function
|
|
self.ce_loss = nn.CrossEntropyLoss(weight=torch.tensor([self.scale_background, 1, 1, 1, 1]))
|
|
self.bce_loss = nn.BCELoss()
|
|
self.iou_loss = mIoULoss(n_classes=4)
|
|
|
|
# Stage 0 - Initial block
|
|
self.initial_block = InitialBlock(3, 16, relu=encoder_relu)
|
|
self.sad = sad
|
|
|
|
# Stage 1 - Encoder (E1)
|
|
self.downsample1_0 = DownsamplingBottleneck(16, 64, return_indices=True, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_1 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_2 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_3 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
self.regular1_4 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
|
|
|
|
# Shared Encoder (E2~E4)
|
|
# Stage 2 - Encoder (E2)
|
|
self.downsample2_0 = DownsamplingBottleneck(64, 128, return_indices=True, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular2_1 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_2 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric2_3 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_4 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular2_5 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_6 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric2_7 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated2_8 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
|
|
|
|
# Stage 3 - Encoder (E3)
|
|
if weight_share:
|
|
self.regular3_0 = self.regular2_1
|
|
self.dilated3_1 = self.dilated2_2
|
|
self.asymmetric3_2 = self.asymmetric2_3
|
|
self.dilated3_3 = self.dilated2_4
|
|
self.regular3_4 = self.regular2_5
|
|
self.dilated3_5 = self.dilated2_6
|
|
self.asymmetric3_6 = self.asymmetric2_7
|
|
self.dilated3_7 = self.dilated2_8
|
|
else:
|
|
self.regular3_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_1 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric3_2 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_3 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular3_4 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_5 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric3_6 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated3_7 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
|
|
|
|
# Stage 4 - Encoder (E4)
|
|
if weight_share:
|
|
self.regular4_0 = self.regular2_1
|
|
self.dilated4_1 = self.dilated2_2
|
|
self.asymmetric4_2 = self.asymmetric2_3
|
|
self.dilated4_3 = self.dilated2_4
|
|
self.regular4_4 = self.regular2_5
|
|
self.dilated4_5 = self.dilated2_6
|
|
self.asymmetric4_6 = self.asymmetric2_7
|
|
self.dilated4_7 = self.dilated2_8
|
|
else:
|
|
self.regular4_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated4_1 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric4_2 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated4_3 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
|
|
self.regular4_4 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated4_5 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
|
|
self.asymmetric4_6 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
|
|
self.dilated4_7 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
|
|
|
|
# Stage 5 - Decoder (D1)
|
|
# self.upsample4_0 = UpsamplingBottleneck(128, 64, dropout_prob=0.1, relu=decoder_relu)
|
|
self.upsample4_0 = UpsamplingBottleneck(256, 64, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular4_1 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular4_2 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
|
|
# Stage 6 - Decoder (D2)
|
|
self.upsample5_0 = UpsamplingBottleneck(64, 16, dropout_prob=0.1, relu=decoder_relu)
|
|
self.regular5_1 = RegularBottleneck(16, padding=1, dropout_prob=0.1, relu=decoder_relu)
|
|
self.transposed_conv = nn.ConvTranspose2d(16, self.num_classes, kernel_size=3, stride=2, padding=1, bias=False)
|
|
|
|
# AT_GEN
|
|
if self.sad:
|
|
self.at_gen_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|
self.at_gen_l2_loss = nn.MSELoss(reduction='mean')
|
|
|
|
# Lane exist (P1)
|
|
self.layer3 = nn.Sequential(
|
|
nn.Conv2d(128, 5, 1),
|
|
nn.Softmax(dim=1),
|
|
nn.AvgPool2d(2, 2),
|
|
)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(self.fc_input_feature, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 4),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def at_gen(self, x1, x2):
|
|
"""
|
|
x1 - previous encoder step feature map
|
|
x2 - current encoder step feature map
|
|
"""
|
|
|
|
# G^2_sum
|
|
sps = SpatialSoftmax(device=x1.device)
|
|
|
|
if x1.size() != x2.size():
|
|
x1 = torch.sum(x1 * x1, dim=1)
|
|
x1 = sps(x1)
|
|
x2 = torch.sum(x2 * x2, dim=1, keepdim=True)
|
|
x2 = torch.squeeze(self.at_gen_upsample(x2), dim=1)
|
|
x2 = sps(x2)
|
|
else:
|
|
x1 = torch.sum(x1 * x1, dim=1)
|
|
x1 = sps(x1)
|
|
x2 = torch.sum(x2 * x2, dim=1)
|
|
x2 = sps(x2)
|
|
|
|
loss = self.at_gen_l2_loss(x1, x2)
|
|
return loss
|
|
|
|
def forward(self, img, seg_gt=None, exist_gt=None, sad_loss=False):
|
|
# Stage 0 - Initial block
|
|
input_size = img.size()
|
|
x_0 = self.initial_block(img)
|
|
|
|
# AT-GEN after each E2, E3, E4
|
|
# Stage 1 - Encoder (E1)
|
|
stage1_input_size = x_0.size()
|
|
x, max_indices1_0 = self.downsample1_0(x_0)
|
|
x = self.regular1_1(x)
|
|
x = self.regular1_2(x)
|
|
x = self.regular1_3(x)
|
|
x_1 = self.regular1_4(x)
|
|
# if self.sad:
|
|
# loss_1 = self.at_gen(x_0, x_1)
|
|
|
|
# Stage 2 - Encoder (E2)
|
|
stage2_input_size = x_1.size()
|
|
x, max_indices2_0 = self.downsample2_0(x_1)
|
|
x = self.regular2_1(x)
|
|
x = self.dilated2_2(x)
|
|
x = self.asymmetric2_3(x)
|
|
x = self.dilated2_4(x)
|
|
x = self.regular2_5(x)
|
|
x = self.dilated2_6(x)
|
|
x = self.asymmetric2_7(x)
|
|
x_2 = self.dilated2_8(x)
|
|
if self.sad:
|
|
loss_2 = self.at_gen(x_1, x_2)
|
|
|
|
# Stage 3 - Encoder (E3)
|
|
x = self.regular3_0(x_2)
|
|
x = self.dilated3_1(x)
|
|
x = self.asymmetric3_2(x)
|
|
x = self.dilated3_3(x)
|
|
x = self.regular3_4(x)
|
|
x = self.dilated3_5(x)
|
|
x = self.asymmetric3_6(x)
|
|
x_3 = self.dilated3_7(x)
|
|
if self.sad:
|
|
loss_3 = self.at_gen(x_2, x_3)
|
|
|
|
# Stage 4 - Encoder (E4)
|
|
x = self.regular3_0(x_3)
|
|
x = self.dilated3_1(x)
|
|
x = self.asymmetric3_2(x)
|
|
x = self.dilated3_3(x)
|
|
x = self.regular3_4(x)
|
|
x = self.dilated3_5(x)
|
|
x = self.asymmetric3_6(x)
|
|
x_4 = self.dilated3_7(x)
|
|
if self.sad:
|
|
loss_4 = self.at_gen(x_3, x_4)
|
|
|
|
# Concatenate E3, E4
|
|
x_34 = torch.cat((x_3, x_4), dim=1)
|
|
|
|
# Stage 4 - Decoder (D1)
|
|
x = self.upsample4_0(x_34, max_indices2_0, output_size=stage2_input_size)
|
|
x = self.regular4_1(x)
|
|
x = self.regular4_2(x)
|
|
|
|
# Stage 5 - Decoder (D2)
|
|
x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size)
|
|
x = self.regular5_1(x)
|
|
seg_pred = self.transposed_conv(x, output_size=input_size)
|
|
|
|
# lane exist
|
|
y = self.layer3(x_4)
|
|
y = y.view(-1, self.fc_input_feature)
|
|
exist_pred = self.fc(y)
|
|
|
|
# loss calculation
|
|
if seg_gt is not None and exist_gt is not None:
|
|
# L = L_seg + a * L_iou + b * L_exist + c * L_distill
|
|
if self.sad:
|
|
loss_seg = self.ce_loss(seg_pred, seg_gt)
|
|
seg_gt_onehot = to_one_hot(seg_gt, 5)
|
|
loss_iou = self.iou_loss(seg_pred[:, 1:self.num_classes, :, :], seg_gt_onehot[:, 1:self.num_classes, :, :])
|
|
loss_exist = self.bce_loss(exist_pred, exist_gt)
|
|
loss_distill = loss_2 + loss_3 + loss_4
|
|
loss = loss_seg * self.scale_sad_seg + loss_iou * self.scale_sad_iou + loss_exist * self.scale_sad_exist
|
|
|
|
# Add SAD loss after 40K episodes
|
|
if sad_loss:
|
|
loss += loss_distill * self.scale_sad_distill
|
|
|
|
else:
|
|
loss_seg = self.ce_loss(seg_pred, seg_gt)
|
|
loss_exist = self.bce_loss(exist_pred, exist_gt)
|
|
loss = loss_seg * self.scale_seg + loss_exist * self.scale_exist
|
|
|
|
else:
|
|
loss_seg = torch.tensor(0, dtype=img.dtype, device=img.device)
|
|
loss_exist = torch.tensor(0, dtype=img.dtype, device=img.device)
|
|
loss = torch.tensor(0, dtype=img.dtype, device=img.device)
|
|
|
|
return seg_pred, exist_pred, loss_seg, loss_exist, loss
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tensor = torch.ones((8, 3, 288, 800)).cuda()
|
|
seg_gt = torch.zeros((8, 288, 800)).long().cuda()
|
|
exist_gt = torch.ones((8, 4)).cuda()
|
|
enet_sad = ENet_SAD((800, 288), sad=True)
|
|
enet_sad.cuda()
|
|
enet_sad.train(mode=True)
|
|
result = enet_sad(tensor, seg_gt, exist_gt, sad_loss=True)
|