You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.9 KiB
52 lines
1.9 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/16 12:31
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
import math
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
"""Adding the positional encoding to the input for Transformer"""
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000) -> None:
|
|
super(PositionalEmbedding, self).__init__()
|
|
|
|
# Calculate the positional encoding once in the logarithmic space.
|
|
pe = torch.zeros(
|
|
max_len, d_model
|
|
).float() # Initialize a tensor of zeros with shape (max_len, d_model) to store positional encodings
|
|
pe.requires_grad = (
|
|
False # Positional encodings do not require gradients as they are fixed
|
|
)
|
|
|
|
position = (
|
|
torch.arange(0, max_len).float().unsqueeze(1)
|
|
) # Generate a sequence from 0 to max_len-1 and add a dimension at the 1st axis
|
|
div_term = (
|
|
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
|
).exp() # Calculate the divisor term in the positional encoding formula
|
|
|
|
pe[:, 0::2] = torch.sin(
|
|
position * div_term
|
|
) # Apply the sine function to the even columns of the positional encoding matrix
|
|
pe[:, 1::2] = torch.cos(
|
|
position * div_term
|
|
) # Apply the cosine function to the odd columns of the positional encoding matrix
|
|
|
|
pe = pe.unsqueeze(
|
|
0
|
|
) # Add a batch dimension, changing the shape to (1, max_len, d_model)
|
|
self.register_buffer(
|
|
"pe", pe
|
|
) # Register the positional encodings as a buffer, which will not be updated as model parameters
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# Return the first max_len positional encodings that match the length of input x
|
|
return x + self.pe[:, : x.size(1)]
|