You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

355 lines
13 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/30 21:27
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
import torch
from torch import nn
from torch.nn import functional as F
from layers import TSTEncoder
from layers import LLM
from typing import Union, Tuple
class SymTime(nn.Module):
"""SymTime architecture for pre-training"""
def __init__(
self,
configs,
context_window: int,
time_mask_ratio: float,
sym_mask_ratio: float,
) -> None:
super().__init__()
self.context_window = context_window
self.patch_len = configs["patch_len"]
self.stride = self.patch_len if configs["stride"] is None else configs["stride"]
self.padding_patch = configs["padding_patch"]
self.time_layers = configs["time_layers"]
self.d_model = configs["d_model"]
self.n_heads = configs["n_heads"]
self.d_ff = configs["d_ff"]
self.llm_name = configs["llm_name"]
self.llm_layers = configs["llm_layers"]
self.hidden_size = configs["hidden_size"]
# Freeze the first n layers of parameters of LLM
self.freeze_layers = configs["freeze_layers"]
self.embed_dim = configs["embed_dim"]
# The ratio of momentum model parameter updates
self.momentum = configs["momentum"]
# Size of the Momentum Queue
self.queue_size = configs["queue_size"]
# Comparative learning annealing parameters
self.temp = nn.Parameter(torch.ones([]) * configs["temp"])
# Proportion of false targets in momentum distillation
self.alpha = configs["alpha"]
# Mask ratio of the time series and symbol data
self.time_mask_ratio = time_mask_ratio
self.sym_mask_ratio = sym_mask_ratio
# Creating an encoder for time series data
self.time_encoder = TSTEncoder(
patch_len=self.patch_len,
n_layers=self.time_layers,
d_model=self.d_model,
n_heads=self.n_heads,
d_ff=self.d_ff,
norm=configs["norm"],
attn_dropout=configs["attn_dropout"],
dropout=configs["dropout"],
act=configs["act"],
pre_norm=configs["pre_norm"],
)
# To obtain time series dimension reduction Token mapping
self.time_proj = nn.Linear(
in_features=self.d_model, out_features=self.embed_dim
)
# Linear mapping for time series patch reconstruction
self.reconstruct_project = nn.Linear(
in_features=self.d_model,
out_features=self.patch_len,
bias=configs["time_project_bias"],
)
# Creating an encoder for symbol data
self.symbolic_encoder = LLM(
llm_name=self.llm_name,
llm_layers=self.llm_layers,
hidden_size=configs["hidden_size"],
freeze_layers=self.freeze_layers,
)
# To obtain symbol dimension reduction Token mapping
self.sym_proj = nn.Linear(
in_features=self.hidden_size, out_features=self.embed_dim
)
# Get the tokenizer used by the text encoder
self.tokenizer = self.symbolic_encoder.tokenizer
# The size of the text capacity
self.vocab_size = self.tokenizer.vocab_size
# create momentum models
self.time_encoder_m = TSTEncoder(
patch_len=self.patch_len,
n_layers=self.time_layers,
d_model=self.d_model,
n_heads=self.n_heads,
d_ff=self.d_ff,
norm=configs["norm"],
attn_dropout=configs["attn_dropout"],
dropout=configs["dropout"],
act=configs["act"],
pre_norm=configs["pre_norm"],
)
self.time_proj_m = nn.Linear(
in_features=self.d_model, out_features=self.embed_dim
)
self.symbolic_encoder_m = LLM(
llm_name=self.llm_name,
llm_layers=self.llm_layers,
hidden_size=configs["hidden_size"],
)
self.sym_proj_m = nn.Linear(
in_features=self.hidden_size, out_features=self.embed_dim
)
# create momentum models params pairs
self.model_pairs = [
[self.time_encoder, self.time_encoder_m],
[self.time_proj, self.time_proj_m],
[self.symbolic_encoder, self.symbolic_encoder_m],
[self.sym_proj, self.sym_proj_m],
]
# copy the params
self.copy_params()
# create the queue
self.register_buffer("time_queue", torch.randn(self.embed_dim, self.queue_size))
self.register_buffer("sym_queue", torch.randn(self.embed_dim, self.queue_size))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
self.time_queue = F.normalize(self.time_queue, dim=0)
self.sym_queue = F.normalize(self.sym_queue, dim=0)
def forward(
self,
time: torch.Tensor,
time_mask: torch.Tensor,
input_ids: torch.Tensor,
attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.no_grad():
self.temp.clamp_(0.001, 0.5)
# Masking the time series data
time_masked, time_attn_mask = self.time_masking(
inputs=time, attn_mask=time_mask
)
# time_masked = torch.concat([time_masked, time[:, padding_index:, :]], dim=1)
# Masking the symbol data as nature language
labels = input_ids.clone()
inputs_ids, labels = self.nlp_mask(
input_ids=input_ids,
vocab_size=self.vocab_size,
device=time.device,
targets=labels,
)
# Forward propagation of time series data through the time series encoder
time_embeds = self.time_encoder(x=time_masked, attn_mask=time_attn_mask)
time_reconstruct = self.reconstruct_project(time_embeds)
# Get the mask loss_mtm through restruct the missing patch
loss_mtm = (time_reconstruct[:, 1:, :] - time) ** 2
loss_mtm = loss_mtm.mean(dim=-1)
# Only make losses in the masked areas
loss_mtm = (loss_mtm * (~time_attn_mask).int()).sum() / (~time_attn_mask).sum()
# Forward propagation for symbolic data in natural language
sym_outputs = self.symbolic_encoder(inputs_ids, attn_mask, labels)
# Get the mask loss_mlm through output
loss_mlm = sym_outputs.loss
# Get the [CLS] features of time series and symbol data as global features
time_features = F.normalize(self.time_proj(time_embeds[:, 0, :]), dim=-1)
sym_features = F.normalize(
self.sym_proj(sym_outputs.hidden_states[-1][:, 0, :]), dim=-1
)
# get the momentum features
with torch.no_grad():
# Update the parameters of the momentum module
self.momentum_update()
time_embeds_m = self.time_encoder_m(x=time_masked, attn_mask=time_attn_mask)
time_features_m = F.normalize(
self.time_proj_m(time_embeds_m[:, 0, :]), dim=-1
)
# time features enqueue
time_features_all = torch.cat(
[time_features_m.t(), self.time_queue.clone().detach()], dim=1
)
sym_outputs_m = self.symbolic_encoder_m(inputs_ids, attn_mask, labels)
sym_features_m = F.normalize(
self.sym_proj_m(sym_outputs_m.hidden_states[-1][:, 0, :]), dim=-1
)
# symbol features enqueue
sym_features_all = torch.cat(
[sym_features_m.t(), self.sym_queue.clone().detach()], dim=1
)
# Let the time series features match the symbol features [batch_size, batch_size]
sim_t2s_m = time_features_m @ sym_features_all / self.temp # s(I, Tm) / tao
# Let the symbol features match the time series features
sim_s2t_m = sym_features_m @ time_features_all / self.temp # s(T, Im) / tao
sim_targets = torch.zeros(sim_t2s_m.size()).to(time_masked.device)
sim_targets.fill_diagonal_(1)
sim_t2s_targets = (
self.alpha * F.softmax(sim_t2s_m, dim=1)
+ (1 - self.alpha) * sim_targets
)
sim_s2t_targets = (
self.alpha * F.softmax(sim_s2t_m, dim=1)
+ (1 - self.alpha) * sim_targets
)
sim_t2s = time_features @ sym_features_all / self.temp
sim_s2t = sym_features @ time_features_all / self.temp
loss_t2s = -torch.sum(
F.log_softmax(sim_t2s, dim=1) * F.softmax(sim_t2s_targets, dim=1), dim=1
).mean()
loss_s2t = -torch.sum(
F.log_softmax(sim_s2t, dim=1) * F.softmax(sim_s2t_targets, dim=1), dim=1
).mean()
# let the new features enqueue and the old features dequeue
self.enqueue_and_dequeue(time_features_m, sym_features_m)
return loss_mtm, loss_mlm, loss_t2s, loss_s2t
def time_masking(
self, inputs: torch.Tensor, attn_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to add mask to time series data"""
ts = inputs.clone()
mask = attn_mask.clone()
# Get batch information
batch_size, num_tokens, patch_len = inputs.size()
token_array = torch.sum(attn_mask, dim=1)
# The proportion of each part that is masked
num_array = (token_array * self.time_mask_ratio).int()
for i in range(0, batch_size):
padding_index = token_array[i]
number = num_array[i]
noise = torch.rand(padding_index)
ids_shuffle = torch.argsort(noise)[:number]
ts[i, ids_shuffle, :] = 0
mask[i, ids_shuffle] = False
return ts, mask
def nlp_mask(
self,
input_ids: torch.Tensor,
vocab_size: int,
device: torch.device,
targets: torch.Tensor = None,
masked_indices=None,
) -> Union[Tuple[torch.Tensor, torch.Tensor] or torch.Tensor]:
"""Function to add mask to symbolic data"""
probability_matrix = torch.full(targets.shape, self.sym_mask_ratio)
if masked_indices is None:
# Determine whether the masked content is specified. If not specified, mask it randomly.
masked_indices = torch.bernoulli(probability_matrix).bool()
# Make sure the two key tokens pad and cls are not masked out
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
# Used for subsequent loss calculations only on the masked parts
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
)
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(
device
)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
@torch.no_grad()
def copy_params(self) -> None:
"""Copy the parameters of the momentum model"""
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data.copy_(param.data) # initialize the momentum model params
param_m.requires_grad = False # not update the momentum by gradient
@torch.no_grad()
def momentum_update(self) -> None:
"""Update the parameters of the momentum encoder"""
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data = param_m.data * self.momentum + param.data * (
1.0 - self.momentum
)
@torch.no_grad()
def enqueue_and_dequeue(
self, time_features: torch.Tensor, sym_features: torch.Tensor
) -> None:
"""Methods for controlling feature enqueue and dequeue"""
batch_size = time_features.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.time_queue[:, ptr : ptr + batch_size] = time_features.T
self.sym_queue[:, ptr : ptr + batch_size] = sym_features.T
# move the pointer ptr
ptr = (ptr + batch_size) % self.queue_size
self.queue_ptr[0] = ptr