You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
206 lines
6.7 KiB
206 lines
6.7 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/10/19 21:20
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import argparse
|
|
import torch
|
|
from exp import Exp_Anomaly_Detection
|
|
from utils.print_args import print_args
|
|
import random
|
|
import numpy as np
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="SymTime-Anomaly_Detection")
|
|
|
|
# basic config
|
|
parser.add_argument("--task_name", type=str, default="anomaly_detection")
|
|
parser.add_argument("--is_training", type=int, default=1, help="status")
|
|
parser.add_argument("--dataset_name", type=str, default="SWAT", help="model id")
|
|
parser.add_argument("--model", type=str, default="SymTime")
|
|
parser.add_argument(
|
|
"--pretrain_path", type=str, default="./models/params/finetuning.pth"
|
|
)
|
|
parser.add_argument("--pretrain_id", type=str, default="zero")
|
|
|
|
# data loader
|
|
parser.add_argument("--data", type=str, default="SWAT", help="datasets type")
|
|
parser.add_argument(
|
|
"--root_path",
|
|
type=str,
|
|
default="./datasets/SWaT/",
|
|
help="root path of the data file",
|
|
)
|
|
parser.add_argument(
|
|
"--data_path", type=str, default="electricity.csv", help="data file"
|
|
)
|
|
parser.add_argument(
|
|
"--features",
|
|
type=str,
|
|
default="M",
|
|
help="forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate",
|
|
)
|
|
parser.add_argument(
|
|
"--target", type=str, default="OT", help="target feature in S or MS task"
|
|
)
|
|
parser.add_argument(
|
|
"--freq",
|
|
type=str,
|
|
default="h",
|
|
help="freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoints",
|
|
type=str,
|
|
default="./checkpoints/",
|
|
help="location of model checkpoints",
|
|
)
|
|
parser.add_argument("--individual", type=bool, default=False)
|
|
|
|
parser.add_argument("--anomaly_ratio", type=float, default=0.5)
|
|
# do patching
|
|
parser.add_argument(
|
|
"--forward_layers", type=int, default=3, help="the feed forward layers numbers"
|
|
)
|
|
parser.add_argument(
|
|
"--patch_len", type=int, default=16, help="Divide the length of the patch"
|
|
)
|
|
parser.add_argument(
|
|
"--stride", type=int, default=1, help="The division step size of patching"
|
|
)
|
|
parser.add_argument(
|
|
"--padding_patch", type=bool, default=True, help="Whether to fill the last Patch"
|
|
)
|
|
parser.add_argument(
|
|
"--out_dropout",
|
|
type=float,
|
|
default=0.1,
|
|
help="Dropout of the final output of the model",
|
|
)
|
|
parser.add_argument(
|
|
"--use_avg", type=bool, default=True, help="use moving average decomposition"
|
|
)
|
|
parser.add_argument(
|
|
"--moving_avg", type=int, default=25, help="window size of moving average"
|
|
)
|
|
# forecasting task
|
|
parser.add_argument("--seq_len", type=int, default=100, help="input sequence length")
|
|
parser.add_argument("--label_len", type=int, default=0, help="start token length")
|
|
parser.add_argument(
|
|
"--pred_len", type=int, default=0, help="prediction sequence length"
|
|
)
|
|
parser.add_argument(
|
|
"--seasonal_patterns", type=str, default="Monthly", help="subset for M4"
|
|
)
|
|
parser.add_argument(
|
|
"--inverse", action="store_true", help="inverse output data", default=False
|
|
)
|
|
|
|
# model define
|
|
parser.add_argument("--dropout", type=float, default=0.1, help="dropout")
|
|
parser.add_argument(
|
|
"--embed",
|
|
type=str,
|
|
default="timeF",
|
|
help="time features encoding, options:[timeF, fixed, learned]",
|
|
)
|
|
parser.add_argument("--activation", type=str, default="gelu", help="activation")
|
|
parser.add_argument(
|
|
"--channel_independence",
|
|
type=int,
|
|
default=1,
|
|
help="0: channel dependence 1: channel independence for FreTS model",
|
|
)
|
|
parser.add_argument(
|
|
"--use_norm", type=int, default=1, help="whether to use normalize; True 1 False 0"
|
|
)
|
|
|
|
# optimization
|
|
parser.add_argument("--enc_in", type=int, default=51, help="encoder input size")
|
|
parser.add_argument(
|
|
"--num_workers", type=int, default=5, help="data loader num workers"
|
|
)
|
|
parser.add_argument("--itr", type=int, default=1, help="experiments times")
|
|
parser.add_argument("--train_epochs", type=int, default=5, help="train epochs")
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=4, help="batch size of train input data"
|
|
)
|
|
parser.add_argument("--patience", type=int, default=3, help="early stopping patience")
|
|
parser.add_argument(
|
|
"--learning_rate", type=float, default=0.0001, help="optimizer learning rate"
|
|
)
|
|
parser.add_argument("--des", type=str, default="test", help="exp description")
|
|
parser.add_argument("--loss", type=str, default="MSE", help="loss function")
|
|
parser.add_argument("--lradj", type=str, default="type1", help="adjust learning rate")
|
|
parser.add_argument(
|
|
"--use_amp",
|
|
action="store_true",
|
|
help="use automatic mixed precision training",
|
|
default=False,
|
|
)
|
|
# GPU
|
|
parser.add_argument("--use_gpu", type=bool, default=True, help="use gpu")
|
|
parser.add_argument("--gpu", type=int, default=0, help="gpu")
|
|
parser.add_argument(
|
|
"--use_multi_gpu", action="store_true", help="use multiple gpus", default=False
|
|
)
|
|
parser.add_argument(
|
|
"--devices", type=str, default="0", help="device ids of multile gpus"
|
|
)
|
|
|
|
# metrics (dtw)
|
|
parser.add_argument(
|
|
"--use_dtw", type=bool, default=False, help="the controller of using dtw metric"
|
|
)
|
|
|
|
# Augmentation
|
|
parser.add_argument(
|
|
"--augmentation_ratio", type=int, default=0, help="How many times to augment"
|
|
)
|
|
parser.add_argument("--seed", type=int, default=2025, help="Randomization seed")
|
|
|
|
args = parser.parse_args()
|
|
args.use_gpu = True if torch.cuda.is_available() else False
|
|
|
|
# Set the random seed for reproducibility
|
|
random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
if __name__ == "__main__":
|
|
if args.use_gpu and args.use_multi_gpu:
|
|
args.devices = args.devices.replace(" ", "")
|
|
device_ids = args.devices.split(",")
|
|
args.device_ids = [int(id_) for id_ in device_ids]
|
|
args.gpu = args.device_ids[0]
|
|
torch.cuda.set_device(args.gpu)
|
|
|
|
print("Args in experiment:")
|
|
print_args(args)
|
|
|
|
exp = Exp_Anomaly_Detection(args) # set experiments
|
|
setting = "{}_{}_{}_{}_seq_len{}_moving_avg{}_forward_layers{}_patch_len{}_stride{}_batch_size{}_learning_rate{}_lradj{}_seed{}_{}".format(
|
|
args.task_name,
|
|
args.dataset_name,
|
|
args.model,
|
|
args.data,
|
|
args.seq_len,
|
|
args.use_avg,
|
|
args.forward_layers,
|
|
args.patch_len,
|
|
args.stride,
|
|
args.batch_size,
|
|
args.learning_rate,
|
|
args.lradj,
|
|
args.seed,
|
|
args.pretrain_id,
|
|
)
|
|
|
|
print(">>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>".format(setting))
|
|
exp.train(setting)
|
|
print(">>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<".format(setting))
|
|
exp.test(setting)
|
|
torch.cuda.empty_cache()
|