diff --git a/chem_trans.py b/chem_trans.py deleted file mode 100644 index 69e7327..0000000 --- a/chem_trans.py +++ /dev/null @@ -1,24 +0,0 @@ -from Encoder import Encoder -import torch -import torch.nn as nn -import config - - -class Chem_trans(nn.Module): - def __init__(self,): - super(Chem_trans, self).__init__() - self.Encoder = Encoder() - self.projection = nn.Linear(2, 2, bias=True) - self.conv1d = nn.Conv1d(in_channels=config.input_dim, out_channels=2, kernel_size=60, stride=2, padding=0) - - - def forward(self,enc_inputs,CCO,enc): - enc_outputs,enc= self.Encoder(enc_inputs,CCO,enc) - dec=torch.cat((enc_outputs,enc),dim=1) - - dec_logits=self.conv1d(dec.transpose(1,2)) - - dec_logits = self.projection(dec_logits.transpose(1,2)) - #dec_logits = self.projection1(dec_logits) - return dec_logits.view(-1, dec_logits.size(-1)) -