You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

52 lines
1.9 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/16 12:31
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
import torch
from torch import nn
from torch import Tensor
import math
class PositionalEmbedding(nn.Module):
"""Adding the positional encoding to the input for Transformer"""
def __init__(self, d_model: int, max_len: int = 5000) -> None:
super(PositionalEmbedding, self).__init__()
# Calculate the positional encoding once in the logarithmic space.
pe = torch.zeros(
max_len, d_model
).float() # Initialize a tensor of zeros with shape (max_len, d_model) to store positional encodings
pe.requires_grad = (
False # Positional encodings do not require gradients as they are fixed
)
position = (
torch.arange(0, max_len).float().unsqueeze(1)
) # Generate a sequence from 0 to max_len-1 and add a dimension at the 1st axis
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp() # Calculate the divisor term in the positional encoding formula
pe[:, 0::2] = torch.sin(
position * div_term
) # Apply the sine function to the even columns of the positional encoding matrix
pe[:, 1::2] = torch.cos(
position * div_term
) # Apply the cosine function to the odd columns of the positional encoding matrix
pe = pe.unsqueeze(
0
) # Add a batch dimension, changing the shape to (1, max_len, d_model)
self.register_buffer(
"pe", pe
) # Register the positional encodings as a buffer, which will not be updated as model parameters
def forward(self, x: Tensor) -> Tensor:
# Return the first max_len positional encodings that match the length of input x
return x + self.pe[:, : x.size(1)]