forked from hnu202409060624/python
parent
71e6215535
commit
a994a9a82b
@ -0,0 +1,24 @@
|
|||||||
|
from Encoder import Encoder
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
class Chem_trans(nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super(Chem_trans, self).__init__()
|
||||||
|
self.Encoder = Encoder()
|
||||||
|
self.projection = nn.Linear(2, 2, bias=True)
|
||||||
|
self.conv1d = nn.Conv1d(in_channels=config.input_dim, out_channels=2, kernel_size=config.kernal_size, stride=config.stride, padding=0)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,enc_inputs,CCO,enc):
|
||||||
|
enc_outputs,enc= self.Encoder(enc_inputs,CCO,enc)
|
||||||
|
dec=torch.cat((enc_outputs,enc),dim=1)
|
||||||
|
|
||||||
|
dec_logits=self.conv1d(dec.transpose(1,2))
|
||||||
|
|
||||||
|
dec_logits = self.projection(dec_logits.transpose(1,2))
|
||||||
|
|
||||||
|
return dec_logits.view(-1, dec_logits.size(-1))
|
||||||
|
|
Loading…
Reference in new issue