You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.3 KiB

# -*- coding: utf-8 -*-
"""
Created on 2024/9/30 17:06
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
"""
import os
from os import path
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Dict, Any, Tuple
import warnings
warnings.filterwarnings("ignore")
class TSDataset(Dataset):
"""Modified dataset object for the pre-training of SymTime"""
def __init__(
self,
time: torch.Tensor,
time_mask: torch.Tensor,
sym_ids: torch.Tensor,
sym_mask: torch.Tensor,
) -> None:
self.time, self.time_mask = time, time_mask
self.sym_ids, self.sym_mask = sym_ids, sym_mask
def __len__(self) -> int:
return self.time.size(0)
def __getitem__(
self, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
time, time_mask = self.time[index], self.time_mask[index]
sym_ids, sym_mask = self.sym_ids[index], self.sym_mask[index]
return time, time_mask, sym_ids, sym_mask
class PreTrainDataLoader(object):
"""List of DataLoaders for loading the pre-trained dataset"""
def __init__(self, args: Any) -> None:
# The file path for the pre-training dataset
self.data_path = args.data_path
# The files number for pre-training
self.num_data = len(os.listdir(self.data_path))
# Number of data points read per iteration
self.number = args.number
self.list = list(range(0, self.num_data, self.number))
self.pointer = 0
# Parameters related to creating a DataLoader object
self.batch_size = args.batch_size
self.shuffle = args.shuffle
self.num_workers = args.num_workers
def __len__(self) -> int:
"""How many batches of data need to be loaded in one epoch?"""
return len(self.list)
def load_data(self) -> Dict:
"""Methods for loading data"""
data_dict = dict(time=[], time_mask=[], sym_ids=[], sym_mask=[])
index = self.list[self.pointer]
# Move the dataset pointer backward.
self.pointer = (self.pointer + 1) % len(self.list)
for file in os.listdir(self.data_path)[index : index + self.number]:
file_path = path.join(self.data_path, file)
data = torch.load(file_path, weights_only=False)
for key in data_dict.keys():
data_dict[key].append(data[key])
# Concatenate the datasets.
for key, value in data_dict.items():
data_dict[key] = torch.concat(value, dim=0)
return data_dict
def get_dataloader(
self, batch_size: Optional[int] = None, shuffle: Optional[bool] = None
) -> DataLoader:
"""How to obtain the DataLoader object used for pre-training"""
data_dict = self.load_data()
dataset = TSDataset(
time=data_dict["time"],
time_mask=data_dict["time_mask"],
sym_ids=data_dict["sym_ids"],
sym_mask=data_dict["sym_mask"],
)
batch_size = self.batch_size if batch_size is None else batch_size
shuffle = self.shuffle if shuffle is None else shuffle
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
)