import os import shutil import random import argparse def split_dataset(source_dir, output_dir, split_ratio=0.8): """ Split the dataset into training and validation sets. :param source_dir: Directory containing the dataset with subfolders for each class. :param output_dir: Base directory where train and val subdirectories will be created. :param split_ratio: Ratio of the training set size to the total dataset size. """ # 创建训练集和验证集目录 train_dir = os.path.join(output_dir, "train") val_dir = os.path.join(output_dir, "val") if not os.path.exists(train_dir): os.makedirs(train_dir) if not os.path.exists(val_dir): os.makedirs(val_dir) for class_dir in os.listdir(source_dir): class_path = os.path.join(source_dir, class_dir) if os.path.isdir(class_path): files = os.listdir(class_path) random.shuffle(files) split_point = int(len(files) * split_ratio) train_files = files[:split_point] val_files = files[split_point:] train_class_dir = os.path.join(train_dir, class_dir) val_class_dir = os.path.join(val_dir, class_dir) if not os.path.exists(train_class_dir): os.makedirs(train_class_dir) if not os.path.exists(val_class_dir): os.makedirs(val_class_dir) for file in train_files: shutil.copy(os.path.join(class_path, file), os.path.join(train_class_dir, file)) for file in val_files: shutil.copy(os.path.join(class_path, file), os.path.join(val_class_dir, file)) # 打印数据集统计信息 print(f"数据集分割完成:") print(f"- 训练集路径: {train_dir}") print(f"- 验证集路径: {val_dir}") print(f"- 分割比例: {split_ratio:.2f} (训练) / {1-split_ratio:.2f} (验证)") return train_dir, val_dir if __name__ == "__main__": # 创建命令行参数解析器 parser = argparse.ArgumentParser(description="将情感语音数据集分割为训练集和验证集") parser.add_argument("--source", type=str, default="wav", help="源数据目录 (默认: wav)") parser.add_argument("--output", type=str, default="dataset", help="输出基础目录,将在其中创建train和val子目录 (默认: dataset)") parser.add_argument("--ratio", type=float, default=0.8, help="训练集比例 (0-1之间的浮点数, 默认: 0.8)") parser.add_argument("--seed", type=int, default=42, help="随机种子 (默认: 无)") parser.add_argument("--force", action="store_true", help="如果目标目录已存在则强制删除") # 解析命令行参数 args = parser.parse_args() # 设置随机种子(如果提供) if args.seed is not None: random.seed(args.seed) print(f"已设置随机种子: {args.seed}") # 检查分割比例是否有效 if args.ratio <= 0 or args.ratio >= 1: parser.error("分割比例必须在0和1之间") # 构建实际的训练和验证目录路径 train_dir = os.path.join(args.output, "train") val_dir = os.path.join(args.output, "val") # 处理目标目录 if args.force: if os.path.exists(train_dir): shutil.rmtree(train_dir) if os.path.exists(val_dir): shutil.rmtree(val_dir) else: # 检查目标目录是否已存在 if os.path.exists(train_dir) or os.path.exists(val_dir): response = input(f"目标目录 ({train_dir} 或 {val_dir}) 已存在,是否继续并覆盖? (y/n): ") if response.lower() != 'y': print("操作已取消") exit(0) else: if os.path.exists(train_dir): shutil.rmtree(train_dir) if os.path.exists(val_dir): shutil.rmtree(val_dir) # 执行数据集分割 split_dataset(args.source, args.output, args.ratio)