You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
13 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/30 21:27
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
# from functools import partial
import numpy as np
import torch
from torch import nn
from layers import TSTEncoder
from layers import Flatten_Heads
from layers import series_decomp
class SymTime(nn.Module):
"""Network architecture used for fine-tuning downstream tasks"""
def __init__(self, args, configs) -> None:
super().__init__()
# Downstream tasks to be completed
self.task_name = args.task_name
self.patch_len = args.patch_len
self.stride = args.stride
self.padding_patch = args.padding_patch
# Calculate the number of patches that can be divided.
self.patch_num = int((args.seq_len - self.patch_len) / self.stride + 1)
if self.padding_patch is True:
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.patch_num += 1
# input and output sequence length
self.seq_len = args.seq_len
self.pred_len = args.pred_len
self.n_layers = configs["time_layers"]
self.forward_layers = args.forward_layers
self.d_model = configs["d_model"]
self.n_heads = configs["n_heads"]
self.d_ff = configs["d_ff"]
# individual output for the final forecasting heads
self.individual = args.individual
self.pretrain_path = args.pretrain_path
# the dropout for finally outputs
self.out_dropout = args.out_dropout
# An encoder for creating time series data
self.time_encoder = TSTEncoder(
patch_len=self.patch_len,
n_layers=self.n_layers,
d_model=self.d_model,
n_heads=self.n_heads,
d_ff=self.d_ff,
norm=configs["norm"],
attn_dropout=configs["attn_dropout"],
dropout=configs["dropout"],
act=configs["act"],
pre_norm=configs["pre_norm"],
forward_layers=self.forward_layers,
)
# load the pre-training params
self.load_pretrained()
# freeze some Transformer layers in time encoder
for name, param in self.time_encoder.named_parameters():
# traverse the number of layers that need to be frozen
for index in range(self.forward_layers, self.n_layers):
if f"layers.{index}" in name:
param.requires_grad = False
# time series seasonal decompsition
self.use_avg = args.use_avg
if self.use_avg is True:
self.decompsition = series_decomp(kernel_size=args.moving_avg)
# trend projection alone
self.projection_trend = nn.Linear(
in_features=self.seq_len,
out_features=(
args.pred_len if "forecast" in self.task_name else self.seq_len
),
)
# Develop an interface module for handling downstream tasks
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
self.flatten_head = Flatten_Heads(
individual=self.individual,
n_vars=args.enc_in,
patch_num=self.patch_num,
nf=self.d_model,
targets_window=args.pred_len,
head_dropout=self.out_dropout,
)
elif self.task_name == "classification":
if args.conv1d is True:
self.conv1d = nn.Conv1d(
in_channels=args.enc_in,
out_channels=args.out_channels,
kernel_size=(3,),
stride=(1,),
padding=1,
)
args.enc_in = args.out_channels
self.use_conv1d = True
else:
self.use_conv1d = False
self.act = nn.GELU()
self.ln_proj = nn.LayerNorm(
self.d_model * (self.patch_num * args.enc_in + 1)
)
self.classifier = nn.Linear(
in_features=self.d_model * (self.patch_num * args.enc_in + 1),
out_features=args.num_classes,
)
elif self.task_name == "anomaly_detection":
self.flatten_head = Flatten_Heads(
individual=self.individual,
n_vars=args.enc_in,
patch_num=self.patch_num,
nf=self.d_model,
targets_window=self.seq_len,
head_dropout=self.out_dropout,
)
elif self.task_name == "imputation":
self.flatten_head = Flatten_Heads(
individual=self.individual,
n_vars=args.enc_in,
patch_num=self.patch_num,
nf=self.d_model,
targets_window=self.seq_len,
head_dropout=self.out_dropout,
cls_token=True,
)
else:
raise ValueError("task name wrong!")
def forward(self, x_enc: torch.Tensor) -> torch.Tensor:
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
x_dec = self.forcast(x_enc=x_enc)
elif self.task_name == "classification":
x_dec = self.classification(x_enc=x_enc)
elif self.task_name == "imputation":
x_dec = self.imputation(x_enc=x_enc)
else:
x_dec = self.anomaly_detection(x_enc=x_enc)
return x_dec
def forcast(self, x_enc: torch.Tensor) -> torch.Tensor:
"""Forward for long and short term forecasting"""
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# time series decompsition
if self.use_avg is True:
seasonal_part, trend_part = self.decompsition(x_enc)
x_enc = seasonal_part.permute(0, 2, 1)
# Mapping trend part to target length
trend_part = trend_part.permute(0, 2, 1)
trend_part = self.projection_trend(trend_part)
trend_part = trend_part.permute(0, 2, 1)
else:
x_enc = x_enc.permute(0, 2, 1)
# do patching
x_enc = self.patching(ts=x_enc) # [batch_size, num_vars, patch_num, patch_len]
batch_size, num_vars, patch_num, patch_len = x_enc.size()
x_enc = torch.reshape(x_enc, [batch_size * num_vars, patch_num, patch_len])
x_dec = self.time_encoder(x_enc)
# Restore the original input form independently from the channel
x_dec = torch.reshape(
x_dec, [batch_size, num_vars, x_dec.shape[-2], x_dec.shape[-1]]
)
x_dec = self.flatten_head(x_dec).permute(
0, 2, 1
) # [batch_size, pred_len, num_vars]
# add the trend part of the decompsition
if self.use_avg is True:
x_dec = x_dec + trend_part
# De-Normalization from Non-stationary Transformer
x_dec = x_dec * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
x_dec = x_dec + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return x_dec
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
"""Forward for classification task"""
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
x_enc = x_enc.permute(0, 2, 1) # [batch_size, num_vars, seq_len]
# Adjusting the input channels through Conv1d
if self.use_conv1d is True:
x_enc = self.conv1d(x_enc) # [batch_size, out_channels, seq_len]
# do patching and reshape
x_enc = self.patching(ts=x_enc) # [batch_size, num_vars, patch_num, patch_len]
batch_size, num_vars, patch_num, patch_len = x_enc.size()
# Learning feature through the backbone of Transformer
x_enc = torch.reshape(
x_enc, shape=(batch_size, num_vars * patch_num, patch_len)
)
x_dec = self.time_encoder(x_enc)
# Output processing
x_dec = self.act(x_dec)
x_dec = torch.reshape(x_dec, shape=(batch_size, -1))
x_dec = self.ln_proj(x_dec)
outputs = self.classifier(x_dec)
return outputs
def imputation(self, x_enc: torch.Tensor) -> torch.Tensor:
"""The interface for performing time series imputation tasks"""
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# time series decompsition
if self.use_avg is True:
seasonal_part, trend_part = self.decompsition(x_enc)
x_enc = seasonal_part.permute(0, 2, 1)
# Mapping trend part to target length
trend_part = trend_part.permute(0, 2, 1)
trend_part = self.projection_trend(trend_part)
trend_part = trend_part.permute(0, 2, 1)
else:
x_enc = x_enc.permute(0, 2, 1)
# do patching and reshape
x_enc = self.patching(ts=x_enc) # [batch_size, n_vars, patch_num, patch_len]
batch_size, n_vars, patch_num, patch_len = x_enc.size()
# Process data in a channel-independent manner
x_enc = torch.reshape(x_enc, shape=(batch_size * n_vars, patch_num, patch_len))
# After the large model forward propagation part
x_dec = self.time_encoder(x_enc) # [batch_size * n_vars, patch_num, d_model]
x_dec = torch.reshape(
x_dec, shape=(batch_size, n_vars, x_dec.size(-2), self.d_model)
)
# Restore the original output dimension of the model
x_dec = self.flatten_head(x_dec).permute(
0, 2, 1
) # [batch_size, pred_len, num_vars]
# add the trend part of the decompsition
if self.use_avg is True:
x_dec = x_dec + trend_part
# De-Normalization from Non-stationary Transformer
x_dec = x_dec * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
x_dec = x_dec + (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return x_dec
def anomaly_detection(self, x_enc: torch.Tensor) -> torch.Tensor:
"""The interface for time series anomaly detection"""
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# time series decompsition
if self.use_avg is True:
seasonal_part, trend_part = self.decompsition(x_enc)
x_enc = seasonal_part.permute(0, 2, 1)
# Mapping trend part to target length
trend_part = trend_part.permute(0, 2, 1)
trend_part = self.projection_trend(trend_part)
trend_part = trend_part.permute(0, 2, 1)
else:
x_enc = x_enc.permute(0, 2, 1)
# do patching and reshape
x_enc = self.patching(ts=x_enc) # [batch_size, n_vars, patch_num, patch_len]
batch_size, n_vars, patch_num, patch_len = x_enc.size()
# Process data in a channel-independent manner
x_enc = torch.reshape(x_enc, shape=(batch_size * n_vars, patch_num, patch_len))
# After the large model forward propagation part
x_dec = self.time_encoder(x_enc) # [batch_size * n_vars, patch_num, d_model]
x_dec = torch.reshape(
x_dec, [batch_size, n_vars, x_dec.shape[-2], x_dec.shape[-1]]
)
# Restore the original output dimension of the model
x_dec = self.flatten_head(x_dec).permute(
0, 2, 1
) # [batch_size, pred_len, num_vars]
# add the trend part of the decompsition
if self.use_avg is True:
x_dec = x_dec + trend_part
# De-Normalization from Non-stationary Transformer
x_dec = x_dec * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
x_dec = x_dec + (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return x_dec
def patching(self, ts: torch.Tensor) -> torch.Tensor:
"""Divide the time series into patch"""
if self.padding_patch is True:
ts = self.padding_patch_layer(ts)
ts = ts.unfold(dimension=-1, size=self.patch_len, step=self.stride)
return ts
def load_pretrained(self) -> None:
"""Loading pre-trained model parameters"""
print("Now loading pre-trained model params...")
self.time_encoder.load_state_dict(
torch.load(self.pretrain_path, weights_only=True)
)