You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
355 lines
13 KiB
355 lines
13 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/30 21:27
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from layers import TSTEncoder
|
|
from layers import LLM
|
|
|
|
from typing import Union, Tuple
|
|
|
|
|
|
class SymTime(nn.Module):
|
|
"""SymTime architecture for pre-training"""
|
|
|
|
def __init__(
|
|
self,
|
|
configs,
|
|
context_window: int,
|
|
time_mask_ratio: float,
|
|
sym_mask_ratio: float,
|
|
) -> None:
|
|
super().__init__()
|
|
self.context_window = context_window
|
|
self.patch_len = configs["patch_len"]
|
|
self.stride = self.patch_len if configs["stride"] is None else configs["stride"]
|
|
self.padding_patch = configs["padding_patch"]
|
|
self.time_layers = configs["time_layers"]
|
|
self.d_model = configs["d_model"]
|
|
self.n_heads = configs["n_heads"]
|
|
self.d_ff = configs["d_ff"]
|
|
self.llm_name = configs["llm_name"]
|
|
self.llm_layers = configs["llm_layers"]
|
|
self.hidden_size = configs["hidden_size"]
|
|
|
|
# Freeze the first n layers of parameters of LLM
|
|
self.freeze_layers = configs["freeze_layers"]
|
|
self.embed_dim = configs["embed_dim"]
|
|
|
|
# The ratio of momentum model parameter updates
|
|
self.momentum = configs["momentum"]
|
|
|
|
# Size of the Momentum Queue
|
|
self.queue_size = configs["queue_size"]
|
|
|
|
# Comparative learning annealing parameters
|
|
self.temp = nn.Parameter(torch.ones([]) * configs["temp"])
|
|
|
|
# Proportion of false targets in momentum distillation
|
|
self.alpha = configs["alpha"]
|
|
|
|
# Mask ratio of the time series and symbol data
|
|
self.time_mask_ratio = time_mask_ratio
|
|
self.sym_mask_ratio = sym_mask_ratio
|
|
|
|
# Creating an encoder for time series data
|
|
self.time_encoder = TSTEncoder(
|
|
patch_len=self.patch_len,
|
|
n_layers=self.time_layers,
|
|
d_model=self.d_model,
|
|
n_heads=self.n_heads,
|
|
d_ff=self.d_ff,
|
|
norm=configs["norm"],
|
|
attn_dropout=configs["attn_dropout"],
|
|
dropout=configs["dropout"],
|
|
act=configs["act"],
|
|
pre_norm=configs["pre_norm"],
|
|
)
|
|
|
|
# To obtain time series dimension reduction Token mapping
|
|
self.time_proj = nn.Linear(
|
|
in_features=self.d_model, out_features=self.embed_dim
|
|
)
|
|
|
|
# Linear mapping for time series patch reconstruction
|
|
self.reconstruct_project = nn.Linear(
|
|
in_features=self.d_model,
|
|
out_features=self.patch_len,
|
|
bias=configs["time_project_bias"],
|
|
)
|
|
|
|
# Creating an encoder for symbol data
|
|
self.symbolic_encoder = LLM(
|
|
llm_name=self.llm_name,
|
|
llm_layers=self.llm_layers,
|
|
hidden_size=configs["hidden_size"],
|
|
freeze_layers=self.freeze_layers,
|
|
)
|
|
|
|
# To obtain symbol dimension reduction Token mapping
|
|
self.sym_proj = nn.Linear(
|
|
in_features=self.hidden_size, out_features=self.embed_dim
|
|
)
|
|
|
|
# Get the tokenizer used by the text encoder
|
|
self.tokenizer = self.symbolic_encoder.tokenizer
|
|
|
|
# The size of the text capacity
|
|
self.vocab_size = self.tokenizer.vocab_size
|
|
|
|
# create momentum models
|
|
self.time_encoder_m = TSTEncoder(
|
|
patch_len=self.patch_len,
|
|
n_layers=self.time_layers,
|
|
d_model=self.d_model,
|
|
n_heads=self.n_heads,
|
|
d_ff=self.d_ff,
|
|
norm=configs["norm"],
|
|
attn_dropout=configs["attn_dropout"],
|
|
dropout=configs["dropout"],
|
|
act=configs["act"],
|
|
pre_norm=configs["pre_norm"],
|
|
)
|
|
self.time_proj_m = nn.Linear(
|
|
in_features=self.d_model, out_features=self.embed_dim
|
|
)
|
|
self.symbolic_encoder_m = LLM(
|
|
llm_name=self.llm_name,
|
|
llm_layers=self.llm_layers,
|
|
hidden_size=configs["hidden_size"],
|
|
)
|
|
self.sym_proj_m = nn.Linear(
|
|
in_features=self.hidden_size, out_features=self.embed_dim
|
|
)
|
|
|
|
# create momentum models params pairs
|
|
self.model_pairs = [
|
|
[self.time_encoder, self.time_encoder_m],
|
|
[self.time_proj, self.time_proj_m],
|
|
[self.symbolic_encoder, self.symbolic_encoder_m],
|
|
[self.sym_proj, self.sym_proj_m],
|
|
]
|
|
# copy the params
|
|
self.copy_params()
|
|
|
|
# create the queue
|
|
self.register_buffer("time_queue", torch.randn(self.embed_dim, self.queue_size))
|
|
self.register_buffer("sym_queue", torch.randn(self.embed_dim, self.queue_size))
|
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
|
self.time_queue = F.normalize(self.time_queue, dim=0)
|
|
self.sym_queue = F.normalize(self.sym_queue, dim=0)
|
|
|
|
def forward(
|
|
self,
|
|
time: torch.Tensor,
|
|
time_mask: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
attn_mask: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
with torch.no_grad():
|
|
self.temp.clamp_(0.001, 0.5)
|
|
|
|
# Masking the time series data
|
|
time_masked, time_attn_mask = self.time_masking(
|
|
inputs=time, attn_mask=time_mask
|
|
)
|
|
# time_masked = torch.concat([time_masked, time[:, padding_index:, :]], dim=1)
|
|
|
|
# Masking the symbol data as nature language
|
|
labels = input_ids.clone()
|
|
inputs_ids, labels = self.nlp_mask(
|
|
input_ids=input_ids,
|
|
vocab_size=self.vocab_size,
|
|
device=time.device,
|
|
targets=labels,
|
|
)
|
|
|
|
# Forward propagation of time series data through the time series encoder
|
|
time_embeds = self.time_encoder(x=time_masked, attn_mask=time_attn_mask)
|
|
time_reconstruct = self.reconstruct_project(time_embeds)
|
|
|
|
# Get the mask loss_mtm through restruct the missing patch
|
|
loss_mtm = (time_reconstruct[:, 1:, :] - time) ** 2
|
|
loss_mtm = loss_mtm.mean(dim=-1)
|
|
# Only make losses in the masked areas
|
|
loss_mtm = (loss_mtm * (~time_attn_mask).int()).sum() / (~time_attn_mask).sum()
|
|
|
|
# Forward propagation for symbolic data in natural language
|
|
sym_outputs = self.symbolic_encoder(inputs_ids, attn_mask, labels)
|
|
|
|
# Get the mask loss_mlm through output
|
|
loss_mlm = sym_outputs.loss
|
|
|
|
# Get the [CLS] features of time series and symbol data as global features
|
|
time_features = F.normalize(self.time_proj(time_embeds[:, 0, :]), dim=-1)
|
|
sym_features = F.normalize(
|
|
self.sym_proj(sym_outputs.hidden_states[-1][:, 0, :]), dim=-1
|
|
)
|
|
|
|
# get the momentum features
|
|
with torch.no_grad():
|
|
# Update the parameters of the momentum module
|
|
self.momentum_update()
|
|
time_embeds_m = self.time_encoder_m(x=time_masked, attn_mask=time_attn_mask)
|
|
time_features_m = F.normalize(
|
|
self.time_proj_m(time_embeds_m[:, 0, :]), dim=-1
|
|
)
|
|
# time features enqueue
|
|
time_features_all = torch.cat(
|
|
[time_features_m.t(), self.time_queue.clone().detach()], dim=1
|
|
)
|
|
|
|
sym_outputs_m = self.symbolic_encoder_m(inputs_ids, attn_mask, labels)
|
|
sym_features_m = F.normalize(
|
|
self.sym_proj_m(sym_outputs_m.hidden_states[-1][:, 0, :]), dim=-1
|
|
)
|
|
|
|
# symbol features enqueue
|
|
sym_features_all = torch.cat(
|
|
[sym_features_m.t(), self.sym_queue.clone().detach()], dim=1
|
|
)
|
|
|
|
# Let the time series features match the symbol features [batch_size, batch_size]
|
|
sim_t2s_m = time_features_m @ sym_features_all / self.temp # s(I, Tm) / tao
|
|
|
|
# Let the symbol features match the time series features
|
|
sim_s2t_m = sym_features_m @ time_features_all / self.temp # s(T, Im) / tao
|
|
|
|
sim_targets = torch.zeros(sim_t2s_m.size()).to(time_masked.device)
|
|
|
|
sim_targets.fill_diagonal_(1)
|
|
sim_t2s_targets = (
|
|
self.alpha * F.softmax(sim_t2s_m, dim=1)
|
|
+ (1 - self.alpha) * sim_targets
|
|
)
|
|
sim_s2t_targets = (
|
|
self.alpha * F.softmax(sim_s2t_m, dim=1)
|
|
+ (1 - self.alpha) * sim_targets
|
|
)
|
|
|
|
sim_t2s = time_features @ sym_features_all / self.temp
|
|
sim_s2t = sym_features @ time_features_all / self.temp
|
|
|
|
loss_t2s = -torch.sum(
|
|
F.log_softmax(sim_t2s, dim=1) * F.softmax(sim_t2s_targets, dim=1), dim=1
|
|
).mean()
|
|
loss_s2t = -torch.sum(
|
|
F.log_softmax(sim_s2t, dim=1) * F.softmax(sim_s2t_targets, dim=1), dim=1
|
|
).mean()
|
|
|
|
# let the new features enqueue and the old features dequeue
|
|
self.enqueue_and_dequeue(time_features_m, sym_features_m)
|
|
|
|
return loss_mtm, loss_mlm, loss_t2s, loss_s2t
|
|
|
|
def time_masking(
|
|
self, inputs: torch.Tensor, attn_mask: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Function to add mask to time series data"""
|
|
ts = inputs.clone()
|
|
mask = attn_mask.clone()
|
|
# Get batch information
|
|
batch_size, num_tokens, patch_len = inputs.size()
|
|
token_array = torch.sum(attn_mask, dim=1)
|
|
# The proportion of each part that is masked
|
|
num_array = (token_array * self.time_mask_ratio).int()
|
|
|
|
for i in range(0, batch_size):
|
|
padding_index = token_array[i]
|
|
number = num_array[i]
|
|
noise = torch.rand(padding_index)
|
|
ids_shuffle = torch.argsort(noise)[:number]
|
|
ts[i, ids_shuffle, :] = 0
|
|
mask[i, ids_shuffle] = False
|
|
|
|
return ts, mask
|
|
|
|
def nlp_mask(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
targets: torch.Tensor = None,
|
|
masked_indices=None,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor] or torch.Tensor]:
|
|
"""Function to add mask to symbolic data"""
|
|
probability_matrix = torch.full(targets.shape, self.sym_mask_ratio)
|
|
if masked_indices is None:
|
|
# Determine whether the masked content is specified. If not specified, mask it randomly.
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
|
|
# Make sure the two key tokens pad and cls are not masked out
|
|
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
|
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
|
|
|
# Used for subsequent loss calculations only on the masked parts
|
|
if targets is not None:
|
|
targets[~masked_indices] = -100 # We only compute loss on masked tokens
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = (
|
|
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
|
)
|
|
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = (
|
|
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
|
|
& masked_indices
|
|
& ~indices_replaced
|
|
)
|
|
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(
|
|
device
|
|
)
|
|
input_ids[indices_random] = random_words[indices_random]
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
|
|
if targets is not None:
|
|
return input_ids, targets
|
|
else:
|
|
return input_ids
|
|
|
|
@torch.no_grad()
|
|
def copy_params(self) -> None:
|
|
"""Copy the parameters of the momentum model"""
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(
|
|
model_pair[0].parameters(), model_pair[1].parameters()
|
|
):
|
|
param_m.data.copy_(param.data) # initialize the momentum model params
|
|
param_m.requires_grad = False # not update the momentum by gradient
|
|
|
|
@torch.no_grad()
|
|
def momentum_update(self) -> None:
|
|
"""Update the parameters of the momentum encoder"""
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(
|
|
model_pair[0].parameters(), model_pair[1].parameters()
|
|
):
|
|
param_m.data = param_m.data * self.momentum + param.data * (
|
|
1.0 - self.momentum
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def enqueue_and_dequeue(
|
|
self, time_features: torch.Tensor, sym_features: torch.Tensor
|
|
) -> None:
|
|
"""Methods for controlling feature enqueue and dequeue"""
|
|
batch_size = time_features.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.queue_size % batch_size == 0 # for simplicity
|
|
|
|
# replace the keys at ptr (dequeue and enqueue)
|
|
self.time_queue[:, ptr : ptr + batch_size] = time_features.T
|
|
self.sym_queue[:, ptr : ptr + batch_size] = sym_features.T
|
|
|
|
# move the pointer ptr
|
|
ptr = (ptr + batch_size) % self.queue_size
|
|
self.queue_ptr[0] = ptr
|