|
|
|
@ -0,0 +1,102 @@
|
|
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
import random
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
def split_dataset(source_dir, output_dir, split_ratio=0.8):
|
|
|
|
|
"""
|
|
|
|
|
Split the dataset into training and validation sets.
|
|
|
|
|
|
|
|
|
|
:param source_dir: Directory containing the dataset with subfolders for each class.
|
|
|
|
|
:param output_dir: Base directory where train and val subdirectories will be created.
|
|
|
|
|
:param split_ratio: Ratio of the training set size to the total dataset size.
|
|
|
|
|
"""
|
|
|
|
|
# 创建训练集和验证集目录
|
|
|
|
|
train_dir = os.path.join(output_dir, "train")
|
|
|
|
|
val_dir = os.path.join(output_dir, "val")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(train_dir):
|
|
|
|
|
os.makedirs(train_dir)
|
|
|
|
|
if not os.path.exists(val_dir):
|
|
|
|
|
os.makedirs(val_dir)
|
|
|
|
|
|
|
|
|
|
for class_dir in os.listdir(source_dir):
|
|
|
|
|
class_path = os.path.join(source_dir, class_dir)
|
|
|
|
|
if os.path.isdir(class_path):
|
|
|
|
|
files = os.listdir(class_path)
|
|
|
|
|
random.shuffle(files)
|
|
|
|
|
split_point = int(len(files) * split_ratio)
|
|
|
|
|
train_files = files[:split_point]
|
|
|
|
|
val_files = files[split_point:]
|
|
|
|
|
|
|
|
|
|
train_class_dir = os.path.join(train_dir, class_dir)
|
|
|
|
|
val_class_dir = os.path.join(val_dir, class_dir)
|
|
|
|
|
if not os.path.exists(train_class_dir):
|
|
|
|
|
os.makedirs(train_class_dir)
|
|
|
|
|
if not os.path.exists(val_class_dir):
|
|
|
|
|
os.makedirs(val_class_dir)
|
|
|
|
|
|
|
|
|
|
for file in train_files:
|
|
|
|
|
shutil.copy(os.path.join(class_path, file), os.path.join(train_class_dir, file))
|
|
|
|
|
for file in val_files:
|
|
|
|
|
shutil.copy(os.path.join(class_path, file), os.path.join(val_class_dir, file))
|
|
|
|
|
|
|
|
|
|
# 打印数据集统计信息
|
|
|
|
|
print(f"数据集分割完成:")
|
|
|
|
|
print(f"- 训练集路径: {train_dir}")
|
|
|
|
|
print(f"- 验证集路径: {val_dir}")
|
|
|
|
|
print(f"- 分割比例: {split_ratio:.2f} (训练) / {1-split_ratio:.2f} (验证)")
|
|
|
|
|
|
|
|
|
|
return train_dir, val_dir
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# 创建命令行参数解析器
|
|
|
|
|
parser = argparse.ArgumentParser(description="将情感语音数据集分割为训练集和验证集")
|
|
|
|
|
parser.add_argument("--source", type=str, default="wav",
|
|
|
|
|
help="源数据目录 (默认: wav)")
|
|
|
|
|
parser.add_argument("--output", type=str, default="dataset",
|
|
|
|
|
help="输出基础目录,将在其中创建train和val子目录 (默认: dataset)")
|
|
|
|
|
parser.add_argument("--ratio", type=float, default=0.8,
|
|
|
|
|
help="训练集比例 (0-1之间的浮点数, 默认: 0.8)")
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
|
|
|
help="随机种子 (默认: 无)")
|
|
|
|
|
parser.add_argument("--force", action="store_true",
|
|
|
|
|
help="如果目标目录已存在则强制删除")
|
|
|
|
|
|
|
|
|
|
# 解析命令行参数
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
# 设置随机种子(如果提供)
|
|
|
|
|
if args.seed is not None:
|
|
|
|
|
random.seed(args.seed)
|
|
|
|
|
print(f"已设置随机种子: {args.seed}")
|
|
|
|
|
|
|
|
|
|
# 检查分割比例是否有效
|
|
|
|
|
if args.ratio <= 0 or args.ratio >= 1:
|
|
|
|
|
parser.error("分割比例必须在0和1之间")
|
|
|
|
|
|
|
|
|
|
# 构建实际的训练和验证目录路径
|
|
|
|
|
train_dir = os.path.join(args.output, "train")
|
|
|
|
|
val_dir = os.path.join(args.output, "val")
|
|
|
|
|
|
|
|
|
|
# 处理目标目录
|
|
|
|
|
if args.force:
|
|
|
|
|
if os.path.exists(train_dir):
|
|
|
|
|
shutil.rmtree(train_dir)
|
|
|
|
|
if os.path.exists(val_dir):
|
|
|
|
|
shutil.rmtree(val_dir)
|
|
|
|
|
else:
|
|
|
|
|
# 检查目标目录是否已存在
|
|
|
|
|
if os.path.exists(train_dir) or os.path.exists(val_dir):
|
|
|
|
|
response = input(f"目标目录 ({train_dir} 或 {val_dir}) 已存在,是否继续并覆盖? (y/n): ")
|
|
|
|
|
if response.lower() != 'y':
|
|
|
|
|
print("操作已取消")
|
|
|
|
|
exit(0)
|
|
|
|
|
else:
|
|
|
|
|
if os.path.exists(train_dir):
|
|
|
|
|
shutil.rmtree(train_dir)
|
|
|
|
|
if os.path.exists(val_dir):
|
|
|
|
|
shutil.rmtree(val_dir)
|
|
|
|
|
|
|
|
|
|
# 执行数据集分割
|
|
|
|
|
split_dataset(args.source, args.output, args.ratio)
|