|
|
import os
|
|
|
import random
|
|
|
import time
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import seaborn as sns
|
|
|
from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score, precision_score, recall_score, accuracy_score
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
|
|
from dataset import AudioDataset
|
|
|
from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101, spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101
|
|
|
import warnings
|
|
|
warnings.filterwarnings("ignore", message="At least one mel filterbank has all zero values")
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(description='训练音频情感分类模型')
|
|
|
|
|
|
# 数据集参数
|
|
|
parser.add_argument('--data_root', type=str, default='dataset', help='数据集根目录')
|
|
|
parser.add_argument('--use_mfcc', action='store_true', help='使用MFCC特征而非原始波形')
|
|
|
parser.add_argument('--val_split', type=float, default=0.2, help='验证集比例')
|
|
|
parser.add_argument('--num_workers', type=int, default=4, help='数据加载器的工作线程数量')
|
|
|
|
|
|
# 模型参数
|
|
|
parser.add_argument('--model', type=str, default='resnet18',
|
|
|
choices=['resnet18', 'resnet34', 'resnet50', 'resnet101'],
|
|
|
help='模型架构')
|
|
|
parser.add_argument('--pretrained', action='store_true', help='使用预训练权重')
|
|
|
|
|
|
# 训练参数
|
|
|
parser.add_argument('--batch_size', type=int, default=32, help='批次大小')
|
|
|
parser.add_argument('--epochs', type=int, default=30, help='训练轮数')
|
|
|
parser.add_argument('--lr', type=float, default=0.001, help='学习率')
|
|
|
parser.add_argument('--weight_decay', type=float, default=1e-5, help='权重衰减')
|
|
|
parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='优化器选择')
|
|
|
parser.add_argument('--scheduler', action='store_true', help='是否使用学习率调度器')
|
|
|
|
|
|
# 其他参数
|
|
|
parser.add_argument('--seed', type=int, default=42, help='随机种子')
|
|
|
parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID')
|
|
|
parser.add_argument('--save_dir', type=str, default='checkpoints', help='模型保存目录')
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def set_seed(seed):
|
|
|
torch.manual_seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
random.seed(seed)
|
|
|
# 如果有GPU,还需要设置以下代码
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def get_model(args, num_classes):
|
|
|
if args.model == 'waveform_resnet18':
|
|
|
return waveform_resnet18(num_classes=num_classes)
|
|
|
elif args.model == 'waveform_resnet34':
|
|
|
return waveform_resnet34(num_classes=num_classes)
|
|
|
elif args.model == 'waveform_resnet50':
|
|
|
return waveform_resnet50(num_classes=num_classes)
|
|
|
elif args.model == 'waveform_resnet101':
|
|
|
return waveform_resnet101(num_classes=num_classes)
|
|
|
elif args.model == 'spectrogram_resnet18':
|
|
|
return spectrogram_resnet18(num_classes=num_classes, pretrained=args.pretrained)
|
|
|
elif args.model == 'spectrogram_resnet34':
|
|
|
return spectrogram_resnet34(num_classes=num_classes, pretrained=args.pretrained)
|
|
|
elif args.model == 'spectrogram_resnet50':
|
|
|
return spectrogram_resnet50(num_classes=num_classes, pretrained=args.pretrained)
|
|
|
elif args.model == 'spectrogram_resnet101':
|
|
|
return spectrogram_resnet101(num_classes=num_classes, pretrained=args.pretrained)
|
|
|
else:
|
|
|
raise ValueError(f"不支持的模型: {args.model}")
|
|
|
|
|
|
def get_optimizer(args, model):
|
|
|
if args.optimizer == 'adam':
|
|
|
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
|
elif args.optimizer == 'sgd':
|
|
|
return optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
|
|
|
else:
|
|
|
raise ValueError(f"不支持的优化器: {args.optimizer}")
|
|
|
|
|
|
def calculate_metrics(outputs, targets, num_classes):
|
|
|
"""计算准确率、精确率、召回率和F1分数"""
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
|
|
|
# 转为CPU numpy数组以用于scikit-learn
|
|
|
preds_np = preds.cpu().numpy()
|
|
|
targets_np = targets.cpu().numpy()
|
|
|
|
|
|
# 计算总体准确率
|
|
|
acc = accuracy_score(targets_np, preds_np)
|
|
|
|
|
|
# 计算每个类别的精确率、召回率和F1分数,然后求平均
|
|
|
precision = precision_score(targets_np, preds_np, average='macro', zero_division=0)
|
|
|
recall = recall_score(targets_np, preds_np, average='macro', zero_division=0)
|
|
|
f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0)
|
|
|
|
|
|
return acc, precision, recall, f1
|
|
|
|
|
|
def validate(model, val_loader, criterion, device, num_classes):
|
|
|
"""在验证集上评估模型"""
|
|
|
model.eval()
|
|
|
val_loss = 0.0
|
|
|
all_outputs = []
|
|
|
all_targets = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for inputs, targets in tqdm(val_loader, desc="验证中", leave=False):
|
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
|
|
outputs = model(inputs)
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
|
|
val_loss += loss.item() * inputs.size(0)
|
|
|
all_outputs.append(outputs)
|
|
|
all_targets.append(targets)
|
|
|
|
|
|
val_loss /= len(val_loader.dataset)
|
|
|
|
|
|
# 合并所有批次的输出和目标
|
|
|
all_outputs = torch.cat(all_outputs, dim=0)
|
|
|
all_targets = torch.cat(all_targets, dim=0)
|
|
|
|
|
|
# 计算评估指标
|
|
|
acc, precision, recall, f1 = calculate_metrics(all_outputs, all_targets, num_classes)
|
|
|
|
|
|
return val_loss, acc, precision, recall, f1, all_outputs, all_targets
|
|
|
|
|
|
def save_model(model, path):
|
|
|
"""保存模型和配置"""
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
torch.save(model.state_dict(), path)
|
|
|
print(f"模型已保存到 {path}")
|
|
|
|
|
|
def plot_metrics(train_metrics, val_metrics, save_dir):
|
|
|
"""绘制训练和验证指标"""
|
|
|
metrics = ['loss', 'accuracy', 'precision', 'recall', 'f1']
|
|
|
|
|
|
# 创建保存目录
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
for i, metric in enumerate(metrics):
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
plt.plot(train_metrics[i], label=f'Train {metric}')
|
|
|
plt.plot(val_metrics[i], label=f'Validation {metric}')
|
|
|
plt.xlabel('Epoch')
|
|
|
plt.ylabel(metric.capitalize())
|
|
|
plt.title(f'Training and Validation {metric.capitalize()}')
|
|
|
plt.legend()
|
|
|
plt.grid(True)
|
|
|
plt.savefig(os.path.join(save_dir, f'{metric}_plot.png'), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path):
|
|
|
"""绘制混淆矩阵"""
|
|
|
cm = confusion_matrix(true_labels, pred_labels)
|
|
|
plt.figure(figsize=(10, 8))
|
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
|
|
plt.xlabel('Predicted')
|
|
|
plt.ylabel('True')
|
|
|
plt.title('Confusion Matrix')
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
def plot_pr_curve(all_outputs, all_targets, num_classes, class_names, save_path):
|
|
|
"""绘制PR曲线"""
|
|
|
# 将输出转换为概率
|
|
|
probs = torch.nn.functional.softmax(all_outputs, dim=1).cpu().numpy()
|
|
|
targets = all_targets.cpu().numpy()
|
|
|
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
|
|
|
# 为每个类别绘制PR曲线
|
|
|
for i in range(num_classes):
|
|
|
precision, recall, _ = precision_recall_curve(
|
|
|
(targets == i).astype(int),
|
|
|
probs[:, i]
|
|
|
)
|
|
|
plt.plot(recall, precision, lw=2, label=f'Class {class_names[i]}')
|
|
|
|
|
|
plt.xlabel('Recall')
|
|
|
plt.ylabel('Precision')
|
|
|
plt.title('Precision-Recall Curve')
|
|
|
plt.legend(loc='best')
|
|
|
plt.grid(True)
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
|
|
|
# 判断是否使用MFCC特征,如果使用则在模型名前加上spectrogram
|
|
|
if args.use_mfcc:
|
|
|
args.model = 'spectrogram_' + args.model
|
|
|
else:
|
|
|
args.model = 'waveform_' + args.model
|
|
|
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
# 设置设备
|
|
|
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"使用设备: {device}")
|
|
|
|
|
|
# 给save_dir添加模型名称、是否使用MFCC特征和日期时间
|
|
|
args.save_dir = os.path.join(args.save_dir, f"{args.model}_{time.strftime('%Y%m%d_%H%M%S')}")
|
|
|
print(f"模型保存目录: {args.save_dir}")
|
|
|
|
|
|
# 创建保存目录
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
|
|
# 加载数据集 - 修改为分别加载训练集和验证集
|
|
|
train_dir = os.path.join(args.data_root, 'train')
|
|
|
val_dir = os.path.join(args.data_root, 'val')
|
|
|
|
|
|
train_dataset = AudioDataset(train_dir, use_mfcc=args.use_mfcc)
|
|
|
val_dataset = AudioDataset(val_dir, use_mfcc=args.use_mfcc)
|
|
|
|
|
|
# 测试集使用验证集
|
|
|
test_dataset = val_dataset
|
|
|
|
|
|
# 定义类别名称
|
|
|
class_names = list(train_dataset.label_to_int.keys())
|
|
|
num_classes = len(class_names)
|
|
|
print(f"类别数量: {num_classes}")
|
|
|
print(f"类别名称: {class_names}")
|
|
|
|
|
|
print(f"训练集大小: {len(train_dataset)}")
|
|
|
print(f"验证集大小: {len(val_dataset)}")
|
|
|
print(f"测试集大小: {len(test_dataset)}")
|
|
|
|
|
|
# 创建数据加载器
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset, batch_size=args.batch_size, shuffle=True,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset, batch_size=args.batch_size, shuffle=False,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
|
|
|
test_loader = DataLoader(
|
|
|
test_dataset, batch_size=args.batch_size, shuffle=False,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
|
|
|
# 创建模型
|
|
|
model = get_model(args, num_classes).to(device)
|
|
|
print(f"模型总参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
|
|
|
|
|
# 定义损失函数和优化器
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
optimizer = get_optimizer(args, model)
|
|
|
|
|
|
# 学习率调度器
|
|
|
scheduler = None
|
|
|
if args.scheduler:
|
|
|
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
|
|
|
|
|
|
# 训练跟踪变量
|
|
|
best_val_f1 = 0.0
|
|
|
best_epoch = 0
|
|
|
best_model_path = ""
|
|
|
|
|
|
# 保存训练和验证指标
|
|
|
train_losses = []
|
|
|
train_accs = []
|
|
|
train_precisions = []
|
|
|
train_recalls = []
|
|
|
train_f1s = []
|
|
|
|
|
|
val_losses = []
|
|
|
val_accs = []
|
|
|
val_precisions = []
|
|
|
val_recalls = []
|
|
|
val_f1s = []
|
|
|
|
|
|
# 记录最佳验证集输出和标签(用于绘图)
|
|
|
best_val_outputs = None
|
|
|
best_val_targets = None
|
|
|
|
|
|
# 训练循环
|
|
|
start_time = time.time()
|
|
|
for epoch in range(args.epochs):
|
|
|
model.train()
|
|
|
train_loss = 0.0
|
|
|
all_outputs = []
|
|
|
all_targets = []
|
|
|
|
|
|
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
|
|
|
for inputs, targets in train_pbar:
|
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
|
|
# 前向传播
|
|
|
optimizer.zero_grad()
|
|
|
outputs = model(inputs)
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
|
|
# 反向传播
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
train_loss += loss.item() * inputs.size(0)
|
|
|
all_outputs.append(outputs)
|
|
|
all_targets.append(targets)
|
|
|
|
|
|
# 更新进度条
|
|
|
train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
|
|
|
|
|
# 计算平均训练损失
|
|
|
train_loss /= len(train_dataset)
|
|
|
|
|
|
# 计算训练指标
|
|
|
all_outputs = torch.cat(all_outputs, dim=0)
|
|
|
all_targets = torch.cat(all_targets, dim=0)
|
|
|
train_acc, train_precision, train_recall, train_f1 = calculate_metrics(all_outputs, all_targets, num_classes)
|
|
|
|
|
|
# 验证
|
|
|
val_loss, val_acc, val_precision, val_recall, val_f1, val_outputs, val_targets = validate(
|
|
|
model, val_loader, criterion, device, num_classes
|
|
|
)
|
|
|
|
|
|
# 学习率调整
|
|
|
if scheduler:
|
|
|
scheduler.step(val_loss)
|
|
|
|
|
|
# 保存指标
|
|
|
train_losses.append(train_loss)
|
|
|
train_accs.append(train_acc)
|
|
|
train_precisions.append(train_precision)
|
|
|
train_recalls.append(train_recall)
|
|
|
train_f1s.append(train_f1)
|
|
|
|
|
|
val_losses.append(val_loss)
|
|
|
val_accs.append(val_acc)
|
|
|
val_precisions.append(val_precision)
|
|
|
val_recalls.append(val_recall)
|
|
|
val_f1s.append(val_f1)
|
|
|
|
|
|
# 打印当前epoch的指标
|
|
|
print(f"Epoch {epoch+1}/{args.epochs}")
|
|
|
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
|
|
print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
|
|
|
print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}")
|
|
|
print(f"Train Recall: {train_recall:.4f} | Val Recall: {val_recall:.4f}")
|
|
|
print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
|
|
|
|
|
|
# 保存最佳模型
|
|
|
if val_f1 > best_val_f1:
|
|
|
best_val_f1 = val_f1
|
|
|
best_epoch = epoch
|
|
|
best_model_path = os.path.join(args.save_dir, f"best.pth")
|
|
|
save_model(model, best_model_path)
|
|
|
# 保存最佳验证结果用于后续绘图
|
|
|
best_val_outputs = val_outputs
|
|
|
best_val_targets = val_targets
|
|
|
|
|
|
# 训练完成
|
|
|
total_time = time.time() - start_time
|
|
|
print(f"训练完成,总耗时: {total_time/60:.2f} 分钟")
|
|
|
print(f"最佳验证 F1: {best_val_f1:.4f},在 Epoch {best_epoch+1}")
|
|
|
|
|
|
# 绘制训练和验证指标曲线
|
|
|
plot_metrics(
|
|
|
[train_losses, train_accs, train_precisions, train_recalls, train_f1s],
|
|
|
[val_losses, val_accs, val_precisions, val_recalls, val_f1s],
|
|
|
os.path.join(args.save_dir, 'plots')
|
|
|
)
|
|
|
|
|
|
# 加载最佳模型进行测试
|
|
|
print(f"加载最佳模型进行测试...")
|
|
|
model.load_state_dict(torch.load(best_model_path))
|
|
|
test_loss, test_acc, test_precision, test_recall, test_f1, test_outputs, test_targets = validate(
|
|
|
model, test_loader, criterion, device, num_classes
|
|
|
)
|
|
|
|
|
|
print(f"测试指标:")
|
|
|
print(f"Loss: {test_loss:.4f}")
|
|
|
print(f"Accuracy: {test_acc:.4f}")
|
|
|
print(f"Precision: {test_precision:.4f}")
|
|
|
print(f"Recall: {test_recall:.4f}")
|
|
|
print(f"F1 Score: {test_f1:.4f}")
|
|
|
|
|
|
# 混淆矩阵
|
|
|
_, test_preds = torch.max(test_outputs, 1)
|
|
|
plot_confusion_matrix(
|
|
|
test_targets.cpu().numpy(),
|
|
|
test_preds.cpu().numpy(),
|
|
|
class_names,
|
|
|
os.path.join(args.save_dir, 'plots', 'confusion_matrix.png')
|
|
|
)
|
|
|
|
|
|
# PR曲线
|
|
|
plot_pr_curve(
|
|
|
test_outputs,
|
|
|
test_targets,
|
|
|
num_classes,
|
|
|
class_names,
|
|
|
os.path.join(args.save_dir, 'plots', 'pr_curve.png')
|
|
|
)
|
|
|
|
|
|
# 保存测试结果
|
|
|
results = {
|
|
|
'test_loss': test_loss,
|
|
|
'test_accuracy': test_acc,
|
|
|
'test_precision': test_precision,
|
|
|
'test_recall': test_recall,
|
|
|
'test_f1': test_f1,
|
|
|
'best_epoch': best_epoch + 1,
|
|
|
'best_val_f1': best_val_f1,
|
|
|
'train_time': total_time,
|
|
|
'model': args.model,
|
|
|
'use_mfcc': args.use_mfcc,
|
|
|
'pretrained': args.pretrained
|
|
|
}
|
|
|
|
|
|
# 将结果写入文件
|
|
|
with open(os.path.join(args.save_dir, 'results.txt'), 'w') as f:
|
|
|
for key, value in results.items():
|
|
|
f.write(f"{key}: {value}\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |