diff --git a/train.py b/train.py new file mode 100644 index 0000000..0de87c7 --- /dev/null +++ b/train.py @@ -0,0 +1,432 @@ +import os +import random +import time +import argparse +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score, precision_score, recall_score, accuracy_score +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from dataset import AudioDataset +from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101, spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101 +import warnings +warnings.filterwarnings("ignore", message="At least one mel filterbank has all zero values") + +def parse_args(): + parser = argparse.ArgumentParser(description='训练音频情感分类模型') + + # 数据集参数 + parser.add_argument('--data_root', type=str, default='dataset', help='数据集根目录') + parser.add_argument('--use_mfcc', action='store_true', help='使用MFCC特征而非原始波形') + parser.add_argument('--val_split', type=float, default=0.2, help='验证集比例') + parser.add_argument('--num_workers', type=int, default=4, help='数据加载器的工作线程数量') + + # 模型参数 + parser.add_argument('--model', type=str, default='resnet18', + choices=['resnet18', 'resnet34', 'resnet50', 'resnet101'], + help='模型架构') + parser.add_argument('--pretrained', action='store_true', help='使用预训练权重') + + # 训练参数 + parser.add_argument('--batch_size', type=int, default=32, help='批次大小') + parser.add_argument('--epochs', type=int, default=30, help='训练轮数') + parser.add_argument('--lr', type=float, default=0.001, help='学习率') + parser.add_argument('--weight_decay', type=float, default=1e-5, help='权重衰减') + parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='优化器选择') + parser.add_argument('--scheduler', action='store_true', help='是否使用学习率调度器') + + # 其他参数 + parser.add_argument('--seed', type=int, default=42, help='随机种子') + parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID') + parser.add_argument('--save_dir', type=str, default='checkpoints', help='模型保存目录') + + return parser.parse_args() + +def set_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + # 如果有GPU,还需要设置以下代码 + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_model(args, num_classes): + if args.model == 'waveform_resnet18': + return waveform_resnet18(num_classes=num_classes) + elif args.model == 'waveform_resnet34': + return waveform_resnet34(num_classes=num_classes) + elif args.model == 'waveform_resnet50': + return waveform_resnet50(num_classes=num_classes) + elif args.model == 'waveform_resnet101': + return waveform_resnet101(num_classes=num_classes) + elif args.model == 'spectrogram_resnet18': + return spectrogram_resnet18(num_classes=num_classes, pretrained=args.pretrained) + elif args.model == 'spectrogram_resnet34': + return spectrogram_resnet34(num_classes=num_classes, pretrained=args.pretrained) + elif args.model == 'spectrogram_resnet50': + return spectrogram_resnet50(num_classes=num_classes, pretrained=args.pretrained) + elif args.model == 'spectrogram_resnet101': + return spectrogram_resnet101(num_classes=num_classes, pretrained=args.pretrained) + else: + raise ValueError(f"不支持的模型: {args.model}") + +def get_optimizer(args, model): + if args.optimizer == 'adam': + return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + elif args.optimizer == 'sgd': + return optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) + else: + raise ValueError(f"不支持的优化器: {args.optimizer}") + +def calculate_metrics(outputs, targets, num_classes): + """计算准确率、精确率、召回率和F1分数""" + _, preds = torch.max(outputs, 1) + + # 转为CPU numpy数组以用于scikit-learn + preds_np = preds.cpu().numpy() + targets_np = targets.cpu().numpy() + + # 计算总体准确率 + acc = accuracy_score(targets_np, preds_np) + + # 计算每个类别的精确率、召回率和F1分数,然后求平均 + precision = precision_score(targets_np, preds_np, average='macro', zero_division=0) + recall = recall_score(targets_np, preds_np, average='macro', zero_division=0) + f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0) + + return acc, precision, recall, f1 + +def validate(model, val_loader, criterion, device, num_classes): + """在验证集上评估模型""" + model.eval() + val_loss = 0.0 + all_outputs = [] + all_targets = [] + + with torch.no_grad(): + for inputs, targets in tqdm(val_loader, desc="验证中", leave=False): + inputs, targets = inputs.to(device), targets.to(device) + + outputs = model(inputs) + loss = criterion(outputs, targets) + + val_loss += loss.item() * inputs.size(0) + all_outputs.append(outputs) + all_targets.append(targets) + + val_loss /= len(val_loader.dataset) + + # 合并所有批次的输出和目标 + all_outputs = torch.cat(all_outputs, dim=0) + all_targets = torch.cat(all_targets, dim=0) + + # 计算评估指标 + acc, precision, recall, f1 = calculate_metrics(all_outputs, all_targets, num_classes) + + return val_loss, acc, precision, recall, f1, all_outputs, all_targets + +def save_model(model, path): + """保存模型和配置""" + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save(model.state_dict(), path) + print(f"模型已保存到 {path}") + +def plot_metrics(train_metrics, val_metrics, save_dir): + """绘制训练和验证指标""" + metrics = ['loss', 'accuracy', 'precision', 'recall', 'f1'] + + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + + for i, metric in enumerate(metrics): + plt.figure(figsize=(10, 6)) + plt.plot(train_metrics[i], label=f'Train {metric}') + plt.plot(val_metrics[i], label=f'Validation {metric}') + plt.xlabel('Epoch') + plt.ylabel(metric.capitalize()) + plt.title(f'Training and Validation {metric.capitalize()}') + plt.legend() + plt.grid(True) + plt.savefig(os.path.join(save_dir, f'{metric}_plot.png'), dpi=300, bbox_inches='tight') + plt.close() + +def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path): + """绘制混淆矩阵""" + cm = confusion_matrix(true_labels, pred_labels) + plt.figure(figsize=(10, 8)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) + plt.xlabel('Predicted') + plt.ylabel('True') + plt.title('Confusion Matrix') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + +def plot_pr_curve(all_outputs, all_targets, num_classes, class_names, save_path): + """绘制PR曲线""" + # 将输出转换为概率 + probs = torch.nn.functional.softmax(all_outputs, dim=1).cpu().numpy() + targets = all_targets.cpu().numpy() + + plt.figure(figsize=(12, 10)) + + # 为每个类别绘制PR曲线 + for i in range(num_classes): + precision, recall, _ = precision_recall_curve( + (targets == i).astype(int), + probs[:, i] + ) + plt.plot(recall, precision, lw=2, label=f'Class {class_names[i]}') + + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.legend(loc='best') + plt.grid(True) + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + +def main(): + args = parse_args() + + # 判断是否使用MFCC特征,如果使用则在模型名前加上spectrogram + if args.use_mfcc: + args.model = 'spectrogram_' + args.model + else: + args.model = 'waveform_' + args.model + + set_seed(args.seed) + + # 设置设备 + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"使用设备: {device}") + + # 给save_dir添加模型名称、是否使用MFCC特征和日期时间 + args.save_dir = os.path.join(args.save_dir, f"{args.model}_{time.strftime('%Y%m%d_%H%M%S')}") + print(f"模型保存目录: {args.save_dir}") + + # 创建保存目录 + os.makedirs(args.save_dir, exist_ok=True) + + # 加载数据集 - 修改为分别加载训练集和验证集 + train_dir = os.path.join(args.data_root, 'train') + val_dir = os.path.join(args.data_root, 'val') + + train_dataset = AudioDataset(train_dir, use_mfcc=args.use_mfcc) + val_dataset = AudioDataset(val_dir, use_mfcc=args.use_mfcc) + + # 测试集使用验证集 + test_dataset = val_dataset + + # 定义类别名称 + class_names = list(train_dataset.label_to_int.keys()) + num_classes = len(class_names) + print(f"类别数量: {num_classes}") + print(f"类别名称: {class_names}") + + print(f"训练集大小: {len(train_dataset)}") + print(f"验证集大小: {len(val_dataset)}") + print(f"测试集大小: {len(test_dataset)}") + + # 创建数据加载器 + train_loader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.num_workers, pin_memory=True + ) + + val_loader = DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, pin_memory=True + ) + + test_loader = DataLoader( + test_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, pin_memory=True + ) + + # 创建模型 + model = get_model(args, num_classes).to(device) + print(f"模型总参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + # 定义损失函数和优化器 + criterion = nn.CrossEntropyLoss() + optimizer = get_optimizer(args, model) + + # 学习率调度器 + scheduler = None + if args.scheduler: + scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True) + + # 训练跟踪变量 + best_val_f1 = 0.0 + best_epoch = 0 + best_model_path = "" + + # 保存训练和验证指标 + train_losses = [] + train_accs = [] + train_precisions = [] + train_recalls = [] + train_f1s = [] + + val_losses = [] + val_accs = [] + val_precisions = [] + val_recalls = [] + val_f1s = [] + + # 记录最佳验证集输出和标签(用于绘图) + best_val_outputs = None + best_val_targets = None + + # 训练循环 + start_time = time.time() + for epoch in range(args.epochs): + model.train() + train_loss = 0.0 + all_outputs = [] + all_targets = [] + + train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}") + for inputs, targets in train_pbar: + inputs, targets = inputs.to(device), targets.to(device) + + # 前向传播 + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + + # 反向传播 + loss.backward() + optimizer.step() + + train_loss += loss.item() * inputs.size(0) + all_outputs.append(outputs) + all_targets.append(targets) + + # 更新进度条 + train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + + # 计算平均训练损失 + train_loss /= len(train_dataset) + + # 计算训练指标 + all_outputs = torch.cat(all_outputs, dim=0) + all_targets = torch.cat(all_targets, dim=0) + train_acc, train_precision, train_recall, train_f1 = calculate_metrics(all_outputs, all_targets, num_classes) + + # 验证 + val_loss, val_acc, val_precision, val_recall, val_f1, val_outputs, val_targets = validate( + model, val_loader, criterion, device, num_classes + ) + + # 学习率调整 + if scheduler: + scheduler.step(val_loss) + + # 保存指标 + train_losses.append(train_loss) + train_accs.append(train_acc) + train_precisions.append(train_precision) + train_recalls.append(train_recall) + train_f1s.append(train_f1) + + val_losses.append(val_loss) + val_accs.append(val_acc) + val_precisions.append(val_precision) + val_recalls.append(val_recall) + val_f1s.append(val_f1) + + # 打印当前epoch的指标 + print(f"Epoch {epoch+1}/{args.epochs}") + print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") + print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}") + print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}") + print(f"Train Recall: {train_recall:.4f} | Val Recall: {val_recall:.4f}") + print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}") + + # 保存最佳模型 + if val_f1 > best_val_f1: + best_val_f1 = val_f1 + best_epoch = epoch + best_model_path = os.path.join(args.save_dir, f"best.pth") + save_model(model, best_model_path) + # 保存最佳验证结果用于后续绘图 + best_val_outputs = val_outputs + best_val_targets = val_targets + + # 训练完成 + total_time = time.time() - start_time + print(f"训练完成,总耗时: {total_time/60:.2f} 分钟") + print(f"最佳验证 F1: {best_val_f1:.4f},在 Epoch {best_epoch+1}") + + # 绘制训练和验证指标曲线 + plot_metrics( + [train_losses, train_accs, train_precisions, train_recalls, train_f1s], + [val_losses, val_accs, val_precisions, val_recalls, val_f1s], + os.path.join(args.save_dir, 'plots') + ) + + # 加载最佳模型进行测试 + print(f"加载最佳模型进行测试...") + model.load_state_dict(torch.load(best_model_path)) + test_loss, test_acc, test_precision, test_recall, test_f1, test_outputs, test_targets = validate( + model, test_loader, criterion, device, num_classes + ) + + print(f"测试指标:") + print(f"Loss: {test_loss:.4f}") + print(f"Accuracy: {test_acc:.4f}") + print(f"Precision: {test_precision:.4f}") + print(f"Recall: {test_recall:.4f}") + print(f"F1 Score: {test_f1:.4f}") + + # 混淆矩阵 + _, test_preds = torch.max(test_outputs, 1) + plot_confusion_matrix( + test_targets.cpu().numpy(), + test_preds.cpu().numpy(), + class_names, + os.path.join(args.save_dir, 'plots', 'confusion_matrix.png') + ) + + # PR曲线 + plot_pr_curve( + test_outputs, + test_targets, + num_classes, + class_names, + os.path.join(args.save_dir, 'plots', 'pr_curve.png') + ) + + # 保存测试结果 + results = { + 'test_loss': test_loss, + 'test_accuracy': test_acc, + 'test_precision': test_precision, + 'test_recall': test_recall, + 'test_f1': test_f1, + 'best_epoch': best_epoch + 1, + 'best_val_f1': best_val_f1, + 'train_time': total_time, + 'model': args.model, + 'use_mfcc': args.use_mfcc, + 'pretrained': args.pretrained + } + + # 将结果写入文件 + with open(os.path.join(args.save_dir, 'results.txt'), 'w') as f: + for key, value in results.items(): + f.write(f"{key}: {value}\n") + +if __name__ == "__main__": + main() \ No newline at end of file