You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
4.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import shutil
import random
import argparse
def split_dataset(source_dir, output_dir, split_ratio=0.8):
"""
Split the dataset into training and validation sets.
:param source_dir: Directory containing the dataset with subfolders for each class.
:param output_dir: Base directory where train and val subdirectories will be created.
:param split_ratio: Ratio of the training set size to the total dataset size.
"""
# 创建训练集和验证集目录
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(val_dir):
os.makedirs(val_dir)
for class_dir in os.listdir(source_dir):
class_path = os.path.join(source_dir, class_dir)
if os.path.isdir(class_path):
files = os.listdir(class_path)
random.shuffle(files)
split_point = int(len(files) * split_ratio)
train_files = files[:split_point]
val_files = files[split_point:]
train_class_dir = os.path.join(train_dir, class_dir)
val_class_dir = os.path.join(val_dir, class_dir)
if not os.path.exists(train_class_dir):
os.makedirs(train_class_dir)
if not os.path.exists(val_class_dir):
os.makedirs(val_class_dir)
for file in train_files:
shutil.copy(os.path.join(class_path, file), os.path.join(train_class_dir, file))
for file in val_files:
shutil.copy(os.path.join(class_path, file), os.path.join(val_class_dir, file))
# 打印数据集统计信息
print(f"数据集分割完成:")
print(f"- 训练集路径: {train_dir}")
print(f"- 验证集路径: {val_dir}")
print(f"- 分割比例: {split_ratio:.2f} (训练) / {1-split_ratio:.2f} (验证)")
return train_dir, val_dir
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="将情感语音数据集分割为训练集和验证集")
parser.add_argument("--source", type=str, default="wav",
help="源数据目录 (默认: wav)")
parser.add_argument("--output", type=str, default="dataset",
help="输出基础目录将在其中创建train和val子目录 (默认: dataset)")
parser.add_argument("--ratio", type=float, default=0.8,
help="训练集比例 (0-1之间的浮点数, 默认: 0.8)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子 (默认: 无)")
parser.add_argument("--force", action="store_true",
help="如果目标目录已存在则强制删除")
# 解析命令行参数
args = parser.parse_args()
# 设置随机种子(如果提供)
if args.seed is not None:
random.seed(args.seed)
print(f"已设置随机种子: {args.seed}")
# 检查分割比例是否有效
if args.ratio <= 0 or args.ratio >= 1:
parser.error("分割比例必须在0和1之间")
# 构建实际的训练和验证目录路径
train_dir = os.path.join(args.output, "train")
val_dir = os.path.join(args.output, "val")
# 处理目标目录
if args.force:
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
if os.path.exists(val_dir):
shutil.rmtree(val_dir)
else:
# 检查目标目录是否已存在
if os.path.exists(train_dir) or os.path.exists(val_dir):
response = input(f"目标目录 ({train_dir}{val_dir}) 已存在,是否继续并覆盖? (y/n): ")
if response.lower() != 'y':
print("操作已取消")
exit(0)
else:
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
if os.path.exists(val_dir):
shutil.rmtree(val_dir)
# 执行数据集分割
split_dataset(args.source, args.output, args.ratio)