ADD file via upload

main
pnmfazke8 3 months ago
parent c8ed6b827f
commit 9ea16e2e07

573
ui.py

@ -0,0 +1,573 @@
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import soundfile as sf
import torchaudio
import matplotlib.font_manager as font_manager
# 添加字体文件路径
font_path = 'SIMSUN.TTC' # 指定字体文件的完整路径
font_manager.fontManager.addfont(font_path)
# 设置为默认字体
plt.rcParams['font.family'] = 'SIMSUN'
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QVBoxLayout, QHBoxLayout,
QLabel, QFileDialog, QComboBox, QWidget, QGroupBox, QGridLayout,
QProgressBar, QMessageBox,QSizePolicy)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QUrl
from PyQt5.QtGui import QFont
from PyQt5.QtMultimedia import QMediaPlayer, QMediaContent
# 导入我们的模型和数据处理
from models import spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101
from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101
from dataset import process_audio_file
class PredictionThread(QThread):
"""用于在后台运行预测的线程"""
prediction_complete = pyqtSignal(dict)
error = pyqtSignal(str)
def __init__(self, model, file_path, use_mfcc, class_names):
super().__init__()
self.model = model
self.file_path = file_path
self.use_mfcc = use_mfcc
self.class_names = class_names
def run(self):
try:
# 处理音频文件
audio_tensor = process_audio_file(self.file_path, self.use_mfcc)
# 确保模型在评估模式
self.model.eval()
# 添加批次维度并进行预测
with torch.no_grad():
audio_tensor = audio_tensor.unsqueeze(0)
if torch.cuda.is_available():
audio_tensor = audio_tensor.cuda()
output = self.model(audio_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
# 获取预测和概率
predicted_idx = torch.argmax(probabilities).item()
predicted_class = self.class_names[predicted_idx]
probs = probabilities.cpu().numpy()
# 返回结果
results = {
"predicted_class": predicted_class,
"predicted_idx": predicted_idx,
"probabilities": probs,
"class_names": self.class_names
}
self.prediction_complete.emit(results)
except Exception as e:
self.error.emit(f"预测过程中出错: {str(e)}")
class EmotionRecognitionApp(QMainWindow):
def __init__(self):
super().__init__()
# 初始化音频播放器
self.media_player = QMediaPlayer()
self.current_audio_path = None
# 初始化模型变量
self.model = None
self.class_names = None
self.use_mfcc = False
self.checkpoints_dir = "checkpoints"
self.initUI()
self.available_models = self.find_available_models()
self.update_model_combo()
def initUI(self):
self.setWindowTitle("语音情感识别系统")
self.setGeometry(100, 100, 1200, 800)
self.setStyleSheet("background-color: #f0f0f0;")
# 创建主布局
main_widget = QWidget()
main_layout = QGridLayout(main_widget)
# 添加标题
title_label = QLabel("语音情感识别系统")
title_label.setFont(QFont("SIMSUN", 16, QFont.Bold))
title_label.setAlignment(Qt.AlignCenter)
title_label.setStyleSheet("color: #2c3e50; margin: 10px;")
main_layout.addWidget(title_label, 0, 0, 1, 2)
# === 左上角:操作区域 ===
operations_group = QGroupBox("操作区")
operations_layout = QGridLayout(operations_group)
# 模型选择
model_label = QLabel("选择模型:")
self.model_combo = QComboBox()
self.model_combo.setMinimumWidth(200)
operations_layout.addWidget(model_label, 0, 0)
operations_layout.addWidget(self.model_combo, 0, 1)
# 加载模型按钮
self.load_model_btn = QPushButton("加载模型")
self.load_model_btn.setStyleSheet("background-color: #3498db; color: white;")
self.load_model_btn.clicked.connect(self.load_model)
operations_layout.addWidget(self.load_model_btn, 0, 2)
# 音频文件选择
file_label = QLabel("音频文件:")
# 使用自定义标签来显示文件路径
self.file_path_label = QLabel("未选择文件")
self.file_path_label.setStyleSheet("""
background-color: white;
padding: 5px;
border: 1px solid #cccccc;
border-radius: 3px;
""")
self.file_path_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.file_path_label.setToolTip("未选择文件")
self.file_path_label.setMinimumWidth(250)
self.file_path_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
self.file_path_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
self.file_path_label.setTextFormat(Qt.PlainText)
self.file_path_label.setWordWrap(True)
self.browse_btn = QPushButton("浏览...")
self.browse_btn.clicked.connect(self.browse_file)
operations_layout.addWidget(file_label, 1, 0)
operations_layout.addWidget(self.file_path_label, 1, 1)
operations_layout.addWidget(self.browse_btn, 1, 2)
# 预测和退出按钮
buttons_layout = QHBoxLayout()
self.predict_btn = QPushButton("进行预测")
self.predict_btn.setStyleSheet("background-color: #2ecc71; color: white; font-weight: bold;")
self.predict_btn.clicked.connect(self.predict)
self.predict_btn.setEnabled(False)
self.exit_btn = QPushButton("退出")
self.exit_btn.setStyleSheet("background-color: #e74c3c; color: white;")
self.exit_btn.clicked.connect(self.close)
buttons_layout.addWidget(self.predict_btn)
buttons_layout.addWidget(self.exit_btn)
operations_layout.addLayout(buttons_layout, 2, 0, 1, 3)
# 进度条初始化
self.progress_bar = QProgressBar()
self.progress_bar.setTextVisible(False)
self.progress_bar.setFixedHeight(20)
# 默认情况下进度条不活动,但占据空间
self.progress_bar.setRange(0, 100) # 设置正常范围
self.progress_bar.setValue(0) # 初始值为0%
self.progress_bar.setStyleSheet("""
QProgressBar {
border: 1px solid #bbb;
border-radius: 5px;
text-align: center;
background-color: #f0f0f0;
}
QProgressBar::chunk {
background-color: #3498db;
width: 10px;
margin: 0.5px;
}
""")
operations_layout.addWidget(self.progress_bar, 3, 0, 1, 3)
main_layout.addWidget(operations_group, 1, 0)
# === 右上角:波形图区域 ===
waveform_group = QGroupBox("音频波形图")
waveform_layout = QVBoxLayout(waveform_group)
# 波形图
self.waveform_figure = plt.figure(figsize=(6, 3))
self.waveform_canvas = FigureCanvas(self.waveform_figure)
waveform_layout.addWidget(self.waveform_canvas)
# 播放控制按钮
playback_layout = QHBoxLayout()
self.play_btn = QPushButton("播放")
self.play_btn.clicked.connect(self.play_audio)
self.play_btn.setEnabled(False)
self.stop_btn = QPushButton("暂停")
self.stop_btn.clicked.connect(self.stop_audio)
self.stop_btn.setEnabled(False)
playback_layout.addWidget(self.play_btn)
playback_layout.addWidget(self.stop_btn)
waveform_layout.addLayout(playback_layout)
main_layout.addWidget(waveform_group, 1, 1)
# === 左下角MFCC区域 ===
mfcc_group = QGroupBox("MFCC 频谱图")
mfcc_layout = QVBoxLayout(mfcc_group)
self.mfcc_figure = plt.figure(figsize=(5, 4))
self.mfcc_canvas = FigureCanvas(self.mfcc_figure)
mfcc_layout.addWidget(self.mfcc_canvas)
main_layout.addWidget(mfcc_group, 2, 0)
# === 右下角:预测结果区域 ===
results_group = QGroupBox("预测结果")
results_layout = QVBoxLayout(results_group)
# 预测结果标签
self.result_label = QLabel("请先加载模型和选择音频文件")
self.result_label.setFont(QFont("SIMSUN", 12))
self.result_label.setAlignment(Qt.AlignCenter)
self.result_label.setStyleSheet("background-color: white; padding: 15px; border-radius: 5px;")
results_layout.addWidget(self.result_label)
# 概率可视化区域
self.result_figure = plt.figure(figsize=(5, 4))
self.result_canvas = FigureCanvas(self.result_figure)
results_layout.addWidget(self.result_canvas)
main_layout.addWidget(results_group, 2, 1)
# 设置行列比例
main_layout.setRowStretch(0, 1) # 标题行
main_layout.setRowStretch(1, 4) # 上半部分
main_layout.setRowStretch(2, 4) # 下半部分
main_layout.setColumnStretch(0, 1)
main_layout.setColumnStretch(1, 1)
# 配置媒体播放器状态监听
self.media_player.stateChanged.connect(self.handle_media_state_changed)
self.setCentralWidget(main_widget)
def find_available_models(self):
"""查找checkpoints文件夹中的可用模型"""
models = []
if not os.path.exists(self.checkpoints_dir):
return models
for folder in os.listdir(self.checkpoints_dir):
folder_path = os.path.join(self.checkpoints_dir, folder)
if os.path.isdir(folder_path):
best_model_path = os.path.join(folder_path, f"best.pth")
# C:\Users\Malong\Desktop\wenhao\语音情感识别\checkpoints\spectrogram_resnet18_20250306_143234\best.pth
# 从中提取spectrogram_resnet18,首先提取最后一个文件夹名称,然后提取前两个单词
print(folder_path)
if os.path.exists(best_model_path):
# 提取信息
model_name = self.extract_model_name(folder_path)
is_mfcc = "spectrogram" in folder
models.append({
"name": folder,
"path": best_model_path,
"model_name": model_name,
"use_mfcc": is_mfcc
})
return models
def extract_model_name(self, model_path):
"""
从模型路径中提取模型名称如spectrogram_resnet18
参数:
model_path (str): 模型文件的完整路径
返回:
str: 提取出的模型名称
"""
# 标准化路径分隔符
model_path = model_path.replace('\\', '/')
# 获取路径的最后一个目录
directory = model_path.split('/')[-1] if '/' in model_path else ""
# 如果目录为空,返回空字符串
if not directory:
return ""
# 分割目录名称
parts = directory.split('_')
# 如果至少有两个部分,提取前两个作为模型名称
if len(parts) >= 2:
return f"{parts[0]}_{parts[1]}"
return directory # 如果没有足够的部分,返回原始目录名
def update_model_combo(self):
"""更新模型下拉列表"""
self.model_combo.clear()
for model_info in self.available_models:
self.model_combo.addItem(model_info["name"], userData=model_info)
if not self.available_models:
self.model_combo.addItem("未找到模型")
def load_model(self):
"""加载选择的模型"""
if not self.available_models:
QMessageBox.warning(self, "错误", "未找到可用模型请确认checkpoints目录中包含训练好的模型文件")
return
try:
# 激活进度条(不确定模式)
self.progress_bar.setRange(0, 0) # 设置为不确定模式
# 获取选择的模型信息
selected_idx = self.model_combo.currentIndex()
model_info = self.model_combo.itemData(selected_idx)
# 模型类型和MFCC标志
model_name = model_info["model_name"]
self.use_mfcc = model_info["use_mfcc"]
model_path = model_info["path"]
# 确定类别数量 (从train.py中加载这里默认为6)
num_classes = 6
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建模型
if model_name == "waveform_resnet18":
self.model = waveform_resnet18(num_classes=num_classes)
elif model_name == "waveform_resnet34":
self.model = waveform_resnet34(num_classes=num_classes)
elif model_name == "waveform_resnet50":
self.model = waveform_resnet50(num_classes=num_classes)
elif model_name == "waveform_resnet101":
self.model = waveform_resnet101(num_classes=num_classes)
elif model_name == "spectrogram_resnet18":
self.model = spectrogram_resnet18(num_classes=num_classes)
elif model_name == "spectrogram_resnet34":
self.model = spectrogram_resnet34(num_classes=num_classes)
elif model_name == "spectrogram_resnet50":
self.model = spectrogram_resnet50(num_classes=num_classes)
elif model_name == "spectrogram_resnet101":
self.model = spectrogram_resnet101(num_classes=num_classes)
else:
raise ValueError(f"不支持的模型类型: {model_name}")
# 加载权重
self.model.load_state_dict(torch.load(model_path, map_location=device))
self.model.to(device)
self.model.eval()
# 设置类别名称 (从数据集中获得,这里硬编码)
# 实际应用中,应该从训练数据中加载或保存在模型权重旁边
self.class_names = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise']
# 更新UI
self.predict_btn.setEnabled(True)
QMessageBox.information(self, "成功", f"成功加载模型: {model_name}")
self.result_label.setText(f"模型已加载: {model_name}\n请选择音频文件进行预测")
except Exception as e:
QMessageBox.critical(self, "错误", f"加载模型时出错: {str(e)}")
finally:
# 恢复进度条为非活动状态
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
def browse_file(self):
"""选择音频文件"""
file_path, _ = QFileDialog.getOpenFileName(
self, "选择音频文件", "", "音频文件 (*.wav *.mp3 *.flac *.ogg)"
)
if file_path:
# 设置文本和工具提示
self.file_path_label.setText(file_path)
self.file_path_label.setToolTip(file_path)
self.current_audio_path = file_path
# 如果模型已加载,则启用预测按钮
if self.model is not None:
self.predict_btn.setEnabled(True)
# 加载并显示波形图
self.load_and_display_waveform(file_path)
# 启用播放按钮
self.play_btn.setEnabled(True)
def load_and_display_waveform(self, file_path):
"""加载并显示波形图和MFCC图"""
try:
# 使用soundfile加载音频
data, samplerate = sf.read(file_path)
# 如果是立体声,转换为单声道
if len(data.shape) > 1 and data.shape[1] > 1:
data = data.mean(axis=1)
# 计算时间轴
time = np.arange(0, len(data)) / samplerate
# 绘制波形图
self.waveform_figure.clear()
ax1 = self.waveform_figure.add_subplot(111)
ax1.plot(time, data, color='blue')
ax1.set_title("音频波形")
ax1.set_xlabel("时间 (秒)")
ax1.set_ylabel("振幅")
self.waveform_figure.tight_layout()
self.waveform_canvas.draw()
# 生成MFCC特征图
self.mfcc_figure.clear()
ax2 = self.mfcc_figure.add_subplot(111)
# 使用torchaudio计算MFCC特征
waveform_tensor = torch.tensor(data).float().unsqueeze(0)
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=samplerate,
n_mfcc=13
)
mfcc_tensor = mfcc_transform(waveform_tensor)
mfcc_data = mfcc_tensor.numpy()[0]
# 绘制MFCC频谱图
img = ax2.imshow(mfcc_data, aspect='auto', origin='lower', interpolation='none')
ax2.set_title("MFCC特征")
ax2.set_ylabel("MFCC系数")
ax2.set_xlabel("时间帧")
# 添加颜色条
self.mfcc_figure.colorbar(img, ax=ax2)
self.mfcc_figure.tight_layout()
self.mfcc_canvas.draw()
except Exception as e:
QMessageBox.warning(self, "错误", f"无法加载或显示音频: {str(e)}")
def play_audio(self):
"""播放音频"""
if self.current_audio_path:
# 使用QMediaPlayer播放音频
url = QUrl.fromLocalFile(self.current_audio_path)
self.media_player.setMedia(QMediaContent(url))
self.media_player.play()
self.play_btn.setEnabled(False)
self.stop_btn.setEnabled(True)
def stop_audio(self):
"""停止音频播放"""
self.media_player.pause()
self.play_btn.setEnabled(True)
self.stop_btn.setEnabled(False)
def handle_media_state_changed(self, state):
"""处理媒体播放状态变化"""
if state == QMediaPlayer.StoppedState:
self.play_btn.setEnabled(True)
self.stop_btn.setEnabled(False)
def predict(self):
"""对选择的音频文件进行预测"""
if not self.model:
QMessageBox.warning(self, "警告", "请先加载模型")
return
file_path = self.file_path_label.text()
if file_path == "未选择文件" or not os.path.exists(file_path):
QMessageBox.warning(self, "警告", "请选择有效的音频文件")
return
# 激活进度条(不确定模式)
self.progress_bar.setRange(0, 0) # 设置为不确定模式
self.predict_btn.setEnabled(False)
self.load_model_btn.setEnabled(False)
self.browse_btn.setEnabled(False)
# 创建并启动预测线程
self.prediction_thread = PredictionThread(
self.model, file_path, self.use_mfcc, self.class_names
)
self.prediction_thread.prediction_complete.connect(self.handle_prediction_result)
self.prediction_thread.error.connect(self.handle_prediction_error)
self.prediction_thread.finished.connect(self.prediction_finished)
self.prediction_thread.start()
def handle_prediction_result(self, results):
"""处理预测结果"""
predicted_class = results["predicted_class"]
probabilities = results["probabilities"]
class_names = results["class_names"]
# 更新结果标签
self.result_label.setText(f"预测结果: {predicted_class}")
# 绘制概率条形图
self.result_figure.clear()
ax = self.result_figure.add_subplot(111)
# 设置中文情感名称标签
emotion_names_zh = {
'angry': '愤怒',
'fear': '恐惧',
'happy': '高兴',
'neutral': '中性',
'sad': '悲伤',
'surprise': '惊讶'
}
zh_labels = [emotion_names_zh.get(name, name) for name in class_names]
bars = ax.bar(zh_labels, probabilities, color='skyblue')
# 突出显示预测的类别
predicted_idx = results["predicted_idx"]
bars[predicted_idx].set_color('orange')
# 添加值标签
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f"{height:.2f}", ha='center', va='bottom')
ax.set_ylabel('概率')
# 设置y轴范围
ax.set_ylim(0, 1)
ax.set_title('各情感类别预测概率')
plt.xticks(rotation=45)
self.result_figure.tight_layout()
self.result_canvas.draw()
def handle_prediction_error(self, error_msg):
"""处理预测错误"""
QMessageBox.critical(self, "预测错误", error_msg)
self.result_label.setText("预测失败,请重试")
def prediction_finished(self):
"""预测完成后的清理工作"""
# 恢复进度条为非活动状态
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
self.predict_btn.setEnabled(True)
self.load_model_btn.setEnabled(True)
self.browse_btn.setEnabled(True)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = EmotionRecognitionApp()
window.show()
sys.exit(app.exec_())
Loading…
Cancel
Save