from Encoder import Encoder import torch import torch.nn as nn import config class Chem_trans(nn.Module): def __init__(self,): super(Chem_trans, self).__init__() self.Encoder = Encoder() self.projection = nn.Linear(2, 2, bias=True) self.conv1d = nn.Conv1d(in_channels=config.input_dim, out_channels=2, kernel_size=config.kernal_size, stride=config.stride, padding=0) def forward(self,enc_inputs,CCO,enc): enc_outputs,enc= self.Encoder(enc_inputs,CCO,enc) dec=torch.cat((enc_outputs,enc),dim=1) dec_logits=self.conv1d(dec.transpose(1,2)) dec_logits = self.projection(dec_logits.transpose(1,2)) return dec_logits.view(-1, dec_logits.size(-1))