You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
274 lines
10 KiB
274 lines
10 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/23 15:56
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import os
|
|
from os import path
|
|
import sys
|
|
from time import sleep
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
from accelerate import Accelerator
|
|
from typing import Callable, Tuple, Dict
|
|
from colorama import Fore, Style
|
|
from tqdm import tqdm
|
|
|
|
from utils.tools import makedir
|
|
from utils.logging import Logging
|
|
|
|
|
|
class Exp_PreTraining(object):
|
|
"""The Interface for the pre-training of SymTime"""
|
|
|
|
def __init__(
|
|
self,
|
|
args,
|
|
model: nn.Module,
|
|
optimizer: Optimizer,
|
|
criterion: Callable,
|
|
scheduler: LRScheduler,
|
|
accelerator: Accelerator,
|
|
data_interface,
|
|
) -> None:
|
|
self.args = args
|
|
# Get the number of training rounds
|
|
self.num_epochs = args.num_epochs
|
|
self.save_epochs = args.save_epochs
|
|
|
|
# Get the model and training dataset
|
|
self.model = model
|
|
|
|
# Get the optimizer for the neural network
|
|
self.optimizer = optimizer
|
|
|
|
# Get the loss function
|
|
self.criterion = criterion
|
|
|
|
# Get dynamically adjusted learning rate
|
|
self.scheduler = scheduler
|
|
|
|
# Get the Synergy Accelerator
|
|
self.accelerator = accelerator
|
|
|
|
# Record the current process ID
|
|
self.process_index = self.accelerator.process_index
|
|
|
|
# Get the training set and validation set
|
|
self.data_interface = data_interface
|
|
# Get the current training device
|
|
self.device = self.accelerator.device
|
|
|
|
if (
|
|
self.process_index == 0
|
|
): # Only the main process records losses and saves model parameters
|
|
# Get the address where the model and parameters are saved
|
|
self.main_path, self.params_path = self.init_path()
|
|
# Create a Logging module for model training
|
|
self.logging = Logging(
|
|
is_pretrain=True, logging_path=self.main_path, datasets=[]
|
|
)
|
|
|
|
def fit(
|
|
self,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Training model to fit data"""
|
|
self.accelerator.print(
|
|
Fore.GREEN + "Starting SymTime Model Pre-training..." + Style.RESET_ALL
|
|
)
|
|
|
|
# Record the four loss functions used by SymTime in self-supervised pre-training
|
|
train_loss = torch.zeros(self.num_epochs, device=self.device)
|
|
train_loss_mtm = torch.zeros(self.num_epochs, device=self.device)
|
|
train_loss_mlm = torch.zeros(self.num_epochs, device=self.device)
|
|
train_loss_t2s = torch.zeros(self.num_epochs, device=self.device)
|
|
train_loss_s2t = torch.zeros(self.num_epochs, device=self.device)
|
|
|
|
for idx, epoch in enumerate(range(1, self.num_epochs + 1), 0):
|
|
"""Here begins an Epoch"""
|
|
num_samples = 0
|
|
# The cumulative number of samples traversed in this Epoch
|
|
|
|
for ii in range(1, len(self.data_interface) + 1):
|
|
"""In one Epoch, all data must be traversed and read"""
|
|
self.accelerator.print(
|
|
Fore.RED + "Now is loading pre-training data" + Style.RESET_ALL,
|
|
end=" -> ",
|
|
)
|
|
train_loader = self.data_interface.get_dataloader()
|
|
train_loader = self.accelerator.prepare_data_loader(
|
|
train_loader, device_placement=True
|
|
)
|
|
sleep(2)
|
|
self.accelerator.print(
|
|
Fore.GREEN + "successfully loaded!" + Style.RESET_ALL
|
|
)
|
|
self.model.train()
|
|
data_loader = tqdm(train_loader, file=sys.stdout)
|
|
for step, (time, time_mask, sym_ids, sym_mask) in enumerate(
|
|
data_loader, 1
|
|
):
|
|
self.optimizer.zero_grad()
|
|
num_samples += time.shape[0]
|
|
# Obtain the loss directly during the forward propagation of the model
|
|
loss_mtm, loss_mlm, loss_t2s, loss_s2t = self.model(
|
|
time, time_mask, sym_ids, sym_mask
|
|
)
|
|
|
|
# Acquiring and integrating errors
|
|
loss = loss_mtm + loss_mlm + (loss_t2s + loss_s2t) / 2
|
|
# Back propagation of error
|
|
self.accelerator.backward(loss)
|
|
|
|
# Parameter update
|
|
self.optimizer.step()
|
|
# Checking model loss
|
|
check_loss(loss, train_type="Pretrain")
|
|
|
|
# Calculate the cumulative loss of this epoch
|
|
train_loss[idx] += loss.item()
|
|
train_loss_mtm[idx] += loss_mtm.item()
|
|
train_loss_mlm[idx] += loss_mlm.item()
|
|
train_loss_t2s[idx] += loss_t2s.item()
|
|
train_loss_s2t[idx] += loss_s2t.item()
|
|
data_loader.desc = (
|
|
"["
|
|
+ Fore.GREEN
|
|
+ f"Epoch {epoch}"
|
|
+ Style.RESET_ALL
|
|
+ "] "
|
|
+ "Loss="
|
|
+ Fore.GREEN
|
|
+ f"{round(train_loss[idx].item() / num_samples, 6)}"
|
|
+ Style.RESET_ALL
|
|
+ f" loss_mtm: {round(train_loss_mtm[idx].item() / num_samples, 6)}, loss_mlm: {round(train_loss_mlm[idx].item() / num_samples, 6)}, "
|
|
f"loss_t2s: {round(train_loss_t2s[idx].item() / num_samples, 6)}, loss_s2t: {round(train_loss_s2t[idx].item() / num_samples, 6)}"
|
|
)
|
|
# Dynamically adjust learning rate
|
|
self.scheduler.step()
|
|
# Freeing up memory for training optimizers
|
|
self.accelerator.clear(train_loader)
|
|
|
|
# Record changes in final loss
|
|
train_loss[idx] = train_loss[idx] / num_samples
|
|
train_loss_mtm[idx] = train_loss_mtm[idx] / num_samples
|
|
train_loss_mlm[idx] = train_loss_mlm[idx] / num_samples
|
|
train_loss_t2s[idx] = train_loss_t2s[idx] / num_samples
|
|
train_loss_s2t[idx] = train_loss_s2t[idx] / num_samples
|
|
|
|
if epoch % self.save_epochs == 0:
|
|
# Save the parameters of a pre-trained model
|
|
self.save_model(loss=train_loss[idx], epoch=epoch)
|
|
|
|
# Logging training process to register the current epoch and the final loss
|
|
self.logging_epoch(
|
|
epoch,
|
|
train_loss[idx],
|
|
train_loss_mtm[idx],
|
|
train_loss_mlm[idx],
|
|
train_loss_t2s[idx],
|
|
train_loss_s2t[idx],
|
|
)
|
|
|
|
return (
|
|
train_loss,
|
|
train_loss_mtm,
|
|
train_loss_mlm,
|
|
train_loss_t2s,
|
|
train_loss_s2t,
|
|
)
|
|
|
|
def init_path(self) -> Tuple[str, str]:
|
|
"""Get the address of the pre-training saved model and logging"""
|
|
|
|
# Directory where the model is saved
|
|
save_path = self.args.save_path
|
|
|
|
# Determine how many files are in the save directory
|
|
num_folder = len(os.listdir(save_path))
|
|
|
|
# Create a folder to save the model
|
|
folder_name = f"exp{num_folder + 1}"
|
|
makedir(save_path, folder_name)
|
|
|
|
# Update the main address of the save directory
|
|
main_path = path.join(save_path, folder_name)
|
|
|
|
# Create a folder to save model parameters
|
|
makedir(main_path, "params")
|
|
params_path = path.join(main_path, "params")
|
|
|
|
print(f"Attention the logging path is {main_path}")
|
|
return main_path, params_path
|
|
|
|
def save_model(self, epoch: int, loss: torch.Tensor) -> None:
|
|
"""Save the parameters of SymTime time encoder in pre-training"""
|
|
if self.process_index == 0:
|
|
self.accelerator.print(
|
|
Fore.RED + "Now is saving the pre-trained params" + Style.RESET_ALL,
|
|
end=" -> ",
|
|
)
|
|
save_name = f"{epoch}_{round(loss.item(), 4)}.pth"
|
|
torch.save(
|
|
self.model.time_encoder.state_dict(),
|
|
path.join(self.params_path, save_name),
|
|
)
|
|
self.accelerator.print(Fore.GREEN + "successfully saved!" + Style.RESET_ALL)
|
|
|
|
def logging_epoch(
|
|
self,
|
|
epoch: int,
|
|
train_loss: torch.Tensor,
|
|
train_loss_mtm: torch.Tensor,
|
|
train_loss_mlm: torch.Tensor,
|
|
train_loss_t2s: torch.Tensor,
|
|
train_loss_s2t: torch.Tensor,
|
|
) -> None:
|
|
"""Record the changes in training loss of an Epoch"""
|
|
gather_train_loss = self.accelerator.gather(train_loss).mean().item()
|
|
gather_train_loss_mtm = self.accelerator.gather(train_loss_mtm).mean().item()
|
|
gather_train_loss_mlm = self.accelerator.gather(train_loss_mlm).mean().item()
|
|
gather_train_loss_t2s = self.accelerator.gather(train_loss_t2s).mean().item()
|
|
gather_train_loss_s2t = self.accelerator.gather(train_loss_s2t).mean().item()
|
|
if self.process_index == 0:
|
|
# Record the average loss of all processes under an Epoch
|
|
self.logging.logging_epoch(
|
|
epoch,
|
|
gather_train_loss,
|
|
gather_train_loss_mtm,
|
|
gather_train_loss_mlm,
|
|
gather_train_loss_t2s,
|
|
gather_train_loss_s2t,
|
|
)
|
|
|
|
|
|
def init_path(save_path) -> str:
|
|
"""Get the address of the pre-training saved model and logging"""
|
|
# Determine how many files are in the save directory
|
|
num_folder = len(os.listdir(save_path))
|
|
|
|
# Create a folder to save the model
|
|
folder_name = f"exp{num_folder + 1}"
|
|
makedir(save_path, folder_name)
|
|
|
|
# Update the main address of the save directory
|
|
main_path = path.join(save_path, folder_name)
|
|
|
|
# Create a folder to save model parameters
|
|
return main_path
|
|
|
|
|
|
def check_loss(loss: torch.Tensor, train_type: str) -> None:
|
|
"""Check training and validation losses to avoid exploding gradients"""
|
|
if not torch.isfinite(loss):
|
|
print(
|
|
Fore.RED + f"{train_type} now occurs ERROR: non-finite loss, end training!"
|
|
)
|
|
sys.exit(1)
|