parent
							
								
									405831f45b
								
							
						
					
					
						commit
						1ff05db5fc
					
				@ -0,0 +1,24 @@
 | 
				
			||||
from Encoder import Encoder
 | 
				
			||||
import torch
 | 
				
			||||
import torch.nn as nn
 | 
				
			||||
import config
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class Chem_trans(nn.Module):
 | 
				
			||||
    def __init__(self,):
 | 
				
			||||
        super(Chem_trans, self).__init__()
 | 
				
			||||
        self.Encoder = Encoder()
 | 
				
			||||
        self.projection = nn.Linear(2, 2, bias=True)
 | 
				
			||||
        self.conv1d = nn.Conv1d(in_channels=config.input_dim, out_channels=2, kernel_size=60, stride=2, padding=0)
 | 
				
			||||
        
 | 
				
			||||
 | 
				
			||||
    def forward(self,enc_inputs,CCO,enc):
 | 
				
			||||
        enc_outputs,enc= self.Encoder(enc_inputs,CCO,enc)
 | 
				
			||||
        dec=torch.cat((enc_outputs,enc),dim=1)
 | 
				
			||||
 | 
				
			||||
        dec_logits=self.conv1d(dec.transpose(1,2))
 | 
				
			||||
 | 
				
			||||
        dec_logits = self.projection(dec_logits.transpose(1,2))
 | 
				
			||||
        #dec_logits = self.projection1(dec_logits)
 | 
				
			||||
        return dec_logits.view(-1, dec_logits.size(-1))
 | 
				
			||||
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue