import os import random import time import argparse import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score, precision_score, recall_score, accuracy_score from tqdm import tqdm import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.optim.lr_scheduler import ReduceLROnPlateau from dataset import AudioDataset from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101, spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101 import warnings warnings.filterwarnings("ignore", message="At least one mel filterbank has all zero values") def parse_args(): parser = argparse.ArgumentParser(description='训练音频情感分类模型') # 数据集参数 parser.add_argument('--data_root', type=str, default='dataset', help='数据集根目录') parser.add_argument('--use_mfcc', action='store_true', help='使用MFCC特征而非原始波形') parser.add_argument('--val_split', type=float, default=0.2, help='验证集比例') parser.add_argument('--num_workers', type=int, default=4, help='数据加载器的工作线程数量') # 模型参数 parser.add_argument('--model', type=str, default='resnet18', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101'], help='模型架构') parser.add_argument('--pretrained', action='store_true', help='使用预训练权重') # 训练参数 parser.add_argument('--batch_size', type=int, default=32, help='批次大小') parser.add_argument('--epochs', type=int, default=30, help='训练轮数') parser.add_argument('--lr', type=float, default=0.001, help='学习率') parser.add_argument('--weight_decay', type=float, default=1e-5, help='权重衰减') parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='优化器选择') parser.add_argument('--scheduler', action='store_true', help='是否使用学习率调度器') # 其他参数 parser.add_argument('--seed', type=int, default=42, help='随机种子') parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID') parser.add_argument('--save_dir', type=str, default='checkpoints', help='模型保存目录') return parser.parse_args() def set_seed(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # 如果有GPU,还需要设置以下代码 if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_model(args, num_classes): if args.model == 'waveform_resnet18': return waveform_resnet18(num_classes=num_classes) elif args.model == 'waveform_resnet34': return waveform_resnet34(num_classes=num_classes) elif args.model == 'waveform_resnet50': return waveform_resnet50(num_classes=num_classes) elif args.model == 'waveform_resnet101': return waveform_resnet101(num_classes=num_classes) elif args.model == 'spectrogram_resnet18': return spectrogram_resnet18(num_classes=num_classes, pretrained=args.pretrained) elif args.model == 'spectrogram_resnet34': return spectrogram_resnet34(num_classes=num_classes, pretrained=args.pretrained) elif args.model == 'spectrogram_resnet50': return spectrogram_resnet50(num_classes=num_classes, pretrained=args.pretrained) elif args.model == 'spectrogram_resnet101': return spectrogram_resnet101(num_classes=num_classes, pretrained=args.pretrained) else: raise ValueError(f"不支持的模型: {args.model}") def get_optimizer(args, model): if args.optimizer == 'adam': return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': return optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) else: raise ValueError(f"不支持的优化器: {args.optimizer}") def calculate_metrics(outputs, targets, num_classes): """计算准确率、精确率、召回率和F1分数""" _, preds = torch.max(outputs, 1) # 转为CPU numpy数组以用于scikit-learn preds_np = preds.cpu().numpy() targets_np = targets.cpu().numpy() # 计算总体准确率 acc = accuracy_score(targets_np, preds_np) # 计算每个类别的精确率、召回率和F1分数,然后求平均 precision = precision_score(targets_np, preds_np, average='macro', zero_division=0) recall = recall_score(targets_np, preds_np, average='macro', zero_division=0) f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0) return acc, precision, recall, f1 def validate(model, val_loader, criterion, device, num_classes): """在验证集上评估模型""" model.eval() val_loss = 0.0 all_outputs = [] all_targets = [] with torch.no_grad(): for inputs, targets in tqdm(val_loader, desc="验证中", leave=False): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() * inputs.size(0) all_outputs.append(outputs) all_targets.append(targets) val_loss /= len(val_loader.dataset) # 合并所有批次的输出和目标 all_outputs = torch.cat(all_outputs, dim=0) all_targets = torch.cat(all_targets, dim=0) # 计算评估指标 acc, precision, recall, f1 = calculate_metrics(all_outputs, all_targets, num_classes) return val_loss, acc, precision, recall, f1, all_outputs, all_targets def save_model(model, path): """保存模型和配置""" os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(model.state_dict(), path) print(f"模型已保存到 {path}") def plot_metrics(train_metrics, val_metrics, save_dir): """绘制训练和验证指标""" metrics = ['loss', 'accuracy', 'precision', 'recall', 'f1'] # 创建保存目录 os.makedirs(save_dir, exist_ok=True) for i, metric in enumerate(metrics): plt.figure(figsize=(10, 6)) plt.plot(train_metrics[i], label=f'Train {metric}') plt.plot(val_metrics[i], label=f'Validation {metric}') plt.xlabel('Epoch') plt.ylabel(metric.capitalize()) plt.title(f'Training and Validation {metric.capitalize()}') plt.legend() plt.grid(True) plt.savefig(os.path.join(save_dir, f'{metric}_plot.png'), dpi=300, bbox_inches='tight') plt.close() def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path): """绘制混淆矩阵""" cm = confusion_matrix(true_labels, pred_labels) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() def plot_pr_curve(all_outputs, all_targets, num_classes, class_names, save_path): """绘制PR曲线""" # 将输出转换为概率 probs = torch.nn.functional.softmax(all_outputs, dim=1).cpu().numpy() targets = all_targets.cpu().numpy() plt.figure(figsize=(12, 10)) # 为每个类别绘制PR曲线 for i in range(num_classes): precision, recall, _ = precision_recall_curve( (targets == i).astype(int), probs[:, i] ) plt.plot(recall, precision, lw=2, label=f'Class {class_names[i]}') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve') plt.legend(loc='best') plt.grid(True) plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() def main(): args = parse_args() # 判断是否使用MFCC特征,如果使用则在模型名前加上spectrogram if args.use_mfcc: args.model = 'spectrogram_' + args.model else: args.model = 'waveform_' + args.model set_seed(args.seed) # 设置设备 device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") # 给save_dir添加模型名称、是否使用MFCC特征和日期时间 args.save_dir = os.path.join(args.save_dir, f"{args.model}_{time.strftime('%Y%m%d_%H%M%S')}") print(f"模型保存目录: {args.save_dir}") # 创建保存目录 os.makedirs(args.save_dir, exist_ok=True) # 加载数据集 - 修改为分别加载训练集和验证集 train_dir = os.path.join(args.data_root, 'train') val_dir = os.path.join(args.data_root, 'val') train_dataset = AudioDataset(train_dir, use_mfcc=args.use_mfcc) val_dataset = AudioDataset(val_dir, use_mfcc=args.use_mfcc) # 测试集使用验证集 test_dataset = val_dataset # 定义类别名称 class_names = list(train_dataset.label_to_int.keys()) num_classes = len(class_names) print(f"类别数量: {num_classes}") print(f"类别名称: {class_names}") print(f"训练集大小: {len(train_dataset)}") print(f"验证集大小: {len(val_dataset)}") print(f"测试集大小: {len(test_dataset)}") # 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) test_loader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) # 创建模型 model = get_model(args, num_classes).to(device) print(f"模型总参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = get_optimizer(args, model) # 学习率调度器 scheduler = None if args.scheduler: scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True) # 训练跟踪变量 best_val_f1 = 0.0 best_epoch = 0 best_model_path = "" # 保存训练和验证指标 train_losses = [] train_accs = [] train_precisions = [] train_recalls = [] train_f1s = [] val_losses = [] val_accs = [] val_precisions = [] val_recalls = [] val_f1s = [] # 记录最佳验证集输出和标签(用于绘图) best_val_outputs = None best_val_targets = None # 训练循环 start_time = time.time() for epoch in range(args.epochs): model.train() train_loss = 0.0 all_outputs = [] all_targets = [] train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}") for inputs, targets in train_pbar: inputs, targets = inputs.to(device), targets.to(device) # 前向传播 optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) # 反向传播 loss.backward() optimizer.step() train_loss += loss.item() * inputs.size(0) all_outputs.append(outputs) all_targets.append(targets) # 更新进度条 train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) # 计算平均训练损失 train_loss /= len(train_dataset) # 计算训练指标 all_outputs = torch.cat(all_outputs, dim=0) all_targets = torch.cat(all_targets, dim=0) train_acc, train_precision, train_recall, train_f1 = calculate_metrics(all_outputs, all_targets, num_classes) # 验证 val_loss, val_acc, val_precision, val_recall, val_f1, val_outputs, val_targets = validate( model, val_loader, criterion, device, num_classes ) # 学习率调整 if scheduler: scheduler.step(val_loss) # 保存指标 train_losses.append(train_loss) train_accs.append(train_acc) train_precisions.append(train_precision) train_recalls.append(train_recall) train_f1s.append(train_f1) val_losses.append(val_loss) val_accs.append(val_acc) val_precisions.append(val_precision) val_recalls.append(val_recall) val_f1s.append(val_f1) # 打印当前epoch的指标 print(f"Epoch {epoch+1}/{args.epochs}") print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}") print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}") print(f"Train Recall: {train_recall:.4f} | Val Recall: {val_recall:.4f}") print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}") # 保存最佳模型 if val_f1 > best_val_f1: best_val_f1 = val_f1 best_epoch = epoch best_model_path = os.path.join(args.save_dir, f"best.pth") save_model(model, best_model_path) # 保存最佳验证结果用于后续绘图 best_val_outputs = val_outputs best_val_targets = val_targets # 训练完成 total_time = time.time() - start_time print(f"训练完成,总耗时: {total_time/60:.2f} 分钟") print(f"最佳验证 F1: {best_val_f1:.4f},在 Epoch {best_epoch+1}") # 绘制训练和验证指标曲线 plot_metrics( [train_losses, train_accs, train_precisions, train_recalls, train_f1s], [val_losses, val_accs, val_precisions, val_recalls, val_f1s], os.path.join(args.save_dir, 'plots') ) # 加载最佳模型进行测试 print(f"加载最佳模型进行测试...") model.load_state_dict(torch.load(best_model_path)) test_loss, test_acc, test_precision, test_recall, test_f1, test_outputs, test_targets = validate( model, test_loader, criterion, device, num_classes ) print(f"测试指标:") print(f"Loss: {test_loss:.4f}") print(f"Accuracy: {test_acc:.4f}") print(f"Precision: {test_precision:.4f}") print(f"Recall: {test_recall:.4f}") print(f"F1 Score: {test_f1:.4f}") # 混淆矩阵 _, test_preds = torch.max(test_outputs, 1) plot_confusion_matrix( test_targets.cpu().numpy(), test_preds.cpu().numpy(), class_names, os.path.join(args.save_dir, 'plots', 'confusion_matrix.png') ) # PR曲线 plot_pr_curve( test_outputs, test_targets, num_classes, class_names, os.path.join(args.save_dir, 'plots', 'pr_curve.png') ) # 保存测试结果 results = { 'test_loss': test_loss, 'test_accuracy': test_acc, 'test_precision': test_precision, 'test_recall': test_recall, 'test_f1': test_f1, 'best_epoch': best_epoch + 1, 'best_val_f1': best_val_f1, 'train_time': total_time, 'model': args.model, 'use_mfcc': args.use_mfcc, 'pretrained': args.pretrained } # 将结果写入文件 with open(os.path.join(args.save_dir, 'results.txt'), 'w') as f: for key, value in results.items(): f.write(f"{key}: {value}\n") if __name__ == "__main__": main()