You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
244 lines
6.6 KiB
244 lines
6.6 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
For the pre-training of SymTime using self-supervised learning.
|
|
|
|
Created on 2024/10/9 17:20
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import random
|
|
import os
|
|
import argparse
|
|
|
|
import numpy as np
|
|
import torch
|
|
from accelerate import Accelerator, DeepSpeedPlugin
|
|
from accelerate import DistributedDataParallelKwargs
|
|
|
|
from data_provider import PreTrainDataLoader
|
|
from exp import Exp_PreTraining
|
|
from utils import ModelInterface, OptimInterface, get_criterion
|
|
|
|
os.environ["CURL_CA_BUNDLE"] = ""
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
|
|
|
|
parser = argparse.ArgumentParser(description="SymTime-pretrain")
|
|
|
|
# basic config
|
|
parser.add_argument(
|
|
"--is_pretrain",
|
|
type=bool,
|
|
default=True,
|
|
help="Whether to perform model pre - training",
|
|
)
|
|
parser.add_argument(
|
|
"--model", type=str, default="base", help="Model type used: small, base, large"
|
|
)
|
|
parser.add_argument(
|
|
"--context_window", type=int, default=256, help="Original length of data loaded"
|
|
)
|
|
parser.add_argument(
|
|
"--time_mask_ratio",
|
|
type=float,
|
|
default=0.40,
|
|
help="Masking ratio of signal patches",
|
|
)
|
|
parser.add_argument(
|
|
"--sym_mask_ratio",
|
|
type=float,
|
|
default=0.15,
|
|
help="Masking ratio of natural language symbols",
|
|
)
|
|
parser.add_argument(
|
|
"--patch_len",
|
|
type=int,
|
|
default=16,
|
|
help="Length of each patch for data embedding in Transformer",
|
|
)
|
|
parser.add_argument(
|
|
"--stride",
|
|
type=int,
|
|
default=None,
|
|
help="Stride size for patching using the sliding windows. If None, non-overlapping patches are used",
|
|
)
|
|
|
|
# Data path related parameters
|
|
parser.add_argument(
|
|
"--data_path",
|
|
type=str,
|
|
default=r"./datasets/pretrain_data/",
|
|
help="Path to store training data",
|
|
)
|
|
parser.add_argument(
|
|
"--save_path", type=str, default="./logging", help="Path to save the model"
|
|
)
|
|
parser.add_argument(
|
|
"--number", type=int, default=2, help="Number of data read per round"
|
|
)
|
|
parser.add_argument(
|
|
"--llm_name", type=str, default="DistilBert", help="Large - language model used"
|
|
)
|
|
|
|
# Parameters related to model optimization
|
|
parser.add_argument(
|
|
"--num_epochs",
|
|
type=int,
|
|
default=1000,
|
|
help="Number of rounds for model pre - training",
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_epochs",
|
|
type=int,
|
|
default=50,
|
|
help="Number of rounds for learning rate warm - up",
|
|
)
|
|
parser.add_argument(
|
|
"--save_epochs",
|
|
type=int,
|
|
default=1,
|
|
help="Save the model once every few rounds of training",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=48, help="Batch size used for training"
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers", type=int, default=1, help="Number of workers for data loader"
|
|
)
|
|
parser.add_argument(
|
|
"--shuffle",
|
|
type=bool,
|
|
default=True,
|
|
help="Whether to shuffle the order of training samples during training",
|
|
)
|
|
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer used")
|
|
parser.add_argument("--criterion", type=str, default="MSE", help="Loss function used")
|
|
parser.add_argument(
|
|
"--warmup", type=str, default="LinearLR", help="Learning rate warm - up method used"
|
|
)
|
|
parser.add_argument(
|
|
"--scheduler",
|
|
type=str,
|
|
default="OneCycle",
|
|
help="Dynamic learning rate adjustment method used",
|
|
)
|
|
parser.add_argument(
|
|
"--learning_rate", type=float, default=1e-5, help="Training learning rate"
|
|
)
|
|
parser.add_argument(
|
|
"--momentum",
|
|
type=float,
|
|
default=0.9,
|
|
help="Momentum size used in stochastic gradient descent",
|
|
)
|
|
parser.add_argument(
|
|
"--weight_decay",
|
|
type=float,
|
|
default=1e-4,
|
|
help="L2 regularization strength suitable for Adam",
|
|
)
|
|
parser.add_argument(
|
|
"--beta1",
|
|
type=float,
|
|
default=0.9,
|
|
help="Decay rate of first - order moment estimate, degree of retention of historical gradients, default 0.9",
|
|
)
|
|
parser.add_argument(
|
|
"--beta2",
|
|
type=float,
|
|
default=0.999,
|
|
help="Decay rate of second - order moment estimate, conducive to improving stability, default 0.999",
|
|
)
|
|
parser.add_argument(
|
|
"--eps", type=float, default=1e-8, help="Constant to prevent division by zero"
|
|
)
|
|
parser.add_argument(
|
|
"--amsgrad", type=bool, default=False, help="Whether to use the AMSgrad variant"
|
|
)
|
|
parser.add_argument(
|
|
"--step_size",
|
|
type=int,
|
|
default=10,
|
|
help="Number of Epochs in StepLR that multiply the learning rate by gamma at regular intervals",
|
|
)
|
|
parser.add_argument(
|
|
"--gamma",
|
|
type=float,
|
|
default=0.99,
|
|
help="Learning rate decay multiplier for StepLR and ExponLR",
|
|
)
|
|
parser.add_argument(
|
|
"--cycle_momentum",
|
|
type=bool,
|
|
default=True,
|
|
help="Whether to use periodic momentum adjustment strategy in OneCycle",
|
|
)
|
|
parser.add_argument(
|
|
"--base_momentum",
|
|
type=float,
|
|
default=0.85,
|
|
help="Base momentum value set during learning rate adjustment",
|
|
)
|
|
parser.add_argument(
|
|
"--max_momentum",
|
|
type=float,
|
|
default=0.95,
|
|
help="Momentum value set when learning rate reaches maximum",
|
|
)
|
|
parser.add_argument(
|
|
"--anneal_strategy",
|
|
type=str,
|
|
default="cos",
|
|
help="Learning rate decay strategy used: cos or linear",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set the random seed for reproducibility
|
|
fix_seed = 2025
|
|
random.seed(fix_seed)
|
|
torch.manual_seed(fix_seed)
|
|
np.random.seed(fix_seed)
|
|
|
|
if __name__ == "__main__":
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
deepspeed_plugin = DeepSpeedPlugin(hf_ds_config="./configs/ds_config.json")
|
|
accelerator = Accelerator(
|
|
device_placement=True,
|
|
gradient_accumulation_steps=1,
|
|
cpu=False,
|
|
# kwargs_handlers=[ddp_kwargs],
|
|
# deepspeed_plugin=deepspeed_plugin
|
|
)
|
|
interface = {
|
|
"data": PreTrainDataLoader(args),
|
|
"model": ModelInterface(args, accelerator),
|
|
"criterion": get_criterion(name=args.criterion),
|
|
"optimizer": OptimInterface(args, accelerator),
|
|
}
|
|
train_data = interface["data"]
|
|
data_loader = train_data.get_dataloader()
|
|
train_data.pointer = 0
|
|
model = interface["model"].model
|
|
train_params = interface["model"].trainable_params()
|
|
criterion = interface["criterion"]
|
|
optimizer = interface["optimizer"].load_optimizer(parameters=train_params)
|
|
scheduler = interface["optimizer"].load_scheduler(
|
|
optimizer, loader_len=len(data_loader)
|
|
)
|
|
|
|
model, optimizer, scheduler, data_loader = accelerator.prepare(
|
|
model, optimizer, scheduler, data_loader
|
|
)
|
|
|
|
trainer = Exp_PreTraining(
|
|
args, model, optimizer, criterion, scheduler, accelerator, train_data
|
|
)
|
|
(
|
|
train_loss,
|
|
train_loss_mtm,
|
|
train_loss_mlm,
|
|
train_loss_t2s,
|
|
train_loss_s2t,
|
|
) = trainer.fit()
|