You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
10 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/16 17:22
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
import numpy as np
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from einops import rearrange
from layers import Transpose, get_activation_fn
from layers import PositionalEmbedding
from typing import Optional, Union, Tuple
class TSTEncoder(nn.Module):
"""Time series encoder backbone of SymTime"""
def __init__(
self,
patch_len: int = 16,
n_layers: int = 3,
d_model: int = 128,
n_heads: int = 16,
d_k: int = None,
d_v: int = None,
d_ff: int = 256,
norm: str = "BatchNorm",
attn_dropout: float = 0.0,
dropout: float = 0.0,
act: str = "gelu",
store_attn: bool = False,
pre_norm: bool = False,
forward_layers: int = 6,
) -> None:
super().__init__()
self.forward_layers = forward_layers
self.W_P = nn.Linear(patch_len, d_model)
# Positional encoding
self.pe = PositionalEmbedding(d_model=d_model)
# Residual dropout
self.dropout = nn.Dropout(dropout)
# Create the [CLS] token
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.cls_mask = nn.Parameter(torch.ones(1, 1).bool(), requires_grad=False)
# Create the encoder layer of the model backbone
self.layers = nn.ModuleList(
[
TSTEncoderLayer(
d_model=d_model,
n_heads=n_heads,
d_k=d_k,
d_v=d_v,
d_ff=d_ff,
norm=norm,
attn_dropout=attn_dropout,
dropout=dropout,
activation=act,
pre_norm=pre_norm,
store_attn=store_attn,
)
for _ in range(n_layers)
]
)
# model params init
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
"""model params init through apply methods"""
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
x: Tensor, # x: [batch_size, patch_num, patch_len]
attn_mask: Optional[Tensor] = None, # attn_mask: [batch, num_patch]
) -> Tensor:
batch_size = x.size(0)
# Input patching embedding
x = self.W_P(x) # x: [batch_size, patch_num, model_dim]
# Add the [CLS] token
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
# adjust the attn mask
if attn_mask is not None:
attn_mask = torch.cat(
[self.cls_mask.expand(batch_size, -1), attn_mask], dim=1
)
# Add the positional embedding
x = self.pe(x)
x = self.dropout(x) # x: [batch_size, patch_num, d_model]
for mod in self.layers[: self.forward_layers]:
x = mod(x, attn_mask=attn_mask)
return x
class TSTEncoderLayer(nn.Module):
"""Patch-based Transformer module sublayer"""
def __init__(
self,
d_model: int,
n_heads: int,
d_k: int = None,
d_v: int = None,
d_ff: int = 256,
store_attn: int = False,
norm: str = "BatchNorm",
attn_dropout: float = 0.0,
dropout: float = 0.0,
bias: bool = True,
activation: str = "gelu",
pre_norm: bool = False,
) -> None:
super(TSTEncoderLayer, self).__init__()
assert (
not d_model % n_heads
), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
# If not specified, the number of heads is divided
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
# Create the multi-head attention
self.self_attn = MultiHeadAttention(
d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout
)
# Add & Norm
self.dropout_attn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_attn = nn.Sequential(
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
)
else:
self.norm_attn = nn.LayerNorm(d_model)
# Position-wise Feed-Forward
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff, bias=bias),
get_activation_fn(activation),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=bias),
)
# Add & Norm
self.dropout_ffn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_ffn = nn.Sequential(
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
)
else:
self.norm_ffn = nn.LayerNorm(d_model)
# use pre-norm or not
self.pre_norm = pre_norm
self.store_attn = store_attn
self.attn = None
def forward(
self, src: Tensor, attn_mask: Optional[Tensor] = None
) -> Union[Tuple[Tensor, Tensor], Tensor]:
"""Multi-Head attention sublayer"""
# Whether to use pre-norm for attention layer
if self.pre_norm:
src = self.norm_attn(src)
# Multi-Head attention
src2, attn = self.self_attn(src, src, src, attn_mask=attn_mask)
if self.store_attn:
self.attn = attn
# Add: residual connection with residual dropout
src = src + self.dropout_attn(src2)
if not self.pre_norm:
src = self.norm_attn(src)
# Whether to use pre-norm for ffn layer
if self.pre_norm:
src = self.norm_ffn(src)
# Position-wise Feed-Forward
src2 = self.ff(src)
# Add: residual connection with residual dropout
src = src + self.dropout_ffn(src2)
if not self.pre_norm:
src = self.norm_ffn(src)
return src
class MultiHeadAttention(nn.Module):
"""Multi-head attention mechanism layer"""
def __init__(
self,
d_model: int,
n_heads: int,
d_k: int = None,
d_v: int = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
qkv_bias: bool = True,
) -> None:
"""Multi Head Attention Layer
Input shape:
Q: [batch_size (bs) x max_q_len x d_model]
K, V: [batch_size (bs) x q_len x d_model]
mask: [q_len x q_len]
"""
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.sdp_attn = _ScaledDotProductAttention(
d_model, n_heads, attn_dropout=attn_dropout
)
# Project output
self.to_out = nn.Sequential(
nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)
)
def forward(
self,
q: Tensor,
k: Optional[Tensor] = None,
v: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
):
bs = q.size(0)
if k is None:
k = q
if v is None:
v = q
# Linear (+ split in multiple heads)
q_s = self.W_Q(q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
k_s = self.W_K(k).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1)
v_s = self.W_V(v).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2)
# Apply Scaled Dot-Product Attention (multiple heads)
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, attn_mask=attn_mask)
# back to the original inputs dimensions
output = (
output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v)
)
output = self.to_out(output)
return output, attn_weights
class _ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_dropout: float = 0.0,
res_attention: bool = False,
):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False)
def forward(
self, q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""
:param q: [batch_size, n_heads, num_token, d_k]
:param k: [batch_size, n_heads, d_k, num_token]
:param v: [batch_size, n_heads, num_token, d_k]
:param attn_mask: [batch_size, n_heads, num_token]
"""
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale
# Attention mask (optional)
if (
attn_mask is not None
): # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
attn_mask = rearrange(attn_mask, "b i -> b 1 i 1") * rearrange(
attn_mask, "b i -> b 1 1 i"
)
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v)
return output, attn_weights