import torch import torch.nn as nn from torch.nn import init import functools from torch.optim import lr_scheduler import torch.nn.functional as F from .external_function import SpectralNorm, GroupNorm import numpy as np ###################################################################################### # base function for network structure ###################################################################################### def init_weights(net, init_type='normal', gain=0.02): """Get different initial method for the network weights""" def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) def get_norm_layer(norm_type='batch'): """Get the normalization layer for the networks""" if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) elif norm_type == 'group': norm_layer = functools.partial(GroupNorm) elif norm_type == 'none': norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_nonlinearity_layer(activation_type='PReLU'): """Get the activation layer for the networks""" if activation_type == 'ReLU': nonlinearity_layer = nn.ReLU() elif activation_type == 'SELU': nonlinearity_layer = nn.SELU() elif activation_type == 'LeakyReLU': nonlinearity_layer = nn.LeakyReLU(0.1) elif activation_type == 'PReLU': nonlinearity_layer = nn.PReLU() else: raise NotImplementedError('activation layer [%s] is not found' % activation_type) return nonlinearity_layer def get_scheduler(optimizer, opt): """Get the training learning rate for different epoch""" if opt.lr_policy == 'lambda': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + 1 + 1 + opt.iter_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'exponent': scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95) else: raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def print_network(net): """print the network""" num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('total number of parameters: %.3f M' % (num_params / 1e6)) def init_net(net, init_type='normal', activation='relu', gpu_ids=[]): """print the network structure and initial the network""" print_network(net) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.cuda() net = torch.nn.DataParallel(net, gpu_ids) init_weights(net, init_type) return net def _freeze(*args): """freeze the network for forward process""" for module in args: if module: for p in module.parameters(): p.requires_grad = False def _unfreeze(*args): """ unfreeze the network for parameter update""" for module in args: if module: for p in module.parameters(): p.requires_grad = True def spectral_norm(module, use_spect=True): """use spectral normal layer to stable the training process""" if use_spect: return SpectralNorm(module) else: return module def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, use_gated=False, **kwargs): """use coord convolution layer to add position information""" if use_gated: return Gated_Conv(input_nc, output_nc, **kwargs) elif use_coord: return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs) else: return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) ###################################################################################### # Network basic function ###################################################################################### class AddCoords(nn.Module): """ Add Coords to a tensor """ def __init__(self, with_r=False): super(AddCoords, self).__init__() self.with_r = with_r def forward(self, x): """ :param x: shape (batch, channel, x_dim, y_dim) :return: shape (batch, channel+2, x_dim, y_dim) """ B, _, x_dim, y_dim = x.size() # coord calculate xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x) yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x) # normalization xx_channel = xx_channel.float() / (x_dim - 1) yy_cahnnel = yy_cahnnel.float() / (y_dim - 1) xx_channel = xx_channel * 2 - 1 yy_cahnnel = yy_cahnnel * 2 - 1 ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1) if self.with_r: rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2) ret = torch.cat([ret, rr], dim=1) return ret class CoordConv(nn.Module): """ CoordConv operation """ def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs): super(CoordConv, self).__init__() self.addcoords = AddCoords(with_r=with_r) input_nc = input_nc + 2 if with_r: input_nc = input_nc + 1 self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) def forward(self, x): ret = self.addcoords(x) ret = self.conv(ret) return ret class ResBlock(nn.Module): """ Define an Residual block for different types """ def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), sample_type='none', use_spect=False, use_coord=False, use_gated=False): super(ResBlock, self).__init__() hidden_nc = output_nc if hidden_nc is None else hidden_nc self.in_nc = input_nc self.hi_nc = hidden_nc self.ou_nc = output_nc self.sample = True if sample_type == 'none': self.sample = False elif sample_type == 'up': output_nc = output_nc * 4 self.pool = nn.PixelShuffle(upscale_factor=2) elif sample_type == 'down': self.pool = nn.AvgPool2d(kernel_size=2, stride=2) else: raise NotImplementedError('sample type [%s] is not found' % sample_type) kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0} self.conv1 = coord_conv(input_nc, hidden_nc, use_spect, use_coord, use_gated=use_gated, **kwargs) self.conv2 = coord_conv(hidden_nc, output_nc, use_spect, use_coord, use_gated=use_gated, **kwargs) self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, use_gated=use_gated, **kwargs_short) if type(norm_layer) == type(None): self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2, ) else: self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2, ) self.shortcut = nn.Sequential(self.bypass, ) def forward(self, x): if self.sample: out = self.pool(self.model(x)) + self.pool(self.shortcut(x)) else: out = self.model(x) + self.shortcut(x) return out class ResBlockEncoderOptimized(nn.Module): """ Define an Encoder block for the first layer of the discriminator and representation network """ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False, use_coord=False, use_gated=False): super(ResBlockEncoderOptimized, self).__init__() kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0} self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, use_gated, **kwargs) self.conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, use_gated, **kwargs) self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, use_gated, **kwargs_short) if type(norm_layer) == type(None): self.model = nn.Sequential(self.conv1, nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2)) else: self.model = nn.Sequential(self.conv1, norm_layer(output_nc), nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2)) self.shortcut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), self.bypass) def forward(self, x): out = self.model(x) + self.shortcut(x) return out class ResBlockDecoder(nn.Module): """ Define a decoder block """ def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False, use_coord=False): super(ResBlockDecoder, self).__init__() hidden_nc = output_nc if hidden_nc is None else hidden_nc self.conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect) self.conv2 = spectral_norm( nn.ConvTranspose2d(hidden_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect) self.bypass = spectral_norm( nn.ConvTranspose2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect) if type(norm_layer) == type(None): self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2, ) else: self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2, ) self.shortcut = nn.Sequential(self.bypass) def forward(self, x): out = self.model(x) + self.shortcut(x) return out class Output(nn.Module): """ Define the output layer """ def __init__(self, input_nc, output_nc, kernel_size=3, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False, use_coord=False, use_gated=False): super(Output, self).__init__() kwargs = {'kernel_size': kernel_size, 'padding': 0, 'bias': True} self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, use_gated=use_gated, **kwargs) if type(norm_layer) == type(None): self.model = nn.Sequential(nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1, nn.Tanh()) else: self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1, nn.Tanh()) def forward(self, x): out = self.model(x) return out class Auto_Attn(nn.Module): """ Short+Long attention Layer""" def __init__(self, input_nc, norm_layer=nn.BatchNorm2d): super(Auto_Attn, self).__init__() self.input_nc = input_nc self.query_conv = nn.Conv2d(input_nc, input_nc // 4, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.alpha = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) self.model = ResBlock(int(input_nc * 2), input_nc, input_nc, norm_layer=norm_layer, use_spect=True) def forward(self, x, pre=None, mask=None): """ inputs : x : input feature maps( B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ # print(x.shape) B, C, W, H = x.size() proj_query = self.query_conv(x).view(B, -1, W * H) # B X (N)X C proj_key = proj_query # B X C x (N) energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key) # transpose check attention = self.softmax(energy) # BX (N) X (N) proj_value = x.view(B, -1, W * H) # B X C X N out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(B, C, W, H) out = self.gamma * out + x if type(pre) != type(None): # using long distance attention layer to copy information from valid regions context_flow = torch.bmm(pre.view(B, -1, W * H), attention.permute(0, 2, 1)).view(B, -1, W, H) context_flow = self.alpha * (1 - mask) * context_flow + (mask) * pre out = self.model(torch.cat([out, context_flow], dim=1)) return out, attention class SimAM(nn.Module): def __init__(self, channels=None, e_lambda=1e-4): super(SimAM, self).__init__() self.activaton = nn.Sigmoid() self.e_lambda = e_lambda def __repr__(self): s = self.__class__.__name__ + '(' s += ('lambda=%f)' % self.e_lambda) return s @staticmethod def get_module_name(): return "simam" def forward(self, x): b, c, h, w = x.size() n = w * h - 1 x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 return x * self.activaton(y) class HardSPDNorm(nn.Module): def __init__(self, n, k, p_input_nc, F_in_nc): super(HardSPDNorm, self).__init__() self.n = n self.k = k self.F_in_nc = F_in_nc self.gamma_conv = Gated_Conv(2 * p_input_nc, self.F_in_nc, kernel_size=3, stride=1, padding=1) self.beta_conv = Gated_Conv(2 * p_input_nc, self.F_in_nc, kernel_size=3, stride=1, padding=1) self.dsample_p = nn.AvgPool2d(kernel_size=2, stride=2) self.dsample_m = nn.MaxPool2d(kernel_size=2, stride=2) self.innorm = nn.InstanceNorm2d(self.F_in_nc, affine=False) def forward(self, F_in, img_p, mask, n_ds): # print(self.F_in_nc) mask = mask.clone() # downsample for i in range(n_ds): img_p = self.dsample_p(img_p) mask = self.dsample_m(mask) img_p = torch.cat([img_p, mask], dim=1) mask, _, _ = torch.chunk(mask, dim=1, chunks=3) # D_h kernel = torch.ones(mask.shape[1], mask.shape[1], 3, 3).cuda() D_h = mask msk = D_h.detach() msk = torch.where(msk == 1, True, False) for i in range(1, self.n + 1): D_h = F.conv2d(D_h, kernel, stride=1, padding=1) tmp = D_h.detach() tmp = torch.where(tmp > 0, True, False) tmp = tmp & ~msk msk = msk | tmp mask[tmp] = 1 / self.k ** (i) D_h = mask gamma_hp = self.gamma_conv(img_p) beta_hp = self.beta_conv(img_p) gamma_hd = gamma_hp * D_h beta_hd = beta_hp * D_h F_in_norm = self.innorm(F_in) # print(gamma_hd.shape) # print(F_in.shape) F_hard = F_in_norm * (gamma_hd + 1) + beta_hd return F_hard class SoftSPDNorm(nn.Module): def __init__(self, p_input_nc, F_in_nc): super(SoftSPDNorm, self).__init__() self.gamma_conv = Gated_Conv(2 * p_input_nc, F_in_nc, kernel_size=3, stride=1, padding=1) self.beta_conv = Gated_Conv(2 * p_input_nc, F_in_nc, kernel_size=3, stride=1, padding=1) self.p_conv = Gated_Conv(6, F_in_nc, stride=1, kernel_size=1) self.f_conv = Gated_Conv(2 * F_in_nc, 1, kernel_size=1, stride=1) self.dsample_p = nn.AvgPool2d(kernel_size=2, stride=2) self.dsample_m = nn.MaxPool2d(kernel_size=2, stride=2) self.sigmoid = nn.Sigmoid() self.innorm = nn.InstanceNorm2d(F_in_nc, affine=False) def forward(self, F_in, img_p, mask, n_ds): mask = mask.clone() # downsample for i in range(n_ds): img_p = self.dsample_p(img_p) mask = self.dsample_m(mask) img_p = torch.cat([img_p, mask], dim=1) mask, _, _ = torch.chunk(mask, dim=1, chunks=3) F_in_norm = self.innorm(F_in) F_p = self.p_conv(img_p) F_mix = torch.cat([F_p, F_in_norm], dim=1) F_conv = self.f_conv(F_mix) D_s = self.sigmoid(F_conv * (1 - mask) + mask) gamma_sp = self.gamma_conv(img_p) beta_sp = self.beta_conv(img_p) gamma_sd = gamma_sp * D_s beta_sd = beta_sp * D_s f_in = F_in_norm * gamma_sd F_soft = f_in + beta_sd return F_soft, mask class ResBlockSPDNorm(nn.Module): def __init__(self, F_in_nc, p_input_nc, n_ds, n, k): super(ResBlockSPDNorm, self).__init__() self.HardSPDNorm = HardSPDNorm(n, k, p_input_nc, F_in_nc) self.SoftSPDNorm = SoftSPDNorm(p_input_nc, F_in_nc) self.relu = nn.ReLU() self.Conv1 = Gated_Conv(F_in_nc, F_in_nc, stride=1, padding=1, kernel_size=3) self.Conv2 = Gated_Conv(F_in_nc, F_in_nc, stride=1, padding=1, kernel_size=3) self.Conv3 = Gated_Conv(F_in_nc, F_in_nc, stride=1, padding=1, kernel_size=3) self.n_ds = n_ds def forward(self, F_in, img_p, mask): # the HardSPDNorm img_p_re = F.interpolate( img_p, size=F_in.size()[2:], mode="nearest" ) mask_re = F.interpolate(mask, size=F_in.size()[2:], mode="nearest") out_h1 = self.HardSPDNorm(F_in, img_p_re, mask_re, self.n_ds) out_h1 = self.relu(out_h1) out_h1 = self.Conv1(out_h1) out_h = self.HardSPDNorm(out_h1, img_p_re, mask_re, self.n_ds) out_h = self.relu(out_h) out_h = self.Conv2(out_h) # the SoftSPDNorm out_s, mask = self.SoftSPDNorm(F_in, img_p_re, mask_re, self.n_ds) out_s = self.relu(out_s) out_s = self.Conv3(out_s) # output out = out_h + out_s return out, mask class Gated_Conv(nn.Module): """ Gated Convolution with spetral normalization """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(Gated_Conv, self).__init__() self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.activation = activation self.batch_norm = batch_norm self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) self.sigmoid = torch.nn.Sigmoid() self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def gated(self, mask): return self.sigmoid(mask) #return torch.clamp(mask, -1, 1) def forward(self, input): x = self.conv2d(input) mask = self.mask_conv2d(input) if self.activation is not None: x = self.activation(x) * self.gated(mask) else: x = x * self.gated(mask) if self.batch_norm: return self.batch_norm2d(x) else: return x class SPDNormResnetBlock(nn.Module): def __init__(self, fin, fout, mask_number, mask_ks): super().__init__() nhidden = 128 fmiddle = min(fin, fout) lable_nc = 3 # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) self.learned_shortcut = True # apply spectral norm if specified self.conv_0 = spectral_norm(self.conv_0) self.conv_1 = spectral_norm(self.conv_1) if self.learned_shortcut: self.conv_s = spectral_norm(self.conv_s) # define normalization layers self.norm_0 = SPDNorm(fin, norm_type="position") self.norm_1 = SPDNorm(fmiddle, norm_type="position") self.norm_s = SPDNorm(fin, norm_type="position") # define the mask weight self.v = nn.Parameter(torch.zeros(1)) self.activeweight = nn.Sigmoid() # define the feature and mask process conv self.mask_number = mask_number self.mask_ks = mask_ks pw_mask = int(np.ceil((self.mask_ks - 1.0) / 2)) self.mask_conv = MaskGet(1, 1, kernel_size=self.mask_ks, padding=pw_mask) self.conv_to_f = nn.Sequential( Gated_Conv(lable_nc * 2, nhidden, kernel_size=3, padding=1), nn.InstanceNorm2d(nhidden), nn.ReLU(), Gated_Conv(nhidden, fin, kernel_size=3, padding=1), ) self.attention = nn.Sequential( nn.Conv2d(fin * 2, fin, kernel_size=3, padding=1), nn.Sigmoid() ) # the semantic feature prior_f form pretrained encoder def forward(self, x, prior_image, mask): """ Args: x: input feature prior_image: the output of PCConv mask: mask """ # prepare the forward prior_features = [] b, c, h, w = x.size() prior_image_resize = F.interpolate( prior_image, size=x.size()[2:], mode="nearest" ) mask_resize = F.interpolate(mask, size=x.size()[2:], mode="nearest") prior_image_resize = torch.cat([prior_image_resize, mask_resize], dim=1) mask_resize, _, _ = torch.chunk(mask_resize, dim=1, chunks=3) # Mask Original and Res path res weight/ value attention prior_feature = self.conv_to_f(prior_image_resize) prior_features.append(prior_feature) soft_map = self.attention(torch.cat([prior_feature, x], 1)) # print(soft_map.shape) # print(mask_resize.shape) soft_map = (1 - mask_resize) * soft_map + mask_resize # Mask weight for next process mask_pre = mask_resize hard_map = 0.0 for i in range(self.mask_number): mask_out = self.mask_conv(mask_pre) mask_generate = (mask_out - mask_pre) * ( 1 / (torch.exp(torch.tensor(i + 1).cuda())) ) mask_pre = mask_out hard_map = hard_map + mask_generate hard_map_inner = (1 - mask_out) * (1 / (torch.exp(torch.tensor(i + 1).cuda()))) hard_map = hard_map + mask_resize + hard_map_inner soft_out = self.conv_s(self.norm_s(x, prior_image_resize, soft_map)) hard_out = self.conv_0(self.actvn(self.norm_0(x, prior_image_resize, hard_map))) hard_out = self.conv_1( self.actvn(self.norm_1(hard_out, prior_image_resize, hard_map)) ) # Res add out = soft_out + hard_out return out, prior_features def actvn(self, x): return F.leaky_relu(x, 2e-1) def PositionalNorm2d(x, epsilon=1e-5): # x: B*C*W*H normalize in C dim mean = x.mean(dim=1, keepdim=True) std = x.var(dim=1, keepdim=True).add(epsilon).sqrt() output = (x - mean) / std return output class SPDNorm(nn.Module): def __init__(self, norm_channel, norm_type="batch"): super().__init__() label_nc = 3 param_free_norm_type = norm_type ks = 3 if param_free_norm_type == "instance": self.param_free_norm = nn.InstanceNorm2d(norm_channel, affine=False) elif param_free_norm_type == "batch": self.param_free_norm = nn.BatchNorm2d(norm_channel, affine=False) elif param_free_norm_type == "position": self.param_free_norm = PositionalNorm2d else: raise ValueError( "%s is not a recognized param-free norm type in SPADE" % param_free_norm_type ) # The dimension of the intermediate embedding space. Yes, hardcoded. pw = ks // 2 nhidden = 128 self.mlp_activate = nn.Sequential( Gated_Conv(label_nc * 2, nhidden, kernel_size=ks, padding=pw), nn.ReLU() ) self.mlp_gamma = nn.Conv2d(nhidden, norm_channel, kernel_size=ks, padding=pw) self.mlp_beta = nn.Conv2d(nhidden, norm_channel, kernel_size=ks, padding=pw) def forward(self, x, prior_f, weight): normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on condition feature actv = self.mlp_activate(prior_f) gamma = self.mlp_gamma(actv) * weight beta = self.mlp_beta(actv) * weight # apply scale and bias out = normalized * (1 + gamma) + beta return out class MaskGet(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super().__init__() self.mask_conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False, ) torch.nn.init.constant_(self.mask_conv.weight, 1.0) def forward(self, input): # depart from partial conv # hole region should sed to 0 with torch.no_grad(): output_mask = self.mask_conv(input) no_update_holes = output_mask == 0 new_mask = torch.ones_like(output_mask) new_mask = new_mask.masked_fill_(no_update_holes.bool(), 0.0) return new_mask class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) elif pool_type=='lp': lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( lp_pool ) elif pool_type=='lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp( lse_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = torch.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial = no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out class MyAttention(nn.Module): """ Layer implements my version of the self-attention module it is mostly same as self attention, but generalizes to (k x k) convolutions instead of (1 x 1) args: in_channels: number of input channels out_channels: number of output channels activation: activation function to be applied (default: lrelu(0.2)) kernel_size: kernel size for convolution (default: (1 x 1)) transpose_conv: boolean denoting whether to use convolutions or transpose convolutions squeeze_factor: squeeze factor for query and keys (default: 8) stride: stride for the convolutions (default: 1) padding: padding for the applied convolutions (default: 1) bias: whether to apply bias or not (default: True) """ def __init__(self, in_channels, out_channels, activation=None, kernel_size=(3, 3), transpose_conv=False, use_spectral_norm=True, use_batch_norm=True, squeeze_factor=8, stride=1, padding=1, bias=True): """ constructor for the layer """ from torch.nn import Conv2d, Parameter, \ Softmax, ConvTranspose2d, BatchNorm2d # base constructor call super(MyAttention, self).__init__() # state of the layer self.activation = activation self.simam = SimAM() self.cbam = CBAM(128) self.gamma = Parameter(torch.zeros(1)) self.beta = Parameter(torch.zeros(1)) self.alpha = Parameter(torch.zeros(1)) self.in_channels = in_channels self.out_channels = out_channels self.squeezed_channels = in_channels // squeeze_factor self.use_batch_norm = use_batch_norm # Modules required for computations if transpose_conv: self.query_conv = ConvTranspose2d( # query convolution in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.key_conv = ConvTranspose2d( in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.value_conv = ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.residual_conv = ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) if not use_spectral_norm else SpectralNorm( ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) ) else: self.query_conv = Conv2d( # query convolution in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.key_conv = Conv2d( in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.value_conv = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.residual_conv = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) if not use_spectral_norm else SpectralNorm( Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) ) # softmax module for applying attention self.softmax = Softmax(dim=-1) self.batch_norm = BatchNorm2d(out_channels) def forward(self, x, f): """ forward computations of the layer :param x: input feature maps (B x C x H x W) :return: out: self attention value + input feature (B x O x H x W) attention: attention map (B x C x H x W) """ # extract the batch size of the input tensor m_batchsize, _, _, _ = x.size() # print(f.shape) # create the query projection proj_query = self.query_conv(x).view( m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C # create the key projection proj_key = self.key_conv(x).view( m_batchsize, self.squeezed_channels, -1) # B x C x (N) # calculate the attention maps energy = torch.bmm(proj_query, proj_key) # energy attention = self.softmax(energy) # attention (B x (N) x (N)) # create the value projection proj_value = self.value_conv(x).view( m_batchsize, self.out_channels, -1) # B X C X N # calculate the output out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # calculate the residual output res_out = self.residual_conv(x) out = out.view(m_batchsize, self.out_channels, res_out.shape[-2], res_out.shape[-1]) attention = attention.view(m_batchsize, -1, res_out.shape[-2], res_out.shape[-1]) if self.use_batch_norm: res_out = self.batch_norm(res_out) if self.activation is not None: out = self.activation(out) res_out = self.activation(res_out) # apply the residual connections # from torchvision import transforms sm = self.simam(f) # toPIL = transforms.ToPILImage() # 这个函数可以将张量转为PIL图片,由小数转为0-255之间的像素值 # print(cbam.chunk(2)[-1].view(128, 32, 32).chunk(128)[-1]) # pic = toPIL(cbam.chunk(2)[-1].view(128, 32, 32).chunk(128)[-1]) # pic.save('cbam.jpg') # print(cbam) # pic = toPIL(sm.chunk(2)[-1].view(128, 32, 32).chunk(128)[-1]) # pic.save('sm1.jpg') sm[:, :, :5, :] = 0 sm[:, :, :, :5] = 0 sm[:, :, :, 27:] = 0 sm[:, :, 27:, :] = 0 cbam = self.cbam(f) # print(cbam.type()) # print(sm.type()) map_mean = cbam + sm / 2.0 # from torchvision import transforms # toPIL = transforms.ToPILImage() # 这个函数可以将张量转为PIL图片,由小数转为0-255之间的像素值 # print(sm.chunk(2)[-1].view(128, 32, 32).chunk(128)[-1]) # pic = toPIL(sm.chunk(2)[-1].view(128, 32, 32).chunk(128)[-1]) # pic.save('map.jpg') # out = ((self.gamma * out) + ((1 - self.gamma) * res_out) * (1 - self.beta)) + self.beta * sm # out = self.gamma * out + (1 - self.gamma) * res_out # out = self.gamma * out + (1 - self.gamma) * sm # out = (self.gamma * out + (1 - self.gamma) * res_out) * sm * self.beta o1 = self.gamma * out + (1 - self.gamma) * res_out out = o1 * self.beta + (1 - self.beta) * map_mean * self.alpha # print(cbam) # out = self.gamma * out + (1 - self.gamma) * cbam return out, attention