ADD file via upload

main
pnmfazke8 3 months ago
parent 9ea16e2e07
commit ec57548845

@ -0,0 +1,432 @@
import os
import random
import time
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score, precision_score, recall_score, accuracy_score
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataset import AudioDataset
from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101, spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101
import warnings
warnings.filterwarnings("ignore", message="At least one mel filterbank has all zero values")
def parse_args():
parser = argparse.ArgumentParser(description='训练音频情感分类模型')
# 数据集参数
parser.add_argument('--data_root', type=str, default='dataset', help='数据集根目录')
parser.add_argument('--use_mfcc', action='store_true', help='使用MFCC特征而非原始波形')
parser.add_argument('--val_split', type=float, default=0.2, help='验证集比例')
parser.add_argument('--num_workers', type=int, default=4, help='数据加载器的工作线程数量')
# 模型参数
parser.add_argument('--model', type=str, default='resnet18',
choices=['resnet18', 'resnet34', 'resnet50', 'resnet101'],
help='模型架构')
parser.add_argument('--pretrained', action='store_true', help='使用预训练权重')
# 训练参数
parser.add_argument('--batch_size', type=int, default=32, help='批次大小')
parser.add_argument('--epochs', type=int, default=30, help='训练轮数')
parser.add_argument('--lr', type=float, default=0.001, help='学习率')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='权重衰减')
parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='优化器选择')
parser.add_argument('--scheduler', action='store_true', help='是否使用学习率调度器')
# 其他参数
parser.add_argument('--seed', type=int, default=42, help='随机种子')
parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID')
parser.add_argument('--save_dir', type=str, default='checkpoints', help='模型保存目录')
return parser.parse_args()
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# 如果有GPU还需要设置以下代码
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_model(args, num_classes):
if args.model == 'waveform_resnet18':
return waveform_resnet18(num_classes=num_classes)
elif args.model == 'waveform_resnet34':
return waveform_resnet34(num_classes=num_classes)
elif args.model == 'waveform_resnet50':
return waveform_resnet50(num_classes=num_classes)
elif args.model == 'waveform_resnet101':
return waveform_resnet101(num_classes=num_classes)
elif args.model == 'spectrogram_resnet18':
return spectrogram_resnet18(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet34':
return spectrogram_resnet34(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet50':
return spectrogram_resnet50(num_classes=num_classes, pretrained=args.pretrained)
elif args.model == 'spectrogram_resnet101':
return spectrogram_resnet101(num_classes=num_classes, pretrained=args.pretrained)
else:
raise ValueError(f"不支持的模型: {args.model}")
def get_optimizer(args, model):
if args.optimizer == 'adam':
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'sgd':
return optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
else:
raise ValueError(f"不支持的优化器: {args.optimizer}")
def calculate_metrics(outputs, targets, num_classes):
"""计算准确率、精确率、召回率和F1分数"""
_, preds = torch.max(outputs, 1)
# 转为CPU numpy数组以用于scikit-learn
preds_np = preds.cpu().numpy()
targets_np = targets.cpu().numpy()
# 计算总体准确率
acc = accuracy_score(targets_np, preds_np)
# 计算每个类别的精确率、召回率和F1分数然后求平均
precision = precision_score(targets_np, preds_np, average='macro', zero_division=0)
recall = recall_score(targets_np, preds_np, average='macro', zero_division=0)
f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0)
return acc, precision, recall, f1
def validate(model, val_loader, criterion, device, num_classes):
"""在验证集上评估模型"""
model.eval()
val_loss = 0.0
all_outputs = []
all_targets = []
with torch.no_grad():
for inputs, targets in tqdm(val_loader, desc="验证中", leave=False):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item() * inputs.size(0)
all_outputs.append(outputs)
all_targets.append(targets)
val_loss /= len(val_loader.dataset)
# 合并所有批次的输出和目标
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)
# 计算评估指标
acc, precision, recall, f1 = calculate_metrics(all_outputs, all_targets, num_classes)
return val_loss, acc, precision, recall, f1, all_outputs, all_targets
def save_model(model, path):
"""保存模型和配置"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(model.state_dict(), path)
print(f"模型已保存到 {path}")
def plot_metrics(train_metrics, val_metrics, save_dir):
"""绘制训练和验证指标"""
metrics = ['loss', 'accuracy', 'precision', 'recall', 'f1']
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
for i, metric in enumerate(metrics):
plt.figure(figsize=(10, 6))
plt.plot(train_metrics[i], label=f'Train {metric}')
plt.plot(val_metrics[i], label=f'Validation {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric.capitalize())
plt.title(f'Training and Validation {metric.capitalize()}')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, f'{metric}_plot.png'), dpi=300, bbox_inches='tight')
plt.close()
def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path):
"""绘制混淆矩阵"""
cm = confusion_matrix(true_labels, pred_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def plot_pr_curve(all_outputs, all_targets, num_classes, class_names, save_path):
"""绘制PR曲线"""
# 将输出转换为概率
probs = torch.nn.functional.softmax(all_outputs, dim=1).cpu().numpy()
targets = all_targets.cpu().numpy()
plt.figure(figsize=(12, 10))
# 为每个类别绘制PR曲线
for i in range(num_classes):
precision, recall, _ = precision_recall_curve(
(targets == i).astype(int),
probs[:, i]
)
plt.plot(recall, precision, lw=2, label=f'Class {class_names[i]}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='best')
plt.grid(True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def main():
args = parse_args()
# 判断是否使用MFCC特征如果使用则在模型名前加上spectrogram
if args.use_mfcc:
args.model = 'spectrogram_' + args.model
else:
args.model = 'waveform_' + args.model
set_seed(args.seed)
# 设置设备
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 给save_dir添加模型名称、是否使用MFCC特征和日期时间
args.save_dir = os.path.join(args.save_dir, f"{args.model}_{time.strftime('%Y%m%d_%H%M%S')}")
print(f"模型保存目录: {args.save_dir}")
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 加载数据集 - 修改为分别加载训练集和验证集
train_dir = os.path.join(args.data_root, 'train')
val_dir = os.path.join(args.data_root, 'val')
train_dataset = AudioDataset(train_dir, use_mfcc=args.use_mfcc)
val_dataset = AudioDataset(val_dir, use_mfcc=args.use_mfcc)
# 测试集使用验证集
test_dataset = val_dataset
# 定义类别名称
class_names = list(train_dataset.label_to_int.keys())
num_classes = len(class_names)
print(f"类别数量: {num_classes}")
print(f"类别名称: {class_names}")
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 创建数据加载器
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
# 创建模型
model = get_model(args, num_classes).to(device)
print(f"模型总参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(args, model)
# 学习率调度器
scheduler = None
if args.scheduler:
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
# 训练跟踪变量
best_val_f1 = 0.0
best_epoch = 0
best_model_path = ""
# 保存训练和验证指标
train_losses = []
train_accs = []
train_precisions = []
train_recalls = []
train_f1s = []
val_losses = []
val_accs = []
val_precisions = []
val_recalls = []
val_f1s = []
# 记录最佳验证集输出和标签(用于绘图)
best_val_outputs = None
best_val_targets = None
# 训练循环
start_time = time.time()
for epoch in range(args.epochs):
model.train()
train_loss = 0.0
all_outputs = []
all_targets = []
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for inputs, targets in train_pbar:
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
all_outputs.append(outputs)
all_targets.append(targets)
# 更新进度条
train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
# 计算平均训练损失
train_loss /= len(train_dataset)
# 计算训练指标
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)
train_acc, train_precision, train_recall, train_f1 = calculate_metrics(all_outputs, all_targets, num_classes)
# 验证
val_loss, val_acc, val_precision, val_recall, val_f1, val_outputs, val_targets = validate(
model, val_loader, criterion, device, num_classes
)
# 学习率调整
if scheduler:
scheduler.step(val_loss)
# 保存指标
train_losses.append(train_loss)
train_accs.append(train_acc)
train_precisions.append(train_precision)
train_recalls.append(train_recall)
train_f1s.append(train_f1)
val_losses.append(val_loss)
val_accs.append(val_acc)
val_precisions.append(val_precision)
val_recalls.append(val_recall)
val_f1s.append(val_f1)
# 打印当前epoch的指标
print(f"Epoch {epoch+1}/{args.epochs}")
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}")
print(f"Train Recall: {train_recall:.4f} | Val Recall: {val_recall:.4f}")
print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
# 保存最佳模型
if val_f1 > best_val_f1:
best_val_f1 = val_f1
best_epoch = epoch
best_model_path = os.path.join(args.save_dir, f"best.pth")
save_model(model, best_model_path)
# 保存最佳验证结果用于后续绘图
best_val_outputs = val_outputs
best_val_targets = val_targets
# 训练完成
total_time = time.time() - start_time
print(f"训练完成,总耗时: {total_time/60:.2f} 分钟")
print(f"最佳验证 F1: {best_val_f1:.4f},在 Epoch {best_epoch+1}")
# 绘制训练和验证指标曲线
plot_metrics(
[train_losses, train_accs, train_precisions, train_recalls, train_f1s],
[val_losses, val_accs, val_precisions, val_recalls, val_f1s],
os.path.join(args.save_dir, 'plots')
)
# 加载最佳模型进行测试
print(f"加载最佳模型进行测试...")
model.load_state_dict(torch.load(best_model_path))
test_loss, test_acc, test_precision, test_recall, test_f1, test_outputs, test_targets = validate(
model, test_loader, criterion, device, num_classes
)
print(f"测试指标:")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_acc:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1 Score: {test_f1:.4f}")
# 混淆矩阵
_, test_preds = torch.max(test_outputs, 1)
plot_confusion_matrix(
test_targets.cpu().numpy(),
test_preds.cpu().numpy(),
class_names,
os.path.join(args.save_dir, 'plots', 'confusion_matrix.png')
)
# PR曲线
plot_pr_curve(
test_outputs,
test_targets,
num_classes,
class_names,
os.path.join(args.save_dir, 'plots', 'pr_curve.png')
)
# 保存测试结果
results = {
'test_loss': test_loss,
'test_accuracy': test_acc,
'test_precision': test_precision,
'test_recall': test_recall,
'test_f1': test_f1,
'best_epoch': best_epoch + 1,
'best_val_f1': best_val_f1,
'train_time': total_time,
'model': args.model,
'use_mfcc': args.use_mfcc,
'pretrained': args.pretrained
}
# 将结果写入文件
with open(os.path.join(args.save_dir, 'results.txt'), 'w') as f:
for key, value in results.items():
f.write(f"{key}: {value}\n")
if __name__ == "__main__":
main()
Loading…
Cancel
Save