ADD file via upload

main
pnmfazke8 3 months ago
parent ec57548845
commit 6b371d5faa

@ -0,0 +1,102 @@
import os
import shutil
import random
import argparse
def split_dataset(source_dir, output_dir, split_ratio=0.8):
"""
Split the dataset into training and validation sets.
:param source_dir: Directory containing the dataset with subfolders for each class.
:param output_dir: Base directory where train and val subdirectories will be created.
:param split_ratio: Ratio of the training set size to the total dataset size.
"""
# 创建训练集和验证集目录
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(val_dir):
os.makedirs(val_dir)
for class_dir in os.listdir(source_dir):
class_path = os.path.join(source_dir, class_dir)
if os.path.isdir(class_path):
files = os.listdir(class_path)
random.shuffle(files)
split_point = int(len(files) * split_ratio)
train_files = files[:split_point]
val_files = files[split_point:]
train_class_dir = os.path.join(train_dir, class_dir)
val_class_dir = os.path.join(val_dir, class_dir)
if not os.path.exists(train_class_dir):
os.makedirs(train_class_dir)
if not os.path.exists(val_class_dir):
os.makedirs(val_class_dir)
for file in train_files:
shutil.copy(os.path.join(class_path, file), os.path.join(train_class_dir, file))
for file in val_files:
shutil.copy(os.path.join(class_path, file), os.path.join(val_class_dir, file))
# 打印数据集统计信息
print(f"数据集分割完成:")
print(f"- 训练集路径: {train_dir}")
print(f"- 验证集路径: {val_dir}")
print(f"- 分割比例: {split_ratio:.2f} (训练) / {1-split_ratio:.2f} (验证)")
return train_dir, val_dir
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="将情感语音数据集分割为训练集和验证集")
parser.add_argument("--source", type=str, default="wav",
help="源数据目录 (默认: wav)")
parser.add_argument("--output", type=str, default="dataset",
help="输出基础目录将在其中创建train和val子目录 (默认: dataset)")
parser.add_argument("--ratio", type=float, default=0.8,
help="训练集比例 (0-1之间的浮点数, 默认: 0.8)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子 (默认: 无)")
parser.add_argument("--force", action="store_true",
help="如果目标目录已存在则强制删除")
# 解析命令行参数
args = parser.parse_args()
# 设置随机种子(如果提供)
if args.seed is not None:
random.seed(args.seed)
print(f"已设置随机种子: {args.seed}")
# 检查分割比例是否有效
if args.ratio <= 0 or args.ratio >= 1:
parser.error("分割比例必须在0和1之间")
# 构建实际的训练和验证目录路径
train_dir = os.path.join(args.output, "train")
val_dir = os.path.join(args.output, "val")
# 处理目标目录
if args.force:
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
if os.path.exists(val_dir):
shutil.rmtree(val_dir)
else:
# 检查目标目录是否已存在
if os.path.exists(train_dir) or os.path.exists(val_dir):
response = input(f"目标目录 ({train_dir}{val_dir}) 已存在,是否继续并覆盖? (y/n): ")
if response.lower() != 'y':
print("操作已取消")
exit(0)
else:
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
if os.path.exists(val_dir):
shutil.rmtree(val_dir)
# 执行数据集分割
split_dataset(args.source, args.output, args.ratio)
Loading…
Cancel
Save