import os import yaml import torch from torch import nn from models import SymTime_finetune as SymTime class Exp_Basic(object): def __init__(self, args): self.args = args self.device = self._acquire_device() self.model = self._build_model().to(self.device) def _build_model(self): with open("./configs/SymTime_base.yaml", "r", encoding="utf-8") as file: configs = yaml.safe_load(file) model = SymTime(args=self.args, configs=configs).float() if self.args.use_multi_gpu and self.args.use_gpu: model = nn.DataParallel(model, device_ids=self.args.device_ids) return model def _acquire_device(self): if self.args.use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = ( str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices ) device = torch.device("cuda:{}".format(self.args.gpu)) print("Use GPU: cuda:{}".format(self.args.gpu)) else: device = torch.device("cpu") print("Use CPU") return device def _get_data(self): pass def vali(self): pass def train(self): pass def test(self): pass