You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

274 lines
10 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/23 15:56
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
import os
from os import path
import sys
from time import sleep
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from accelerate import Accelerator
from typing import Callable, Tuple, Dict
from colorama import Fore, Style
from tqdm import tqdm
from utils.tools import makedir
from utils.logging import Logging
class Exp_PreTraining(object):
"""The Interface for the pre-training of SymTime"""
def __init__(
self,
args,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable,
scheduler: LRScheduler,
accelerator: Accelerator,
data_interface,
) -> None:
self.args = args
# Get the number of training rounds
self.num_epochs = args.num_epochs
self.save_epochs = args.save_epochs
# Get the model and training dataset
self.model = model
# Get the optimizer for the neural network
self.optimizer = optimizer
# Get the loss function
self.criterion = criterion
# Get dynamically adjusted learning rate
self.scheduler = scheduler
# Get the Synergy Accelerator
self.accelerator = accelerator
# Record the current process ID
self.process_index = self.accelerator.process_index
# Get the training set and validation set
self.data_interface = data_interface
# Get the current training device
self.device = self.accelerator.device
if (
self.process_index == 0
): # Only the main process records losses and saves model parameters
# Get the address where the model and parameters are saved
self.main_path, self.params_path = self.init_path()
# Create a Logging module for model training
self.logging = Logging(
is_pretrain=True, logging_path=self.main_path, datasets=[]
)
def fit(
self,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Training model to fit data"""
self.accelerator.print(
Fore.GREEN + "Starting SymTime Model Pre-training..." + Style.RESET_ALL
)
# Record the four loss functions used by SymTime in self-supervised pre-training
train_loss = torch.zeros(self.num_epochs, device=self.device)
train_loss_mtm = torch.zeros(self.num_epochs, device=self.device)
train_loss_mlm = torch.zeros(self.num_epochs, device=self.device)
train_loss_t2s = torch.zeros(self.num_epochs, device=self.device)
train_loss_s2t = torch.zeros(self.num_epochs, device=self.device)
for idx, epoch in enumerate(range(1, self.num_epochs + 1), 0):
"""Here begins an Epoch"""
num_samples = 0
# The cumulative number of samples traversed in this Epoch
for ii in range(1, len(self.data_interface) + 1):
"""In one Epoch, all data must be traversed and read"""
self.accelerator.print(
Fore.RED + "Now is loading pre-training data" + Style.RESET_ALL,
end=" -> ",
)
train_loader = self.data_interface.get_dataloader()
train_loader = self.accelerator.prepare_data_loader(
train_loader, device_placement=True
)
sleep(2)
self.accelerator.print(
Fore.GREEN + "successfully loaded!" + Style.RESET_ALL
)
self.model.train()
data_loader = tqdm(train_loader, file=sys.stdout)
for step, (time, time_mask, sym_ids, sym_mask) in enumerate(
data_loader, 1
):
self.optimizer.zero_grad()
num_samples += time.shape[0]
# Obtain the loss directly during the forward propagation of the model
loss_mtm, loss_mlm, loss_t2s, loss_s2t = self.model(
time, time_mask, sym_ids, sym_mask
)
# Acquiring and integrating errors
loss = loss_mtm + loss_mlm + (loss_t2s + loss_s2t) / 2
# Back propagation of error
self.accelerator.backward(loss)
# Parameter update
self.optimizer.step()
# Checking model loss
check_loss(loss, train_type="Pretrain")
# Calculate the cumulative loss of this epoch
train_loss[idx] += loss.item()
train_loss_mtm[idx] += loss_mtm.item()
train_loss_mlm[idx] += loss_mlm.item()
train_loss_t2s[idx] += loss_t2s.item()
train_loss_s2t[idx] += loss_s2t.item()
data_loader.desc = (
"["
+ Fore.GREEN
+ f"Epoch {epoch}"
+ Style.RESET_ALL
+ "] "
+ "Loss="
+ Fore.GREEN
+ f"{round(train_loss[idx].item() / num_samples, 6)}"
+ Style.RESET_ALL
+ f" loss_mtm: {round(train_loss_mtm[idx].item() / num_samples, 6)}, loss_mlm: {round(train_loss_mlm[idx].item() / num_samples, 6)}, "
f"loss_t2s: {round(train_loss_t2s[idx].item() / num_samples, 6)}, loss_s2t: {round(train_loss_s2t[idx].item() / num_samples, 6)}"
)
# Dynamically adjust learning rate
self.scheduler.step()
# Freeing up memory for training optimizers
self.accelerator.clear(train_loader)
# Record changes in final loss
train_loss[idx] = train_loss[idx] / num_samples
train_loss_mtm[idx] = train_loss_mtm[idx] / num_samples
train_loss_mlm[idx] = train_loss_mlm[idx] / num_samples
train_loss_t2s[idx] = train_loss_t2s[idx] / num_samples
train_loss_s2t[idx] = train_loss_s2t[idx] / num_samples
if epoch % self.save_epochs == 0:
# Save the parameters of a pre-trained model
self.save_model(loss=train_loss[idx], epoch=epoch)
# Logging training process to register the current epoch and the final loss
self.logging_epoch(
epoch,
train_loss[idx],
train_loss_mtm[idx],
train_loss_mlm[idx],
train_loss_t2s[idx],
train_loss_s2t[idx],
)
return (
train_loss,
train_loss_mtm,
train_loss_mlm,
train_loss_t2s,
train_loss_s2t,
)
def init_path(self) -> Tuple[str, str]:
"""Get the address of the pre-training saved model and logging"""
# Directory where the model is saved
save_path = self.args.save_path
# Determine how many files are in the save directory
num_folder = len(os.listdir(save_path))
# Create a folder to save the model
folder_name = f"exp{num_folder + 1}"
makedir(save_path, folder_name)
# Update the main address of the save directory
main_path = path.join(save_path, folder_name)
# Create a folder to save model parameters
makedir(main_path, "params")
params_path = path.join(main_path, "params")
print(f"Attention the logging path is {main_path}")
return main_path, params_path
def save_model(self, epoch: int, loss: torch.Tensor) -> None:
"""Save the parameters of SymTime time encoder in pre-training"""
if self.process_index == 0:
self.accelerator.print(
Fore.RED + "Now is saving the pre-trained params" + Style.RESET_ALL,
end=" -> ",
)
save_name = f"{epoch}_{round(loss.item(), 4)}.pth"
torch.save(
self.model.time_encoder.state_dict(),
path.join(self.params_path, save_name),
)
self.accelerator.print(Fore.GREEN + "successfully saved!" + Style.RESET_ALL)
def logging_epoch(
self,
epoch: int,
train_loss: torch.Tensor,
train_loss_mtm: torch.Tensor,
train_loss_mlm: torch.Tensor,
train_loss_t2s: torch.Tensor,
train_loss_s2t: torch.Tensor,
) -> None:
"""Record the changes in training loss of an Epoch"""
gather_train_loss = self.accelerator.gather(train_loss).mean().item()
gather_train_loss_mtm = self.accelerator.gather(train_loss_mtm).mean().item()
gather_train_loss_mlm = self.accelerator.gather(train_loss_mlm).mean().item()
gather_train_loss_t2s = self.accelerator.gather(train_loss_t2s).mean().item()
gather_train_loss_s2t = self.accelerator.gather(train_loss_s2t).mean().item()
if self.process_index == 0:
# Record the average loss of all processes under an Epoch
self.logging.logging_epoch(
epoch,
gather_train_loss,
gather_train_loss_mtm,
gather_train_loss_mlm,
gather_train_loss_t2s,
gather_train_loss_s2t,
)
def init_path(save_path) -> str:
"""Get the address of the pre-training saved model and logging"""
# Determine how many files are in the save directory
num_folder = len(os.listdir(save_path))
# Create a folder to save the model
folder_name = f"exp{num_folder + 1}"
makedir(save_path, folder_name)
# Update the main address of the save directory
main_path = path.join(save_path, folder_name)
# Create a folder to save model parameters
return main_path
def check_loss(loss: torch.Tensor, train_type: str) -> None:
"""Check training and validation losses to avoid exploding gradients"""
if not torch.isfinite(loss):
print(
Fore.RED + f"{train_type} now occurs ERROR: non-finite loss, end training!"
)
sys.exit(1)