You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
245 lines
8.0 KiB
245 lines
8.0 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
This code is used to generate bimodal data for the SymTime,
|
|
including time series and symbolic expressions, for model pretraining.
|
|
We further encapsulate the S2Generator interface to enable multithreaded data processing for data generation.
|
|
To ensure diversity in the data generation mechanism, we iterate over different random seeds for each generation.
|
|
|
|
The PyTorch code for SymTime can be found here:
|
|
The Paper for SymTime can be found here:
|
|
|
|
Externally passed variables:
|
|
- root_path: The file path to save the generated S2 data;
|
|
- start_seed: The start seed to generate the S2 data;
|
|
- end seed: The end seed for stopping;
|
|
- max_input_dim: The maximum input dimension;
|
|
- max_output_dim: The maximum output dimension;
|
|
- length: The length of the generated S2 data;
|
|
- num_threads: The number of threads to use;
|
|
|
|
before running this code, please ensure that the S2Generator package is installed.
|
|
pip install s2generator
|
|
|
|
Created on 2025/09/02 20:14:50
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/S2Generator
|
|
"""
|
|
import argparse
|
|
import os
|
|
import threading
|
|
import queue
|
|
import time
|
|
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import torch
|
|
|
|
from colorama import Fore, Style
|
|
from S2Generator import SeriesParams, SymbolParams, Generator
|
|
|
|
from typing import List, Optional, Any, Union
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--root_path", type=str)
|
|
parser.add_argument("--start_seed", type=int, default=0)
|
|
parser.add_argument("--end_seed", type=int, default=10)
|
|
parser.add_argument("--max_input_dim", type=int, default=6)
|
|
parser.add_argument("--max_output_dim", type=int, default=6)
|
|
parser.add_argument("--length", type=int, default=768)
|
|
parser.add_argument("--num_threads", type=int, default=4, help="number of threads")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Parameters related to controlling data generation
|
|
series_params = SeriesParams()
|
|
symbol_params = SymbolParams(max_trials=64)
|
|
|
|
generator = Generator(series_params, symbol_params)
|
|
|
|
|
|
def process_item(item: int) -> str:
|
|
"""
|
|
Process a single item (random seed) to generate and save S2 data.
|
|
|
|
For each random seed, generates data for various input and output dimensions,
|
|
saves the results to organized directory structure, and provides progress feedback.
|
|
|
|
:param item: Random seed value used for reproducible data generation
|
|
:return: Status message indicating completion of processing for this seed
|
|
"""
|
|
# Create directory for this seed
|
|
folder_path = os.path.join(args.root_path, str(item))
|
|
os.makedirs(folder_path, exist_ok=True)
|
|
|
|
# Create random number generator with seed for reproducibility
|
|
rng = np.random.RandomState(item)
|
|
|
|
# Generate data for all input/output dimension combinations
|
|
for input_dim in range(1, args.max_input_dim + 1):
|
|
for output_dim in range(1, args.max_output_dim + 1):
|
|
# Generate S2 data
|
|
symbol, excitation, response = generator.run(
|
|
rng=rng,
|
|
n_inputs_points=args.length,
|
|
input_dimension=input_dim,
|
|
output_dimension=output_dim,
|
|
)
|
|
|
|
# Save data if generation was successful
|
|
if symbol is not None:
|
|
# Create filename with dimension information
|
|
file_name = f"ID={input_dim}_OD={output_dim}.pt"
|
|
file_path = os.path.join(folder_path, file_name)
|
|
|
|
# Save data as PyTorch tensor
|
|
torch.save(
|
|
obj={
|
|
"symbol": symbol,
|
|
"excitation": torch.from_numpy(excitation).float(),
|
|
"response": torch.from_numpy(response).float(),
|
|
},
|
|
f=file_path,
|
|
)
|
|
|
|
# Update progress bar with current status
|
|
pbar.update(1)
|
|
pbar.set_postfix(
|
|
{
|
|
"Seed": f"{item}",
|
|
"Input Dim": f"{input_dim}",
|
|
"Output Dim": f"{output_dim}",
|
|
"Status": (
|
|
Fore.GREEN + "Success" + Style.RESET_ALL
|
|
if symbol is not None
|
|
else "Failure"
|
|
),
|
|
}
|
|
)
|
|
|
|
# Prevent CPU overheating with cooldown period
|
|
time.sleep(3)
|
|
|
|
return f"Processed: {item}"
|
|
|
|
|
|
def worker(task_queue: queue.Queue, result_queue: Optional[queue.Queue]) -> None:
|
|
"""
|
|
Worker thread function that processes tasks from a queue.
|
|
|
|
Continuously retrieves tasks from the task queue, processes them,
|
|
and optionally places results in the result queue.
|
|
|
|
:param task_queue: Queue containing tasks to be processed
|
|
:param result_queue: Optional queue for storing processing results
|
|
"""
|
|
while True:
|
|
try:
|
|
# Get task from queue with timeout to prevent infinite blocking
|
|
item = task_queue.get(timeout=1)
|
|
|
|
# Process the item
|
|
result = process_item(item)
|
|
|
|
# Place result in result queue if provided
|
|
if result_queue is not None:
|
|
result_queue.put(result)
|
|
|
|
# Mark task as completed
|
|
task_queue.task_done()
|
|
|
|
except queue.Empty:
|
|
# Queue is empty, exit thread
|
|
break
|
|
except Exception as e:
|
|
print(f"Error in thread {threading.current_thread().name}: {e}")
|
|
task_queue.task_done()
|
|
|
|
|
|
def parallel_process(
|
|
data: List[Any], num_threads: Optional[int] = None, return_results: bool = False
|
|
) -> Optional[List[Any]]:
|
|
"""
|
|
Process data in parallel using multiple threads.
|
|
|
|
Creates a thread pool to process items from the data list concurrently.
|
|
Supports optional collection of processing results.
|
|
|
|
:param data: List of items to be processed
|
|
:param num_threads: Number of threads to use (defaults to min(4, len(data)))
|
|
:param return_results: Whether to collect and return processing results
|
|
:return: List of results if return_results=True, otherwise None
|
|
"""
|
|
if not data:
|
|
return [] if return_results else None
|
|
|
|
# Determine optimal thread count
|
|
if num_threads is None:
|
|
num_threads = min(len(data), 4) # Default to 4 threads maximum
|
|
|
|
# Ensure thread count doesn't exceed data size
|
|
num_threads = min(num_threads, len(data))
|
|
|
|
# Create task and result queues
|
|
task_queue = queue.Queue()
|
|
result_queue = queue.Queue() if return_results else None
|
|
|
|
# Populate task queue with data items
|
|
for item in data:
|
|
task_queue.put(item)
|
|
|
|
# Create and start worker threads
|
|
threads = []
|
|
for i in range(num_threads):
|
|
thread = threading.Thread(
|
|
target=worker, args=(task_queue, result_queue), name=f"Worker-{i + 1}"
|
|
)
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
# Wait for all tasks to complete
|
|
task_queue.join()
|
|
|
|
# Wait for all threads to finish
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
# Collect results if requested
|
|
if return_results:
|
|
results = []
|
|
while not result_queue.empty():
|
|
results.append(result_queue.get())
|
|
return results
|
|
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
data_list = list(range(args.start_seed, args.end_seed))
|
|
|
|
print(f"Processing {len(data_list)} random seeds using {args.num_threads} threads.")
|
|
print("Starting processing...")
|
|
time.sleep(1)
|
|
|
|
os.makedirs(args.root_path, exist_ok=True)
|
|
|
|
# Record start time for performance measurement
|
|
start_time = time.time()
|
|
|
|
# Create progress bar with total number of operations
|
|
with tqdm(
|
|
total=args.max_input_dim * args.max_output_dim * len(data_list),
|
|
desc="S2Generation",
|
|
) as pbar:
|
|
# Execute parallel processing
|
|
results = parallel_process(
|
|
data_list, num_threads=args.num_threads, return_results=True
|
|
)
|
|
|
|
# Record end time and calculate duration
|
|
end_time = time.time()
|
|
|
|
print(f"\nProcessing completed in: {end_time - start_time:.2f} seconds")
|
|
print(f"Processed {len(results) if results else 0} items")
|