You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
75 lines
2.3 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/16 10:29
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
from typing import Tuple, Union, Callable
|
|
|
|
|
|
class Transpose(nn.Module):
|
|
"""Transpose the dimensions of the input tensor"""
|
|
|
|
def __init__(self, *dims, contiguous=False) -> None:
|
|
super().__init__()
|
|
self.dims, self.contiguous = dims, contiguous
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
if self.contiguous:
|
|
return x.transpose(*self.dims).contiguous()
|
|
else:
|
|
return x.transpose(*self.dims)
|
|
|
|
|
|
def get_batch_norm(d_model: int) -> nn.Module:
|
|
"""Get the BatchNorm module for processing the attention mechanism"""
|
|
return nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2))
|
|
|
|
|
|
def get_activation_fn(activation: Union[str, Callable]) -> nn.Module:
|
|
"""Select the activation function to use."""
|
|
if callable(activation):
|
|
return activation()
|
|
elif activation.lower() == "relu":
|
|
return nn.ReLU()
|
|
elif activation.lower() == "gelu":
|
|
return nn.GELU()
|
|
raise ValueError(
|
|
f'{activation} is not available. You can use "relu", "gelu", or a callable'
|
|
)
|
|
|
|
|
|
class moving_avg(nn.Module):
|
|
"""Moving average block to highlight the trend of time series"""
|
|
|
|
def __init__(self, kernel_size: int = 25, stride: int = 1) -> None:
|
|
super(moving_avg, self).__init__()
|
|
self.kernel_size = kernel_size
|
|
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# padding on the both ends of time series
|
|
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
x = torch.cat([front, x, end], dim=1)
|
|
x = self.avg(x.permute(0, 2, 1))
|
|
x = x.permute(0, 2, 1)
|
|
return x
|
|
|
|
|
|
class series_decomp(nn.Module):
|
|
"""Series decomposition instance block"""
|
|
|
|
def __init__(self, kernel_size: int) -> None:
|
|
super(series_decomp, self).__init__()
|
|
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
|
|
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
|
|
moving_mean = self.moving_avg(x)
|
|
res = x - moving_mean
|
|
return res, moving_mean
|