From 1ff05db5fcfe146b3783c51182521e307dbb7e0f Mon Sep 17 00:00:00 2001 From: hnu202409060624 <2804411502@qq.com> Date: Mon, 30 Dec 2024 18:25:38 +0800 Subject: [PATCH] ADD file via upload --- chem_trans.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 chem_trans.py diff --git a/chem_trans.py b/chem_trans.py new file mode 100644 index 0000000..69e7327 --- /dev/null +++ b/chem_trans.py @@ -0,0 +1,24 @@ +from Encoder import Encoder +import torch +import torch.nn as nn +import config + + +class Chem_trans(nn.Module): + def __init__(self,): + super(Chem_trans, self).__init__() + self.Encoder = Encoder() + self.projection = nn.Linear(2, 2, bias=True) + self.conv1d = nn.Conv1d(in_channels=config.input_dim, out_channels=2, kernel_size=60, stride=2, padding=0) + + + def forward(self,enc_inputs,CCO,enc): + enc_outputs,enc= self.Encoder(enc_inputs,CCO,enc) + dec=torch.cat((enc_outputs,enc),dim=1) + + dec_logits=self.conv1d(dec.transpose(1,2)) + + dec_logits = self.projection(dec_logits.transpose(1,2)) + #dec_logits = self.projection1(dec_logits) + return dec_logits.view(-1, dec_logits.size(-1)) +