|
|
|
@ -0,0 +1,573 @@
|
|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import soundfile as sf
|
|
|
|
|
import torchaudio
|
|
|
|
|
import matplotlib.font_manager as font_manager
|
|
|
|
|
|
|
|
|
|
# 添加字体文件路径
|
|
|
|
|
font_path = 'SIMSUN.TTC' # 指定字体文件的完整路径
|
|
|
|
|
font_manager.fontManager.addfont(font_path)
|
|
|
|
|
|
|
|
|
|
# 设置为默认字体
|
|
|
|
|
plt.rcParams['font.family'] = 'SIMSUN'
|
|
|
|
|
|
|
|
|
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
|
|
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QVBoxLayout, QHBoxLayout,
|
|
|
|
|
QLabel, QFileDialog, QComboBox, QWidget, QGroupBox, QGridLayout,
|
|
|
|
|
QProgressBar, QMessageBox,QSizePolicy)
|
|
|
|
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QUrl
|
|
|
|
|
from PyQt5.QtGui import QFont
|
|
|
|
|
from PyQt5.QtMultimedia import QMediaPlayer, QMediaContent
|
|
|
|
|
|
|
|
|
|
# 导入我们的模型和数据处理
|
|
|
|
|
from models import spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101
|
|
|
|
|
from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101
|
|
|
|
|
from dataset import process_audio_file
|
|
|
|
|
|
|
|
|
|
class PredictionThread(QThread):
|
|
|
|
|
"""用于在后台运行预测的线程"""
|
|
|
|
|
prediction_complete = pyqtSignal(dict)
|
|
|
|
|
error = pyqtSignal(str)
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, file_path, use_mfcc, class_names):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.model = model
|
|
|
|
|
self.file_path = file_path
|
|
|
|
|
self.use_mfcc = use_mfcc
|
|
|
|
|
self.class_names = class_names
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
try:
|
|
|
|
|
# 处理音频文件
|
|
|
|
|
audio_tensor = process_audio_file(self.file_path, self.use_mfcc)
|
|
|
|
|
|
|
|
|
|
# 确保模型在评估模式
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
# 添加批次维度并进行预测
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
audio_tensor = audio_tensor.unsqueeze(0)
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
audio_tensor = audio_tensor.cuda()
|
|
|
|
|
|
|
|
|
|
output = self.model(audio_tensor)
|
|
|
|
|
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
|
|
|
|
|
|
|
|
|
|
# 获取预测和概率
|
|
|
|
|
predicted_idx = torch.argmax(probabilities).item()
|
|
|
|
|
predicted_class = self.class_names[predicted_idx]
|
|
|
|
|
probs = probabilities.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
# 返回结果
|
|
|
|
|
results = {
|
|
|
|
|
"predicted_class": predicted_class,
|
|
|
|
|
"predicted_idx": predicted_idx,
|
|
|
|
|
"probabilities": probs,
|
|
|
|
|
"class_names": self.class_names
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
self.prediction_complete.emit(results)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.error.emit(f"预测过程中出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmotionRecognitionApp(QMainWindow):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# 初始化音频播放器
|
|
|
|
|
self.media_player = QMediaPlayer()
|
|
|
|
|
self.current_audio_path = None
|
|
|
|
|
|
|
|
|
|
# 初始化模型变量
|
|
|
|
|
self.model = None
|
|
|
|
|
self.class_names = None
|
|
|
|
|
self.use_mfcc = False
|
|
|
|
|
self.checkpoints_dir = "checkpoints"
|
|
|
|
|
|
|
|
|
|
self.initUI()
|
|
|
|
|
|
|
|
|
|
self.available_models = self.find_available_models()
|
|
|
|
|
self.update_model_combo()
|
|
|
|
|
|
|
|
|
|
def initUI(self):
|
|
|
|
|
self.setWindowTitle("语音情感识别系统")
|
|
|
|
|
self.setGeometry(100, 100, 1200, 800)
|
|
|
|
|
self.setStyleSheet("background-color: #f0f0f0;")
|
|
|
|
|
|
|
|
|
|
# 创建主布局
|
|
|
|
|
main_widget = QWidget()
|
|
|
|
|
main_layout = QGridLayout(main_widget)
|
|
|
|
|
|
|
|
|
|
# 添加标题
|
|
|
|
|
title_label = QLabel("语音情感识别系统")
|
|
|
|
|
title_label.setFont(QFont("SIMSUN", 16, QFont.Bold))
|
|
|
|
|
title_label.setAlignment(Qt.AlignCenter)
|
|
|
|
|
title_label.setStyleSheet("color: #2c3e50; margin: 10px;")
|
|
|
|
|
main_layout.addWidget(title_label, 0, 0, 1, 2)
|
|
|
|
|
|
|
|
|
|
# === 左上角:操作区域 ===
|
|
|
|
|
operations_group = QGroupBox("操作区")
|
|
|
|
|
operations_layout = QGridLayout(operations_group)
|
|
|
|
|
|
|
|
|
|
# 模型选择
|
|
|
|
|
model_label = QLabel("选择模型:")
|
|
|
|
|
self.model_combo = QComboBox()
|
|
|
|
|
self.model_combo.setMinimumWidth(200)
|
|
|
|
|
operations_layout.addWidget(model_label, 0, 0)
|
|
|
|
|
operations_layout.addWidget(self.model_combo, 0, 1)
|
|
|
|
|
|
|
|
|
|
# 加载模型按钮
|
|
|
|
|
self.load_model_btn = QPushButton("加载模型")
|
|
|
|
|
self.load_model_btn.setStyleSheet("background-color: #3498db; color: white;")
|
|
|
|
|
self.load_model_btn.clicked.connect(self.load_model)
|
|
|
|
|
operations_layout.addWidget(self.load_model_btn, 0, 2)
|
|
|
|
|
|
|
|
|
|
# 音频文件选择
|
|
|
|
|
file_label = QLabel("音频文件:")
|
|
|
|
|
# 使用自定义标签来显示文件路径
|
|
|
|
|
self.file_path_label = QLabel("未选择文件")
|
|
|
|
|
self.file_path_label.setStyleSheet("""
|
|
|
|
|
background-color: white;
|
|
|
|
|
padding: 5px;
|
|
|
|
|
border: 1px solid #cccccc;
|
|
|
|
|
border-radius: 3px;
|
|
|
|
|
""")
|
|
|
|
|
self.file_path_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
|
|
|
|
self.file_path_label.setToolTip("未选择文件")
|
|
|
|
|
self.file_path_label.setMinimumWidth(250)
|
|
|
|
|
self.file_path_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
|
|
|
|
|
self.file_path_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
|
|
|
|
|
self.file_path_label.setTextFormat(Qt.PlainText)
|
|
|
|
|
self.file_path_label.setWordWrap(True)
|
|
|
|
|
|
|
|
|
|
self.browse_btn = QPushButton("浏览...")
|
|
|
|
|
self.browse_btn.clicked.connect(self.browse_file)
|
|
|
|
|
operations_layout.addWidget(file_label, 1, 0)
|
|
|
|
|
operations_layout.addWidget(self.file_path_label, 1, 1)
|
|
|
|
|
operations_layout.addWidget(self.browse_btn, 1, 2)
|
|
|
|
|
|
|
|
|
|
# 预测和退出按钮
|
|
|
|
|
buttons_layout = QHBoxLayout()
|
|
|
|
|
self.predict_btn = QPushButton("进行预测")
|
|
|
|
|
self.predict_btn.setStyleSheet("background-color: #2ecc71; color: white; font-weight: bold;")
|
|
|
|
|
self.predict_btn.clicked.connect(self.predict)
|
|
|
|
|
self.predict_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
self.exit_btn = QPushButton("退出")
|
|
|
|
|
self.exit_btn.setStyleSheet("background-color: #e74c3c; color: white;")
|
|
|
|
|
self.exit_btn.clicked.connect(self.close)
|
|
|
|
|
|
|
|
|
|
buttons_layout.addWidget(self.predict_btn)
|
|
|
|
|
buttons_layout.addWidget(self.exit_btn)
|
|
|
|
|
operations_layout.addLayout(buttons_layout, 2, 0, 1, 3)
|
|
|
|
|
|
|
|
|
|
# 进度条初始化
|
|
|
|
|
self.progress_bar = QProgressBar()
|
|
|
|
|
self.progress_bar.setTextVisible(False)
|
|
|
|
|
self.progress_bar.setFixedHeight(20)
|
|
|
|
|
# 默认情况下进度条不活动,但占据空间
|
|
|
|
|
self.progress_bar.setRange(0, 100) # 设置正常范围
|
|
|
|
|
self.progress_bar.setValue(0) # 初始值为0%
|
|
|
|
|
self.progress_bar.setStyleSheet("""
|
|
|
|
|
QProgressBar {
|
|
|
|
|
border: 1px solid #bbb;
|
|
|
|
|
border-radius: 5px;
|
|
|
|
|
text-align: center;
|
|
|
|
|
background-color: #f0f0f0;
|
|
|
|
|
}
|
|
|
|
|
QProgressBar::chunk {
|
|
|
|
|
background-color: #3498db;
|
|
|
|
|
width: 10px;
|
|
|
|
|
margin: 0.5px;
|
|
|
|
|
}
|
|
|
|
|
""")
|
|
|
|
|
operations_layout.addWidget(self.progress_bar, 3, 0, 1, 3)
|
|
|
|
|
|
|
|
|
|
main_layout.addWidget(operations_group, 1, 0)
|
|
|
|
|
|
|
|
|
|
# === 右上角:波形图区域 ===
|
|
|
|
|
waveform_group = QGroupBox("音频波形图")
|
|
|
|
|
waveform_layout = QVBoxLayout(waveform_group)
|
|
|
|
|
|
|
|
|
|
# 波形图
|
|
|
|
|
self.waveform_figure = plt.figure(figsize=(6, 3))
|
|
|
|
|
self.waveform_canvas = FigureCanvas(self.waveform_figure)
|
|
|
|
|
waveform_layout.addWidget(self.waveform_canvas)
|
|
|
|
|
|
|
|
|
|
# 播放控制按钮
|
|
|
|
|
playback_layout = QHBoxLayout()
|
|
|
|
|
self.play_btn = QPushButton("播放")
|
|
|
|
|
self.play_btn.clicked.connect(self.play_audio)
|
|
|
|
|
self.play_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
self.stop_btn = QPushButton("暂停")
|
|
|
|
|
self.stop_btn.clicked.connect(self.stop_audio)
|
|
|
|
|
self.stop_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
playback_layout.addWidget(self.play_btn)
|
|
|
|
|
playback_layout.addWidget(self.stop_btn)
|
|
|
|
|
waveform_layout.addLayout(playback_layout)
|
|
|
|
|
|
|
|
|
|
main_layout.addWidget(waveform_group, 1, 1)
|
|
|
|
|
|
|
|
|
|
# === 左下角:MFCC区域 ===
|
|
|
|
|
mfcc_group = QGroupBox("MFCC 频谱图")
|
|
|
|
|
mfcc_layout = QVBoxLayout(mfcc_group)
|
|
|
|
|
|
|
|
|
|
self.mfcc_figure = plt.figure(figsize=(5, 4))
|
|
|
|
|
self.mfcc_canvas = FigureCanvas(self.mfcc_figure)
|
|
|
|
|
mfcc_layout.addWidget(self.mfcc_canvas)
|
|
|
|
|
|
|
|
|
|
main_layout.addWidget(mfcc_group, 2, 0)
|
|
|
|
|
|
|
|
|
|
# === 右下角:预测结果区域 ===
|
|
|
|
|
results_group = QGroupBox("预测结果")
|
|
|
|
|
results_layout = QVBoxLayout(results_group)
|
|
|
|
|
|
|
|
|
|
# 预测结果标签
|
|
|
|
|
self.result_label = QLabel("请先加载模型和选择音频文件")
|
|
|
|
|
self.result_label.setFont(QFont("SIMSUN", 12))
|
|
|
|
|
self.result_label.setAlignment(Qt.AlignCenter)
|
|
|
|
|
self.result_label.setStyleSheet("background-color: white; padding: 15px; border-radius: 5px;")
|
|
|
|
|
results_layout.addWidget(self.result_label)
|
|
|
|
|
|
|
|
|
|
# 概率可视化区域
|
|
|
|
|
self.result_figure = plt.figure(figsize=(5, 4))
|
|
|
|
|
self.result_canvas = FigureCanvas(self.result_figure)
|
|
|
|
|
results_layout.addWidget(self.result_canvas)
|
|
|
|
|
|
|
|
|
|
main_layout.addWidget(results_group, 2, 1)
|
|
|
|
|
|
|
|
|
|
# 设置行列比例
|
|
|
|
|
main_layout.setRowStretch(0, 1) # 标题行
|
|
|
|
|
main_layout.setRowStretch(1, 4) # 上半部分
|
|
|
|
|
main_layout.setRowStretch(2, 4) # 下半部分
|
|
|
|
|
main_layout.setColumnStretch(0, 1)
|
|
|
|
|
main_layout.setColumnStretch(1, 1)
|
|
|
|
|
|
|
|
|
|
# 配置媒体播放器状态监听
|
|
|
|
|
self.media_player.stateChanged.connect(self.handle_media_state_changed)
|
|
|
|
|
|
|
|
|
|
self.setCentralWidget(main_widget)
|
|
|
|
|
|
|
|
|
|
def find_available_models(self):
|
|
|
|
|
"""查找checkpoints文件夹中的可用模型"""
|
|
|
|
|
models = []
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.checkpoints_dir):
|
|
|
|
|
return models
|
|
|
|
|
|
|
|
|
|
for folder in os.listdir(self.checkpoints_dir):
|
|
|
|
|
folder_path = os.path.join(self.checkpoints_dir, folder)
|
|
|
|
|
if os.path.isdir(folder_path):
|
|
|
|
|
best_model_path = os.path.join(folder_path, f"best.pth")
|
|
|
|
|
# C:\Users\Malong\Desktop\wenhao\语音情感识别\checkpoints\spectrogram_resnet18_20250306_143234\best.pth
|
|
|
|
|
# 从中提取spectrogram_resnet18,首先提取最后一个文件夹名称,然后提取前两个单词
|
|
|
|
|
print(folder_path)
|
|
|
|
|
if os.path.exists(best_model_path):
|
|
|
|
|
# 提取信息
|
|
|
|
|
model_name = self.extract_model_name(folder_path)
|
|
|
|
|
is_mfcc = "spectrogram" in folder
|
|
|
|
|
models.append({
|
|
|
|
|
"name": folder,
|
|
|
|
|
"path": best_model_path,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"use_mfcc": is_mfcc
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return models
|
|
|
|
|
def extract_model_name(self, model_path):
|
|
|
|
|
"""
|
|
|
|
|
从模型路径中提取模型名称(如spectrogram_resnet18)
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
model_path (str): 模型文件的完整路径
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
str: 提取出的模型名称
|
|
|
|
|
"""
|
|
|
|
|
# 标准化路径分隔符
|
|
|
|
|
model_path = model_path.replace('\\', '/')
|
|
|
|
|
|
|
|
|
|
# 获取路径的最后一个目录
|
|
|
|
|
directory = model_path.split('/')[-1] if '/' in model_path else ""
|
|
|
|
|
|
|
|
|
|
# 如果目录为空,返回空字符串
|
|
|
|
|
if not directory:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
# 分割目录名称
|
|
|
|
|
parts = directory.split('_')
|
|
|
|
|
|
|
|
|
|
# 如果至少有两个部分,提取前两个作为模型名称
|
|
|
|
|
if len(parts) >= 2:
|
|
|
|
|
return f"{parts[0]}_{parts[1]}"
|
|
|
|
|
|
|
|
|
|
return directory # 如果没有足够的部分,返回原始目录名
|
|
|
|
|
|
|
|
|
|
def update_model_combo(self):
|
|
|
|
|
"""更新模型下拉列表"""
|
|
|
|
|
self.model_combo.clear()
|
|
|
|
|
for model_info in self.available_models:
|
|
|
|
|
self.model_combo.addItem(model_info["name"], userData=model_info)
|
|
|
|
|
|
|
|
|
|
if not self.available_models:
|
|
|
|
|
self.model_combo.addItem("未找到模型")
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
|
|
|
|
"""加载选择的模型"""
|
|
|
|
|
if not self.available_models:
|
|
|
|
|
QMessageBox.warning(self, "错误", "未找到可用模型,请确认checkpoints目录中包含训练好的模型文件")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 激活进度条(不确定模式)
|
|
|
|
|
self.progress_bar.setRange(0, 0) # 设置为不确定模式
|
|
|
|
|
|
|
|
|
|
# 获取选择的模型信息
|
|
|
|
|
selected_idx = self.model_combo.currentIndex()
|
|
|
|
|
model_info = self.model_combo.itemData(selected_idx)
|
|
|
|
|
|
|
|
|
|
# 模型类型和MFCC标志
|
|
|
|
|
model_name = model_info["model_name"]
|
|
|
|
|
self.use_mfcc = model_info["use_mfcc"]
|
|
|
|
|
model_path = model_info["path"]
|
|
|
|
|
|
|
|
|
|
# 确定类别数量 (从train.py中加载,这里默认为6)
|
|
|
|
|
num_classes = 6
|
|
|
|
|
|
|
|
|
|
# 设置设备
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
|
# 创建模型
|
|
|
|
|
if model_name == "waveform_resnet18":
|
|
|
|
|
self.model = waveform_resnet18(num_classes=num_classes)
|
|
|
|
|
elif model_name == "waveform_resnet34":
|
|
|
|
|
self.model = waveform_resnet34(num_classes=num_classes)
|
|
|
|
|
elif model_name == "waveform_resnet50":
|
|
|
|
|
self.model = waveform_resnet50(num_classes=num_classes)
|
|
|
|
|
elif model_name == "waveform_resnet101":
|
|
|
|
|
self.model = waveform_resnet101(num_classes=num_classes)
|
|
|
|
|
elif model_name == "spectrogram_resnet18":
|
|
|
|
|
self.model = spectrogram_resnet18(num_classes=num_classes)
|
|
|
|
|
elif model_name == "spectrogram_resnet34":
|
|
|
|
|
self.model = spectrogram_resnet34(num_classes=num_classes)
|
|
|
|
|
elif model_name == "spectrogram_resnet50":
|
|
|
|
|
self.model = spectrogram_resnet50(num_classes=num_classes)
|
|
|
|
|
elif model_name == "spectrogram_resnet101":
|
|
|
|
|
self.model = spectrogram_resnet101(num_classes=num_classes)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的模型类型: {model_name}")
|
|
|
|
|
|
|
|
|
|
# 加载权重
|
|
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
|
self.model.to(device)
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
# 设置类别名称 (从数据集中获得,这里硬编码)
|
|
|
|
|
# 实际应用中,应该从训练数据中加载或保存在模型权重旁边
|
|
|
|
|
self.class_names = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
|
|
|
|
|
|
|
|
|
# 更新UI
|
|
|
|
|
self.predict_btn.setEnabled(True)
|
|
|
|
|
QMessageBox.information(self, "成功", f"成功加载模型: {model_name}")
|
|
|
|
|
self.result_label.setText(f"模型已加载: {model_name}\n请选择音频文件进行预测")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
QMessageBox.critical(self, "错误", f"加载模型时出错: {str(e)}")
|
|
|
|
|
finally:
|
|
|
|
|
# 恢复进度条为非活动状态
|
|
|
|
|
self.progress_bar.setRange(0, 100)
|
|
|
|
|
self.progress_bar.setValue(0)
|
|
|
|
|
|
|
|
|
|
def browse_file(self):
|
|
|
|
|
"""选择音频文件"""
|
|
|
|
|
file_path, _ = QFileDialog.getOpenFileName(
|
|
|
|
|
self, "选择音频文件", "", "音频文件 (*.wav *.mp3 *.flac *.ogg)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if file_path:
|
|
|
|
|
# 设置文本和工具提示
|
|
|
|
|
self.file_path_label.setText(file_path)
|
|
|
|
|
self.file_path_label.setToolTip(file_path)
|
|
|
|
|
self.current_audio_path = file_path
|
|
|
|
|
|
|
|
|
|
# 如果模型已加载,则启用预测按钮
|
|
|
|
|
if self.model is not None:
|
|
|
|
|
self.predict_btn.setEnabled(True)
|
|
|
|
|
|
|
|
|
|
# 加载并显示波形图
|
|
|
|
|
self.load_and_display_waveform(file_path)
|
|
|
|
|
|
|
|
|
|
# 启用播放按钮
|
|
|
|
|
self.play_btn.setEnabled(True)
|
|
|
|
|
|
|
|
|
|
def load_and_display_waveform(self, file_path):
|
|
|
|
|
"""加载并显示波形图和MFCC图"""
|
|
|
|
|
try:
|
|
|
|
|
# 使用soundfile加载音频
|
|
|
|
|
data, samplerate = sf.read(file_path)
|
|
|
|
|
|
|
|
|
|
# 如果是立体声,转换为单声道
|
|
|
|
|
if len(data.shape) > 1 and data.shape[1] > 1:
|
|
|
|
|
data = data.mean(axis=1)
|
|
|
|
|
|
|
|
|
|
# 计算时间轴
|
|
|
|
|
time = np.arange(0, len(data)) / samplerate
|
|
|
|
|
|
|
|
|
|
# 绘制波形图
|
|
|
|
|
self.waveform_figure.clear()
|
|
|
|
|
ax1 = self.waveform_figure.add_subplot(111)
|
|
|
|
|
ax1.plot(time, data, color='blue')
|
|
|
|
|
ax1.set_title("音频波形")
|
|
|
|
|
ax1.set_xlabel("时间 (秒)")
|
|
|
|
|
ax1.set_ylabel("振幅")
|
|
|
|
|
self.waveform_figure.tight_layout()
|
|
|
|
|
self.waveform_canvas.draw()
|
|
|
|
|
|
|
|
|
|
# 生成MFCC特征图
|
|
|
|
|
self.mfcc_figure.clear()
|
|
|
|
|
ax2 = self.mfcc_figure.add_subplot(111)
|
|
|
|
|
|
|
|
|
|
# 使用torchaudio计算MFCC特征
|
|
|
|
|
waveform_tensor = torch.tensor(data).float().unsqueeze(0)
|
|
|
|
|
mfcc_transform = torchaudio.transforms.MFCC(
|
|
|
|
|
sample_rate=samplerate,
|
|
|
|
|
n_mfcc=13
|
|
|
|
|
)
|
|
|
|
|
mfcc_tensor = mfcc_transform(waveform_tensor)
|
|
|
|
|
mfcc_data = mfcc_tensor.numpy()[0]
|
|
|
|
|
|
|
|
|
|
# 绘制MFCC频谱图
|
|
|
|
|
img = ax2.imshow(mfcc_data, aspect='auto', origin='lower', interpolation='none')
|
|
|
|
|
ax2.set_title("MFCC特征")
|
|
|
|
|
ax2.set_ylabel("MFCC系数")
|
|
|
|
|
ax2.set_xlabel("时间帧")
|
|
|
|
|
|
|
|
|
|
# 添加颜色条
|
|
|
|
|
self.mfcc_figure.colorbar(img, ax=ax2)
|
|
|
|
|
self.mfcc_figure.tight_layout()
|
|
|
|
|
self.mfcc_canvas.draw()
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
QMessageBox.warning(self, "错误", f"无法加载或显示音频: {str(e)}")
|
|
|
|
|
|
|
|
|
|
def play_audio(self):
|
|
|
|
|
"""播放音频"""
|
|
|
|
|
if self.current_audio_path:
|
|
|
|
|
# 使用QMediaPlayer播放音频
|
|
|
|
|
url = QUrl.fromLocalFile(self.current_audio_path)
|
|
|
|
|
self.media_player.setMedia(QMediaContent(url))
|
|
|
|
|
self.media_player.play()
|
|
|
|
|
self.play_btn.setEnabled(False)
|
|
|
|
|
self.stop_btn.setEnabled(True)
|
|
|
|
|
|
|
|
|
|
def stop_audio(self):
|
|
|
|
|
"""停止音频播放"""
|
|
|
|
|
self.media_player.pause()
|
|
|
|
|
self.play_btn.setEnabled(True)
|
|
|
|
|
self.stop_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
def handle_media_state_changed(self, state):
|
|
|
|
|
"""处理媒体播放状态变化"""
|
|
|
|
|
if state == QMediaPlayer.StoppedState:
|
|
|
|
|
self.play_btn.setEnabled(True)
|
|
|
|
|
self.stop_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
def predict(self):
|
|
|
|
|
"""对选择的音频文件进行预测"""
|
|
|
|
|
if not self.model:
|
|
|
|
|
QMessageBox.warning(self, "警告", "请先加载模型")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
file_path = self.file_path_label.text()
|
|
|
|
|
if file_path == "未选择文件" or not os.path.exists(file_path):
|
|
|
|
|
QMessageBox.warning(self, "警告", "请选择有效的音频文件")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 激活进度条(不确定模式)
|
|
|
|
|
self.progress_bar.setRange(0, 0) # 设置为不确定模式
|
|
|
|
|
self.predict_btn.setEnabled(False)
|
|
|
|
|
self.load_model_btn.setEnabled(False)
|
|
|
|
|
self.browse_btn.setEnabled(False)
|
|
|
|
|
|
|
|
|
|
# 创建并启动预测线程
|
|
|
|
|
self.prediction_thread = PredictionThread(
|
|
|
|
|
self.model, file_path, self.use_mfcc, self.class_names
|
|
|
|
|
)
|
|
|
|
|
self.prediction_thread.prediction_complete.connect(self.handle_prediction_result)
|
|
|
|
|
self.prediction_thread.error.connect(self.handle_prediction_error)
|
|
|
|
|
self.prediction_thread.finished.connect(self.prediction_finished)
|
|
|
|
|
self.prediction_thread.start()
|
|
|
|
|
|
|
|
|
|
def handle_prediction_result(self, results):
|
|
|
|
|
"""处理预测结果"""
|
|
|
|
|
predicted_class = results["predicted_class"]
|
|
|
|
|
probabilities = results["probabilities"]
|
|
|
|
|
class_names = results["class_names"]
|
|
|
|
|
|
|
|
|
|
# 更新结果标签
|
|
|
|
|
self.result_label.setText(f"预测结果: {predicted_class}")
|
|
|
|
|
|
|
|
|
|
# 绘制概率条形图
|
|
|
|
|
self.result_figure.clear()
|
|
|
|
|
ax = self.result_figure.add_subplot(111)
|
|
|
|
|
|
|
|
|
|
# 设置中文情感名称标签
|
|
|
|
|
emotion_names_zh = {
|
|
|
|
|
'angry': '愤怒',
|
|
|
|
|
'fear': '恐惧',
|
|
|
|
|
'happy': '高兴',
|
|
|
|
|
'neutral': '中性',
|
|
|
|
|
'sad': '悲伤',
|
|
|
|
|
'surprise': '惊讶'
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
zh_labels = [emotion_names_zh.get(name, name) for name in class_names]
|
|
|
|
|
|
|
|
|
|
bars = ax.bar(zh_labels, probabilities, color='skyblue')
|
|
|
|
|
|
|
|
|
|
# 突出显示预测的类别
|
|
|
|
|
predicted_idx = results["predicted_idx"]
|
|
|
|
|
bars[predicted_idx].set_color('orange')
|
|
|
|
|
|
|
|
|
|
# 添加值标签
|
|
|
|
|
for bar in bars:
|
|
|
|
|
height = bar.get_height()
|
|
|
|
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
|
|
|
|
f"{height:.2f}", ha='center', va='bottom')
|
|
|
|
|
|
|
|
|
|
ax.set_ylabel('概率')
|
|
|
|
|
# 设置y轴范围
|
|
|
|
|
ax.set_ylim(0, 1)
|
|
|
|
|
ax.set_title('各情感类别预测概率')
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
self.result_figure.tight_layout()
|
|
|
|
|
|
|
|
|
|
self.result_canvas.draw()
|
|
|
|
|
|
|
|
|
|
def handle_prediction_error(self, error_msg):
|
|
|
|
|
"""处理预测错误"""
|
|
|
|
|
QMessageBox.critical(self, "预测错误", error_msg)
|
|
|
|
|
self.result_label.setText("预测失败,请重试")
|
|
|
|
|
|
|
|
|
|
def prediction_finished(self):
|
|
|
|
|
"""预测完成后的清理工作"""
|
|
|
|
|
# 恢复进度条为非活动状态
|
|
|
|
|
self.progress_bar.setRange(0, 100)
|
|
|
|
|
self.progress_bar.setValue(0)
|
|
|
|
|
self.predict_btn.setEnabled(True)
|
|
|
|
|
self.load_model_btn.setEnabled(True)
|
|
|
|
|
self.browse_btn.setEnabled(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
app = QApplication(sys.argv)
|
|
|
|
|
window = EmotionRecognitionApp()
|
|
|
|
|
window.show()
|
|
|
|
|
sys.exit(app.exec_())
|