You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

432 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import random
import time
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score, precision_score, recall_score, accuracy_score
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataset import AudioDataset
from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101, spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101
import warnings
warnings.filterwarnings("ignore", message="At least one mel filterbank has all zero values")
def parse_args():
parser = argparse.ArgumentParser(description='训练音频情感分类模型')
# 数据集参数
parser.add_argument('--data_root', type=str, default='dataset', help='数据集根目录')
parser.add_argument('--use_mfcc', action='store_true', help='使用MFCC特征而非原始波形')
parser.add_argument('--val_split', type=float, default=0.2, help='验证集比例')
parser.add_argument('--num_workers', type=int, default=4, help='数据加载器的工作线程数量')
# 模型参数
parser.add_argument('--model', type=str, default='resnet18',
choices=['resnet18', 'resnet34', 'resnet50', 'resnet101'],
help='模型架构')
parser.add_argument('--pretrained', action='store_true', help='使用预训练权重')
# 训练参数
parser.add_argument('--batch_size', type=int, default=32, help='批次大小')
parser.add_argument('--epochs', type=int, default=30, help='训练轮数')
parser.add_argument('--lr', type=float, default=0.001, help='学习率')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='权重衰减')
parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='优化器选择')
parser.add_argument('--scheduler', action='store_true', help='是否使用学习率调度器')
# 其他参数
parser.add_argument('--seed', type=int, default=42, help='随机种子')
parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID')
parser.add_argument('--save_dir', type=str, default='checkpoints', help='模型保存目录')
return parser.parse_args()
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# 如果有GPU还需要设置以下代码
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model(args, num_classes):
if args.model == 'waveform_resnet18':
return waveform_resnet18(num_classes=num_classes)
elif args.model == 'waveform_resnet34':
return waveform_resnet34(num_classes=num_classes)
elif args.model == 'waveform_resnet50':
return waveform_resnet50(num_classes=num_classes)
elif args.model == 'waveform_resnet101':
return waveform_resnet101(num_classes=num_classes)
elif args.model == 'spectrogram_resnet18':
return spectrogram_resnet18(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet34':
return spectrogram_resnet34(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet50':
return spectrogram_resnet50(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet101':
return spectrogram_resnet101(num_classes=num_classes, pretrained=args.pretrained)
else:
raise ValueError(f"不支持的模型: {args.model}")
def get_optimizer(args, model):
if args.optimizer == 'adam':
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'sgd':
return optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
else:
raise ValueError(f"不支持的优化器: {args.optimizer}")
def calculate_metrics(outputs, targets, num_classes):
"""计算准确率、精确率、召回率和F1分数"""
_, preds = torch.max(outputs, 1)
# 转为CPU numpy数组以用于scikit-learn
preds_np = preds.cpu().numpy()
targets_np = targets.cpu().numpy()
# 计算总体准确率
acc = accuracy_score(targets_np, preds_np)
# 计算每个类别的精确率、召回率和F1分数然后求平均
precision = precision_score(targets_np, preds_np, average='macro', zero_division=0)
recall = recall_score(targets_np, preds_np, average='macro', zero_division=0)
f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0)
return acc, precision, recall, f1
def validate(model, val_loader, criterion, device, num_classes):
"""在验证集上评估模型"""
model.eval()
val_loss = 0.0
all_outputs = []
all_targets = []
with torch.no_grad():
for inputs, targets in tqdm(val_loader, desc="验证中", leave=False):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item() * inputs.size(0)
all_outputs.append(outputs)
all_targets.append(targets)
val_loss /= len(val_loader.dataset)
# 合并所有批次的输出和目标
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)
# 计算评估指标
acc, precision, recall, f1 = calculate_metrics(all_outputs, all_targets, num_classes)
return val_loss, acc, precision, recall, f1, all_outputs, all_targets
def save_model(model, path):
"""保存模型和配置"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(model.state_dict(), path)
print(f"模型已保存到 {path}")
def plot_metrics(train_metrics, val_metrics, save_dir):
"""绘制训练和验证指标"""
metrics = ['loss', 'accuracy', 'precision', 'recall', 'f1']
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
for i, metric in enumerate(metrics):
plt.figure(figsize=(10, 6))
plt.plot(train_metrics[i], label=f'Train {metric}')
plt.plot(val_metrics[i], label=f'Validation {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric.capitalize())
plt.title(f'Training and Validation {metric.capitalize()}')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, f'{metric}_plot.png'), dpi=300, bbox_inches='tight')
plt.close()
def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path):
"""绘制混淆矩阵"""
cm = confusion_matrix(true_labels, pred_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def plot_pr_curve(all_outputs, all_targets, num_classes, class_names, save_path):
"""绘制PR曲线"""
# 将输出转换为概率
probs = torch.nn.functional.softmax(all_outputs, dim=1).cpu().numpy()
targets = all_targets.cpu().numpy()
plt.figure(figsize=(12, 10))
# 为每个类别绘制PR曲线
for i in range(num_classes):
precision, recall, _ = precision_recall_curve(
(targets == i).astype(int),
probs[:, i]
)
plt.plot(recall, precision, lw=2, label=f'Class {class_names[i]}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='best')
plt.grid(True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def main():
args = parse_args()
# 判断是否使用MFCC特征如果使用则在模型名前加上spectrogram
if args.use_mfcc:
args.model = 'spectrogram_' + args.model
else:
args.model = 'waveform_' + args.model
set_seed(args.seed)
# 设置设备
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 给save_dir添加模型名称、是否使用MFCC特征和日期时间
args.save_dir = os.path.join(args.save_dir, f"{args.model}_{time.strftime('%Y%m%d_%H%M%S')}")
print(f"模型保存目录: {args.save_dir}")
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 加载数据集 - 修改为分别加载训练集和验证集
train_dir = os.path.join(args.data_root, 'train')
val_dir = os.path.join(args.data_root, 'val')
train_dataset = AudioDataset(train_dir, use_mfcc=args.use_mfcc)
val_dataset = AudioDataset(val_dir, use_mfcc=args.use_mfcc)
# 测试集使用验证集
test_dataset = val_dataset
# 定义类别名称
class_names = list(train_dataset.label_to_int.keys())
num_classes = len(class_names)
print(f"类别数量: {num_classes}")
print(f"类别名称: {class_names}")
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 创建数据加载器
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
# 创建模型
model = get_model(args, num_classes).to(device)
print(f"模型总参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(args, model)
# 学习率调度器
scheduler = None
if args.scheduler:
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
# 训练跟踪变量
best_val_f1 = 0.0
best_epoch = 0
best_model_path = ""
# 保存训练和验证指标
train_losses = []
train_accs = []
train_precisions = []
train_recalls = []
train_f1s = []
val_losses = []
val_accs = []
val_precisions = []
val_recalls = []
val_f1s = []
# 记录最佳验证集输出和标签(用于绘图)
best_val_outputs = None
best_val_targets = None
# 训练循环
start_time = time.time()
for epoch in range(args.epochs):
model.train()
train_loss = 0.0
all_outputs = []
all_targets = []
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for inputs, targets in train_pbar:
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
all_outputs.append(outputs)
all_targets.append(targets)
# 更新进度条
train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
# 计算平均训练损失
train_loss /= len(train_dataset)
# 计算训练指标
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)
train_acc, train_precision, train_recall, train_f1 = calculate_metrics(all_outputs, all_targets, num_classes)
# 验证
val_loss, val_acc, val_precision, val_recall, val_f1, val_outputs, val_targets = validate(
model, val_loader, criterion, device, num_classes
)
# 学习率调整
if scheduler:
scheduler.step(val_loss)
# 保存指标
train_losses.append(train_loss)
train_accs.append(train_acc)
train_precisions.append(train_precision)
train_recalls.append(train_recall)
train_f1s.append(train_f1)
val_losses.append(val_loss)
val_accs.append(val_acc)
val_precisions.append(val_precision)
val_recalls.append(val_recall)
val_f1s.append(val_f1)
# 打印当前epoch的指标
print(f"Epoch {epoch+1}/{args.epochs}")
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}")
print(f"Train Recall: {train_recall:.4f} | Val Recall: {val_recall:.4f}")
print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
# 保存最佳模型
if val_f1 > best_val_f1:
best_val_f1 = val_f1
best_epoch = epoch
best_model_path = os.path.join(args.save_dir, f"best.pth")
save_model(model, best_model_path)
# 保存最佳验证结果用于后续绘图
best_val_outputs = val_outputs
best_val_targets = val_targets
# 训练完成
total_time = time.time() - start_time
print(f"训练完成,总耗时: {total_time/60:.2f} 分钟")
print(f"最佳验证 F1: {best_val_f1:.4f},在 Epoch {best_epoch+1}")
# 绘制训练和验证指标曲线
plot_metrics(
[train_losses, train_accs, train_precisions, train_recalls, train_f1s],
[val_losses, val_accs, val_precisions, val_recalls, val_f1s],
os.path.join(args.save_dir, 'plots')
)
# 加载最佳模型进行测试
print(f"加载最佳模型进行测试...")
model.load_state_dict(torch.load(best_model_path))
test_loss, test_acc, test_precision, test_recall, test_f1, test_outputs, test_targets = validate(
model, test_loader, criterion, device, num_classes
)
print(f"测试指标:")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_acc:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1 Score: {test_f1:.4f}")
# 混淆矩阵
_, test_preds = torch.max(test_outputs, 1)
plot_confusion_matrix(
test_targets.cpu().numpy(),
test_preds.cpu().numpy(),
class_names,
os.path.join(args.save_dir, 'plots', 'confusion_matrix.png')
)
# PR曲线
plot_pr_curve(
test_outputs,
test_targets,
num_classes,
class_names,
os.path.join(args.save_dir, 'plots', 'pr_curve.png')
)
# 保存测试结果
results = {
'test_loss': test_loss,
'test_accuracy': test_acc,
'test_precision': test_precision,
'test_recall': test_recall,
'test_f1': test_f1,
'best_epoch': best_epoch + 1,
'best_val_f1': best_val_f1,
'train_time': total_time,
'model': args.model,
'use_mfcc': args.use_mfcc,
'pretrained': args.pretrained
}
# 将结果写入文件
with open(os.path.join(args.save_dir, 'results.txt'), 'w') as f:
for key, value in results.items():
f.write(f"{key}: {value}\n")
if __name__ == "__main__":
main()