You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
332 lines
10 KiB
332 lines
10 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/16 17:22
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from layers import Transpose, get_activation_fn
|
|
from layers import PositionalEmbedding
|
|
from typing import Optional, Union, Tuple
|
|
|
|
|
|
class TSTEncoder(nn.Module):
|
|
"""Time series encoder backbone of SymTime"""
|
|
|
|
def __init__(
|
|
self,
|
|
patch_len: int = 16,
|
|
n_layers: int = 3,
|
|
d_model: int = 128,
|
|
n_heads: int = 16,
|
|
d_k: int = None,
|
|
d_v: int = None,
|
|
d_ff: int = 256,
|
|
norm: str = "BatchNorm",
|
|
attn_dropout: float = 0.0,
|
|
dropout: float = 0.0,
|
|
act: str = "gelu",
|
|
store_attn: bool = False,
|
|
pre_norm: bool = False,
|
|
forward_layers: int = 6,
|
|
) -> None:
|
|
super().__init__()
|
|
self.forward_layers = forward_layers
|
|
self.W_P = nn.Linear(patch_len, d_model)
|
|
|
|
# Positional encoding
|
|
self.pe = PositionalEmbedding(d_model=d_model)
|
|
|
|
# Residual dropout
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
# Create the [CLS] token
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
self.cls_mask = nn.Parameter(torch.ones(1, 1).bool(), requires_grad=False)
|
|
|
|
# Create the encoder layer of the model backbone
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
TSTEncoderLayer(
|
|
d_model=d_model,
|
|
n_heads=n_heads,
|
|
d_k=d_k,
|
|
d_v=d_v,
|
|
d_ff=d_ff,
|
|
norm=norm,
|
|
attn_dropout=attn_dropout,
|
|
dropout=dropout,
|
|
activation=act,
|
|
pre_norm=pre_norm,
|
|
store_attn=store_attn,
|
|
)
|
|
for _ in range(n_layers)
|
|
]
|
|
)
|
|
|
|
# model params init
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m: nn.Module) -> None:
|
|
"""model params init through apply methods"""
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor, # x: [batch_size, patch_num, patch_len]
|
|
attn_mask: Optional[Tensor] = None, # attn_mask: [batch, num_patch]
|
|
) -> Tensor:
|
|
batch_size = x.size(0)
|
|
|
|
# Input patching embedding
|
|
x = self.W_P(x) # x: [batch_size, patch_num, model_dim]
|
|
|
|
# Add the [CLS] token
|
|
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
|
x = torch.cat([cls_token, x], dim=1)
|
|
# adjust the attn mask
|
|
if attn_mask is not None:
|
|
attn_mask = torch.cat(
|
|
[self.cls_mask.expand(batch_size, -1), attn_mask], dim=1
|
|
)
|
|
|
|
# Add the positional embedding
|
|
x = self.pe(x)
|
|
x = self.dropout(x) # x: [batch_size, patch_num, d_model]
|
|
|
|
for mod in self.layers[: self.forward_layers]:
|
|
x = mod(x, attn_mask=attn_mask)
|
|
|
|
return x
|
|
|
|
|
|
class TSTEncoderLayer(nn.Module):
|
|
"""Patch-based Transformer module sublayer"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
n_heads: int,
|
|
d_k: int = None,
|
|
d_v: int = None,
|
|
d_ff: int = 256,
|
|
store_attn: int = False,
|
|
norm: str = "BatchNorm",
|
|
attn_dropout: float = 0.0,
|
|
dropout: float = 0.0,
|
|
bias: bool = True,
|
|
activation: str = "gelu",
|
|
pre_norm: bool = False,
|
|
) -> None:
|
|
super(TSTEncoderLayer, self).__init__()
|
|
|
|
assert (
|
|
not d_model % n_heads
|
|
), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
|
# If not specified, the number of heads is divided
|
|
d_k = d_model // n_heads if d_k is None else d_k
|
|
d_v = d_model // n_heads if d_v is None else d_v
|
|
|
|
# Create the multi-head attention
|
|
self.self_attn = MultiHeadAttention(
|
|
d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout
|
|
)
|
|
|
|
# Add & Norm
|
|
self.dropout_attn = nn.Dropout(dropout)
|
|
if "batch" in norm.lower():
|
|
self.norm_attn = nn.Sequential(
|
|
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
|
|
)
|
|
else:
|
|
self.norm_attn = nn.LayerNorm(d_model)
|
|
|
|
# Position-wise Feed-Forward
|
|
self.ff = nn.Sequential(
|
|
nn.Linear(d_model, d_ff, bias=bias),
|
|
get_activation_fn(activation),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(d_ff, d_model, bias=bias),
|
|
)
|
|
|
|
# Add & Norm
|
|
self.dropout_ffn = nn.Dropout(dropout)
|
|
if "batch" in norm.lower():
|
|
self.norm_ffn = nn.Sequential(
|
|
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
|
|
)
|
|
else:
|
|
self.norm_ffn = nn.LayerNorm(d_model)
|
|
|
|
# use pre-norm or not
|
|
self.pre_norm = pre_norm
|
|
self.store_attn = store_attn
|
|
self.attn = None
|
|
|
|
def forward(
|
|
self, src: Tensor, attn_mask: Optional[Tensor] = None
|
|
) -> Union[Tuple[Tensor, Tensor], Tensor]:
|
|
"""Multi-Head attention sublayer"""
|
|
|
|
# Whether to use pre-norm for attention layer
|
|
if self.pre_norm:
|
|
src = self.norm_attn(src)
|
|
|
|
# Multi-Head attention
|
|
src2, attn = self.self_attn(src, src, src, attn_mask=attn_mask)
|
|
if self.store_attn:
|
|
self.attn = attn
|
|
|
|
# Add: residual connection with residual dropout
|
|
src = src + self.dropout_attn(src2)
|
|
if not self.pre_norm:
|
|
src = self.norm_attn(src)
|
|
|
|
# Whether to use pre-norm for ffn layer
|
|
if self.pre_norm:
|
|
src = self.norm_ffn(src)
|
|
|
|
# Position-wise Feed-Forward
|
|
src2 = self.ff(src)
|
|
|
|
# Add: residual connection with residual dropout
|
|
src = src + self.dropout_ffn(src2)
|
|
if not self.pre_norm:
|
|
src = self.norm_ffn(src)
|
|
|
|
return src
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""Multi-head attention mechanism layer"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
n_heads: int,
|
|
d_k: int = None,
|
|
d_v: int = None,
|
|
attn_dropout: float = 0.0,
|
|
proj_dropout: float = 0.0,
|
|
qkv_bias: bool = True,
|
|
) -> None:
|
|
"""Multi Head Attention Layer
|
|
Input shape:
|
|
Q: [batch_size (bs) x max_q_len x d_model]
|
|
K, V: [batch_size (bs) x q_len x d_model]
|
|
mask: [q_len x q_len]
|
|
"""
|
|
super().__init__()
|
|
d_k = d_model // n_heads if d_k is None else d_k
|
|
d_v = d_model // n_heads if d_v is None else d_v
|
|
|
|
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
|
|
|
|
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
|
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
|
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
|
|
|
|
# Scaled Dot-Product Attention (multiple heads)
|
|
self.sdp_attn = _ScaledDotProductAttention(
|
|
d_model, n_heads, attn_dropout=attn_dropout
|
|
)
|
|
|
|
# Project output
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
q: Tensor,
|
|
k: Optional[Tensor] = None,
|
|
v: Optional[Tensor] = None,
|
|
attn_mask: Optional[Tensor] = None,
|
|
):
|
|
bs = q.size(0)
|
|
if k is None:
|
|
k = q
|
|
if v is None:
|
|
v = q
|
|
|
|
# Linear (+ split in multiple heads)
|
|
q_s = self.W_Q(q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
|
|
k_s = self.W_K(k).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1)
|
|
v_s = self.W_V(v).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2)
|
|
|
|
# Apply Scaled Dot-Product Attention (multiple heads)
|
|
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, attn_mask=attn_mask)
|
|
|
|
# back to the original inputs dimensions
|
|
output = (
|
|
output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v)
|
|
)
|
|
output = self.to_out(output)
|
|
|
|
return output, attn_weights
|
|
|
|
|
|
class _ScaledDotProductAttention(nn.Module):
|
|
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
|
|
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
|
|
by Lee et al, 2021)"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
n_heads: int,
|
|
attn_dropout: float = 0.0,
|
|
res_attention: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.attn_dropout = nn.Dropout(attn_dropout)
|
|
self.res_attention = res_attention
|
|
head_dim = d_model // n_heads
|
|
self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False)
|
|
|
|
def forward(
|
|
self, q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None
|
|
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
|
|
"""
|
|
:param q: [batch_size, n_heads, num_token, d_k]
|
|
:param k: [batch_size, n_heads, d_k, num_token]
|
|
:param v: [batch_size, n_heads, num_token, d_k]
|
|
:param attn_mask: [batch_size, n_heads, num_token]
|
|
"""
|
|
|
|
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
|
|
attn_scores = torch.matmul(q, k) * self.scale
|
|
|
|
# Attention mask (optional)
|
|
if (
|
|
attn_mask is not None
|
|
): # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
|
|
attn_mask = rearrange(attn_mask, "b i -> b 1 i 1") * rearrange(
|
|
attn_mask, "b i -> b 1 1 i"
|
|
)
|
|
if attn_mask.dtype == torch.bool:
|
|
attn_scores.masked_fill_(attn_mask, -np.inf)
|
|
else:
|
|
attn_scores += attn_mask
|
|
|
|
# normalize the attention weights
|
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
attn_weights = self.attn_dropout(attn_weights)
|
|
|
|
# compute the new values given the attention weights
|
|
output = torch.matmul(attn_weights, v)
|
|
|
|
return output, attn_weights
|