You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
2.6 KiB

from data_provider.data_loader import (
Dataset_ETT_hour,
Dataset_ETT_minute,
Dataset_Custom,
Dataset_M4,
PSMSegLoader,
MSLSegLoader,
SMAPSegLoader,
SMDSegLoader,
SWATSegLoader,
UEAloader,
)
from data_provider.uea import collate_fn
from torch.utils.data import DataLoader
data_dict = {
"ETTh1": Dataset_ETT_hour,
"ETTh2": Dataset_ETT_hour,
"ETTm1": Dataset_ETT_minute,
"ETTm2": Dataset_ETT_minute,
"custom": Dataset_Custom,
"m4": Dataset_M4,
"PSM": PSMSegLoader,
"MSL": MSLSegLoader,
"SMAP": SMAPSegLoader,
"SMD": SMDSegLoader,
"SWAT": SWATSegLoader,
"UEA": UEAloader,
}
def data_provider(args, flag):
Data = data_dict[args.data]
timeenc = 0 if args.embed != "timeF" else 1
shuffle_flag = False if (flag == "test" or flag == "TEST") else True
drop_last = False
batch_size = args.batch_size
freq = args.freq
if args.task_name == "anomaly_detection":
drop_last = False
data_set = Data(
args=args,
root_path=args.root_path,
win_size=args.seq_len,
flag=flag,
)
print(flag, len(data_set))
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last,
)
return data_set, data_loader
elif args.task_name == "classification":
drop_last = False
data_set = Data(
args=args,
root_path=args.root_path,
flag=flag,
)
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last,
collate_fn=lambda x: collate_fn(x, max_len=args.seq_len),
)
return data_set, data_loader
else:
if args.data == "m4":
drop_last = False
data_set = Data(
args=args,
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq,
seasonal_patterns=args.seasonal_patterns,
)
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last,
)
return data_set, data_loader