forked from hnu202409060624/python
parent
ce1fe2a73d
commit
c1370d0e81
@ -1,77 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from attention import MultiHeadAttention
|
|
||||||
import config
|
|
||||||
from Feed_Forward import PoswiseFeedForwardNet
|
|
||||||
from torch_geometric.nn import GCNConv
|
|
||||||
import math
|
|
||||||
|
|
||||||
def get_attn_pad_mask(seq_q, seq_k):
|
|
||||||
batch_size, len_q = seq_q.size()
|
|
||||||
batch_size, len_k = seq_k.size()
|
|
||||||
# eq(zero) is PAD token
|
|
||||||
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
|
|
||||||
# 扩展成多维度
|
|
||||||
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
|
|
||||||
|
|
||||||
|
|
||||||
def get_sinusoid_encoding_table(max_len, d_model):
|
|
||||||
# 创建一个位置编码表,大小为 [max_len, d_model]
|
|
||||||
position_enc = torch.zeros(max_len, d_model)
|
|
||||||
# 为每个位置生成编码
|
|
||||||
for pos in range(max_len):
|
|
||||||
for i in range(0, d_model, 2):
|
|
||||||
position_enc[pos, i] = math.sin(pos / (10000 ** (2 * i / d_model)))
|
|
||||||
position_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
|
|
||||||
return position_enc
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(EncoderLayer, self).__init__()
|
|
||||||
self.conv = GCNConv(config.embedding_dim, config.embedding_dim, normalize=True,bias=config.bias,aggr='mean')
|
|
||||||
self.conv1 = GCNConv(config.embedding_dim, config.embedding_dim, normalize=True, bias=config.bias, aggr='mean')
|
|
||||||
self.conv2 = GCNConv(config.embedding_dim, config.embedding_dim, normalize=True, bias=config.bias, aggr='mean')
|
|
||||||
self.enc_feed_forward1 = PoswiseFeedForwardNet()
|
|
||||||
self.enc_feed_forward2=PoswiseFeedForwardNet()
|
|
||||||
self.Model_list=nn.ModuleList([MultiHeadAttention() for _ in range(4)])
|
|
||||||
def forward(self, enc_inputs,enc2,enc_self_attn_mask,edge_index): # enc_inputs: [batch_size, src_len, d_model]
|
|
||||||
# 输入3个enc_inputs分别与W_q、W_k、W_v相乘得到Q、K、V
|
|
||||||
enc_outputs=self.conv(enc_inputs,edge_index)
|
|
||||||
enc_outputs=self.enc_feed_forward2(enc_outputs)
|
|
||||||
enc_outputs=self.conv1(enc_outputs,edge_index)
|
|
||||||
enc_outputs=self.enc_feed_forward2(enc_outputs)
|
|
||||||
enc_outputs=self.conv2(enc_outputs,edge_index)
|
|
||||||
enc_outputs=self.enc_feed_forward2(enc_outputs)
|
|
||||||
|
|
||||||
attn=0
|
|
||||||
for i in self.Model_list:
|
|
||||||
enc2,attn=i(enc2,enc2,enc2,enc_self_attn_mask)
|
|
||||||
enc2 = self.enc_feed_forward1(enc2)
|
|
||||||
return enc_outputs,enc2, attn
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Encoder, self).__init__()
|
|
||||||
self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
||||||
self.embedding1 = nn.Embedding(5, 1024)
|
|
||||||
self.attention = MultiHeadAttention()
|
|
||||||
self.pos_ffn = PoswiseFeedForwardNet()
|
|
||||||
self.layers = nn.ModuleList([EncoderLayer() for _ in range(config.Encoder_n_layers)])
|
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
|
||||||
|
|
||||||
def forward(self, enc_inputs,edge_index,enc):
|
|
||||||
enc=self.embedding1(enc)
|
|
||||||
enc_outputs=self.embedding(enc_inputs)
|
|
||||||
atoms_enc_self_attns1 = []
|
|
||||||
enc_self_attn_mask = get_attn_pad_mask(enc.squeeze(0),
|
|
||||||
enc.squeeze(0)) # enc_self_attn_mask: [batch_size, src_len, src_len]
|
|
||||||
enc_outputs=enc_outputs.squeeze(0)
|
|
||||||
edge_index=edge_index.squeeze(0)
|
|
||||||
enc_outputs=enc_outputs.unsqueeze(0)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
enc_outputs,enc, attn = layer(enc_outputs,enc,enc_self_attn_mask,edge_index) # enc_outputs : [batch_size, src_len, d_model],
|
|
||||||
atoms_enc_self_attns1.append(attn)
|
|
||||||
|
|
||||||
return enc_outputs,enc
|
|
||||||
|
|
Loading…
Reference in new issue