From 6b371d5faa0a243ed147c72c86ef35e5d71cf06b Mon Sep 17 00:00:00 2001 From: pnmfazke8 <2712887573@qq.com> Date: Tue, 20 May 2025 22:22:18 +0800 Subject: [PATCH] ADD file via upload --- process_data.py | 102 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 process_data.py diff --git a/process_data.py b/process_data.py new file mode 100644 index 0000000..e5979e1 --- /dev/null +++ b/process_data.py @@ -0,0 +1,102 @@ +import os +import shutil +import random +import argparse + +def split_dataset(source_dir, output_dir, split_ratio=0.8): + """ + Split the dataset into training and validation sets. + + :param source_dir: Directory containing the dataset with subfolders for each class. + :param output_dir: Base directory where train and val subdirectories will be created. + :param split_ratio: Ratio of the training set size to the total dataset size. + """ + # 创建训练集和验证集目录 + train_dir = os.path.join(output_dir, "train") + val_dir = os.path.join(output_dir, "val") + + if not os.path.exists(train_dir): + os.makedirs(train_dir) + if not os.path.exists(val_dir): + os.makedirs(val_dir) + + for class_dir in os.listdir(source_dir): + class_path = os.path.join(source_dir, class_dir) + if os.path.isdir(class_path): + files = os.listdir(class_path) + random.shuffle(files) + split_point = int(len(files) * split_ratio) + train_files = files[:split_point] + val_files = files[split_point:] + + train_class_dir = os.path.join(train_dir, class_dir) + val_class_dir = os.path.join(val_dir, class_dir) + if not os.path.exists(train_class_dir): + os.makedirs(train_class_dir) + if not os.path.exists(val_class_dir): + os.makedirs(val_class_dir) + + for file in train_files: + shutil.copy(os.path.join(class_path, file), os.path.join(train_class_dir, file)) + for file in val_files: + shutil.copy(os.path.join(class_path, file), os.path.join(val_class_dir, file)) + + # 打印数据集统计信息 + print(f"数据集分割完成:") + print(f"- 训练集路径: {train_dir}") + print(f"- 验证集路径: {val_dir}") + print(f"- 分割比例: {split_ratio:.2f} (训练) / {1-split_ratio:.2f} (验证)") + + return train_dir, val_dir + +if __name__ == "__main__": + # 创建命令行参数解析器 + parser = argparse.ArgumentParser(description="将情感语音数据集分割为训练集和验证集") + parser.add_argument("--source", type=str, default="wav", + help="源数据目录 (默认: wav)") + parser.add_argument("--output", type=str, default="dataset", + help="输出基础目录,将在其中创建train和val子目录 (默认: dataset)") + parser.add_argument("--ratio", type=float, default=0.8, + help="训练集比例 (0-1之间的浮点数, 默认: 0.8)") + parser.add_argument("--seed", type=int, default=42, + help="随机种子 (默认: 无)") + parser.add_argument("--force", action="store_true", + help="如果目标目录已存在则强制删除") + + # 解析命令行参数 + args = parser.parse_args() + + # 设置随机种子(如果提供) + if args.seed is not None: + random.seed(args.seed) + print(f"已设置随机种子: {args.seed}") + + # 检查分割比例是否有效 + if args.ratio <= 0 or args.ratio >= 1: + parser.error("分割比例必须在0和1之间") + + # 构建实际的训练和验证目录路径 + train_dir = os.path.join(args.output, "train") + val_dir = os.path.join(args.output, "val") + + # 处理目标目录 + if args.force: + if os.path.exists(train_dir): + shutil.rmtree(train_dir) + if os.path.exists(val_dir): + shutil.rmtree(val_dir) + else: + # 检查目标目录是否已存在 + if os.path.exists(train_dir) or os.path.exists(val_dir): + response = input(f"目标目录 ({train_dir} 或 {val_dir}) 已存在,是否继续并覆盖? (y/n): ") + if response.lower() != 'y': + print("操作已取消") + exit(0) + else: + if os.path.exists(train_dir): + shutil.rmtree(train_dir) + if os.path.exists(val_dir): + shutil.rmtree(val_dir) + + # 执行数据集分割 + split_dataset(args.source, args.output, args.ratio) \ No newline at end of file