You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
1.9 KiB
63 lines
1.9 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/10/13 10:04
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
from os import path
|
|
import yaml
|
|
|
|
from torch import nn
|
|
from typing import List
|
|
from models import SymTime_pretrain as SymTime
|
|
from colorama import Fore, Style
|
|
|
|
|
|
class ModelInterface(object):
|
|
"""
|
|
A general interface for loading models,
|
|
including model pre-training and model fine-tuning
|
|
"""
|
|
|
|
def __init__(self, args, accelerator) -> None:
|
|
self.args = args
|
|
|
|
# Accelerator object used
|
|
self.accelerator = accelerator
|
|
|
|
# Determine whether to pre-train the model
|
|
self.is_pretrain = args.is_pretrain
|
|
|
|
# Determine the model to use
|
|
self.model_type = args.model
|
|
self.model = self.load_pretrain()
|
|
|
|
def load_pretrain(self) -> nn.Module:
|
|
"""Load the initialized model for pre-training"""
|
|
self.accelerator.print(
|
|
Fore.RED + "Now is loading model" + Style.RESET_ALL, end=" -> "
|
|
)
|
|
|
|
# Get the address of the configuration file
|
|
configs_path = path.join("configs", f"SymTime_{self.model_type}.yaml")
|
|
with open(configs_path, "r", encoding="utf-8") as file:
|
|
config = yaml.safe_load(file)
|
|
model = SymTime(
|
|
config,
|
|
context_window=self.args.context_window,
|
|
time_mask_ratio=self.args.time_mask_ratio,
|
|
sym_mask_ratio=self.args.sym_mask_ratio,
|
|
)
|
|
self.accelerator.print(Fore.GREEN + "successfully loaded!" + Style.RESET_ALL)
|
|
return model
|
|
|
|
def trainable_params(self) -> List:
|
|
"""Get trainable model parameters"""
|
|
train_params = []
|
|
for params in self.model.parameters():
|
|
if params.requires_grad is True:
|
|
# Parameters for which gradients are not calculated are frozen parameters
|
|
train_params.append(params)
|
|
return train_params
|