You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.3 KiB
147 lines
4.3 KiB
# This source code is provided for the purposes of scientific reproducibility
|
|
# under the following limited license from Element AI Inc. The code is an
|
|
# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
|
|
# expansion analysis for interpretable time series forecasting,
|
|
# https://arxiv.org/abs/1905.10437). The copyright to the source code is
|
|
# licensed under the Creative Commons - Attribution-NonCommercial 4.0
|
|
# International license (CC BY-NC 4.0):
|
|
# https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
|
|
# for the benefit of third parties or internally in production) requires an
|
|
# explicit license. The subject-matter of the N-BEATS model and associated
|
|
# materials are the property of Element AI Inc. and may be subject to patent
|
|
# protection. No license to patents is granted hereunder (whether express or
|
|
# implied). Copyright © 2020 Element AI Inc. All rights reserved.
|
|
|
|
"""
|
|
M4 Dataset
|
|
"""
|
|
import logging
|
|
import os
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import patoolib
|
|
from tqdm import tqdm
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import sys
|
|
from urllib import request
|
|
|
|
|
|
def url_file_name(url: str) -> str:
|
|
"""
|
|
Extract file name from url.
|
|
|
|
:param url: URL to extract file name from.
|
|
:return: File name.
|
|
"""
|
|
return url.split("/")[-1] if len(url) > 0 else ""
|
|
|
|
|
|
def download(url: str, file_path: str) -> None:
|
|
"""
|
|
Download a file to the given path.
|
|
|
|
:param url: URL to download
|
|
:param file_path: Where to download the content.
|
|
"""
|
|
|
|
def progress(count, block_size, total_size):
|
|
progress_pct = float(count * block_size) / float(total_size) * 100.0
|
|
sys.stdout.write(
|
|
"\rDownloading {} to {} {:.1f}%".format(url, file_path, progress_pct)
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
if not os.path.isfile(file_path):
|
|
opener = request.build_opener()
|
|
opener.addheaders = [("User-agent", "Mozilla/5.0")]
|
|
request.install_opener(opener)
|
|
pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True)
|
|
f, _ = request.urlretrieve(url, file_path, progress)
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
file_info = os.stat(f)
|
|
logging.info(
|
|
f"Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes."
|
|
)
|
|
else:
|
|
file_info = os.stat(file_path)
|
|
logging.info(f"File already exists: {file_path} {file_info.st_size} bytes.")
|
|
|
|
|
|
@dataclass()
|
|
class M4Dataset:
|
|
ids: np.ndarray
|
|
groups: np.ndarray
|
|
frequencies: np.ndarray
|
|
horizons: np.ndarray
|
|
values: np.ndarray
|
|
|
|
@staticmethod
|
|
def load(
|
|
training: bool = True, dataset_file: str = "../datasets/m4"
|
|
) -> "M4Dataset":
|
|
"""
|
|
Load cached datasets.
|
|
|
|
:param training: Load training part if training is True, test part otherwise.
|
|
"""
|
|
info_file = os.path.join(dataset_file, "M4-info.csv")
|
|
train_cache_file = os.path.join(dataset_file, "training.npz")
|
|
test_cache_file = os.path.join(dataset_file, "test.npz")
|
|
m4_info = pd.read_csv(info_file)
|
|
return M4Dataset(
|
|
ids=m4_info.M4id.values,
|
|
groups=m4_info.SP.values,
|
|
frequencies=m4_info.Frequency.values,
|
|
horizons=m4_info.Horizon.values,
|
|
values=np.load(
|
|
train_cache_file if training else test_cache_file, allow_pickle=True
|
|
),
|
|
)
|
|
|
|
|
|
@dataclass()
|
|
class M4Meta:
|
|
seasonal_patterns = ["Yearly", "Quarterly", "Monthly", "Weekly", "Daily", "Hourly"]
|
|
horizons = [6, 8, 18, 13, 14, 48]
|
|
frequencies = [1, 4, 12, 1, 1, 24]
|
|
horizons_map = {
|
|
"Yearly": 6,
|
|
"Quarterly": 8,
|
|
"Monthly": 18,
|
|
"Weekly": 13,
|
|
"Daily": 14,
|
|
"Hourly": 48,
|
|
} # different predict length
|
|
frequency_map = {
|
|
"Yearly": 1,
|
|
"Quarterly": 4,
|
|
"Monthly": 12,
|
|
"Weekly": 1,
|
|
"Daily": 1,
|
|
"Hourly": 24,
|
|
}
|
|
history_size = {
|
|
"Yearly": 1.5,
|
|
"Quarterly": 1.5,
|
|
"Monthly": 1.5,
|
|
"Weekly": 10,
|
|
"Daily": 10,
|
|
"Hourly": 10,
|
|
} # from interpretable.gin
|
|
|
|
|
|
def load_m4_info() -> pd.DataFrame:
|
|
"""
|
|
Load M4Info file.
|
|
|
|
:return: Pandas DataFrame of M4Info.
|
|
"""
|
|
return pd.read_csv(INFO_FILE_PATH)
|