diff --git a/ui.py b/ui.py new file mode 100644 index 0000000..4b76ec8 --- /dev/null +++ b/ui.py @@ -0,0 +1,573 @@ +import sys +import os +import numpy as np +import torch +import matplotlib.pyplot as plt +import soundfile as sf +import torchaudio +import matplotlib.font_manager as font_manager + +# 添加字体文件路径 +font_path = 'SIMSUN.TTC' # 指定字体文件的完整路径 +font_manager.fontManager.addfont(font_path) + +# 设置为默认字体 +plt.rcParams['font.family'] = 'SIMSUN' + +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QVBoxLayout, QHBoxLayout, + QLabel, QFileDialog, QComboBox, QWidget, QGroupBox, QGridLayout, + QProgressBar, QMessageBox,QSizePolicy) +from PyQt5.QtCore import Qt, QThread, pyqtSignal, QUrl +from PyQt5.QtGui import QFont +from PyQt5.QtMultimedia import QMediaPlayer, QMediaContent + +# 导入我们的模型和数据处理 +from models import spectrogram_resnet18, spectrogram_resnet34, spectrogram_resnet50, spectrogram_resnet101 +from models import waveform_resnet18, waveform_resnet34, waveform_resnet50, waveform_resnet101 +from dataset import process_audio_file + +class PredictionThread(QThread): + """用于在后台运行预测的线程""" + prediction_complete = pyqtSignal(dict) + error = pyqtSignal(str) + + def __init__(self, model, file_path, use_mfcc, class_names): + super().__init__() + self.model = model + self.file_path = file_path + self.use_mfcc = use_mfcc + self.class_names = class_names + + def run(self): + try: + # 处理音频文件 + audio_tensor = process_audio_file(self.file_path, self.use_mfcc) + + # 确保模型在评估模式 + self.model.eval() + + # 添加批次维度并进行预测 + with torch.no_grad(): + audio_tensor = audio_tensor.unsqueeze(0) + if torch.cuda.is_available(): + audio_tensor = audio_tensor.cuda() + + output = self.model(audio_tensor) + probabilities = torch.nn.functional.softmax(output, dim=1)[0] + + # 获取预测和概率 + predicted_idx = torch.argmax(probabilities).item() + predicted_class = self.class_names[predicted_idx] + probs = probabilities.cpu().numpy() + + # 返回结果 + results = { + "predicted_class": predicted_class, + "predicted_idx": predicted_idx, + "probabilities": probs, + "class_names": self.class_names + } + + self.prediction_complete.emit(results) + + except Exception as e: + self.error.emit(f"预测过程中出错: {str(e)}") + + +class EmotionRecognitionApp(QMainWindow): + def __init__(self): + super().__init__() + + # 初始化音频播放器 + self.media_player = QMediaPlayer() + self.current_audio_path = None + + # 初始化模型变量 + self.model = None + self.class_names = None + self.use_mfcc = False + self.checkpoints_dir = "checkpoints" + + self.initUI() + + self.available_models = self.find_available_models() + self.update_model_combo() + + def initUI(self): + self.setWindowTitle("语音情感识别系统") + self.setGeometry(100, 100, 1200, 800) + self.setStyleSheet("background-color: #f0f0f0;") + + # 创建主布局 + main_widget = QWidget() + main_layout = QGridLayout(main_widget) + + # 添加标题 + title_label = QLabel("语音情感识别系统") + title_label.setFont(QFont("SIMSUN", 16, QFont.Bold)) + title_label.setAlignment(Qt.AlignCenter) + title_label.setStyleSheet("color: #2c3e50; margin: 10px;") + main_layout.addWidget(title_label, 0, 0, 1, 2) + + # === 左上角:操作区域 === + operations_group = QGroupBox("操作区") + operations_layout = QGridLayout(operations_group) + + # 模型选择 + model_label = QLabel("选择模型:") + self.model_combo = QComboBox() + self.model_combo.setMinimumWidth(200) + operations_layout.addWidget(model_label, 0, 0) + operations_layout.addWidget(self.model_combo, 0, 1) + + # 加载模型按钮 + self.load_model_btn = QPushButton("加载模型") + self.load_model_btn.setStyleSheet("background-color: #3498db; color: white;") + self.load_model_btn.clicked.connect(self.load_model) + operations_layout.addWidget(self.load_model_btn, 0, 2) + + # 音频文件选择 + file_label = QLabel("音频文件:") + # 使用自定义标签来显示文件路径 + self.file_path_label = QLabel("未选择文件") + self.file_path_label.setStyleSheet(""" + background-color: white; + padding: 5px; + border: 1px solid #cccccc; + border-radius: 3px; + """) + self.file_path_label.setTextInteractionFlags(Qt.TextSelectableByMouse) + self.file_path_label.setToolTip("未选择文件") + self.file_path_label.setMinimumWidth(250) + self.file_path_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + self.file_path_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter) + self.file_path_label.setTextFormat(Qt.PlainText) + self.file_path_label.setWordWrap(True) + + self.browse_btn = QPushButton("浏览...") + self.browse_btn.clicked.connect(self.browse_file) + operations_layout.addWidget(file_label, 1, 0) + operations_layout.addWidget(self.file_path_label, 1, 1) + operations_layout.addWidget(self.browse_btn, 1, 2) + + # 预测和退出按钮 + buttons_layout = QHBoxLayout() + self.predict_btn = QPushButton("进行预测") + self.predict_btn.setStyleSheet("background-color: #2ecc71; color: white; font-weight: bold;") + self.predict_btn.clicked.connect(self.predict) + self.predict_btn.setEnabled(False) + + self.exit_btn = QPushButton("退出") + self.exit_btn.setStyleSheet("background-color: #e74c3c; color: white;") + self.exit_btn.clicked.connect(self.close) + + buttons_layout.addWidget(self.predict_btn) + buttons_layout.addWidget(self.exit_btn) + operations_layout.addLayout(buttons_layout, 2, 0, 1, 3) + + # 进度条初始化 + self.progress_bar = QProgressBar() + self.progress_bar.setTextVisible(False) + self.progress_bar.setFixedHeight(20) + # 默认情况下进度条不活动,但占据空间 + self.progress_bar.setRange(0, 100) # 设置正常范围 + self.progress_bar.setValue(0) # 初始值为0% + self.progress_bar.setStyleSheet(""" + QProgressBar { + border: 1px solid #bbb; + border-radius: 5px; + text-align: center; + background-color: #f0f0f0; + } + QProgressBar::chunk { + background-color: #3498db; + width: 10px; + margin: 0.5px; + } + """) + operations_layout.addWidget(self.progress_bar, 3, 0, 1, 3) + + main_layout.addWidget(operations_group, 1, 0) + + # === 右上角:波形图区域 === + waveform_group = QGroupBox("音频波形图") + waveform_layout = QVBoxLayout(waveform_group) + + # 波形图 + self.waveform_figure = plt.figure(figsize=(6, 3)) + self.waveform_canvas = FigureCanvas(self.waveform_figure) + waveform_layout.addWidget(self.waveform_canvas) + + # 播放控制按钮 + playback_layout = QHBoxLayout() + self.play_btn = QPushButton("播放") + self.play_btn.clicked.connect(self.play_audio) + self.play_btn.setEnabled(False) + + self.stop_btn = QPushButton("暂停") + self.stop_btn.clicked.connect(self.stop_audio) + self.stop_btn.setEnabled(False) + + playback_layout.addWidget(self.play_btn) + playback_layout.addWidget(self.stop_btn) + waveform_layout.addLayout(playback_layout) + + main_layout.addWidget(waveform_group, 1, 1) + + # === 左下角:MFCC区域 === + mfcc_group = QGroupBox("MFCC 频谱图") + mfcc_layout = QVBoxLayout(mfcc_group) + + self.mfcc_figure = plt.figure(figsize=(5, 4)) + self.mfcc_canvas = FigureCanvas(self.mfcc_figure) + mfcc_layout.addWidget(self.mfcc_canvas) + + main_layout.addWidget(mfcc_group, 2, 0) + + # === 右下角:预测结果区域 === + results_group = QGroupBox("预测结果") + results_layout = QVBoxLayout(results_group) + + # 预测结果标签 + self.result_label = QLabel("请先加载模型和选择音频文件") + self.result_label.setFont(QFont("SIMSUN", 12)) + self.result_label.setAlignment(Qt.AlignCenter) + self.result_label.setStyleSheet("background-color: white; padding: 15px; border-radius: 5px;") + results_layout.addWidget(self.result_label) + + # 概率可视化区域 + self.result_figure = plt.figure(figsize=(5, 4)) + self.result_canvas = FigureCanvas(self.result_figure) + results_layout.addWidget(self.result_canvas) + + main_layout.addWidget(results_group, 2, 1) + + # 设置行列比例 + main_layout.setRowStretch(0, 1) # 标题行 + main_layout.setRowStretch(1, 4) # 上半部分 + main_layout.setRowStretch(2, 4) # 下半部分 + main_layout.setColumnStretch(0, 1) + main_layout.setColumnStretch(1, 1) + + # 配置媒体播放器状态监听 + self.media_player.stateChanged.connect(self.handle_media_state_changed) + + self.setCentralWidget(main_widget) + + def find_available_models(self): + """查找checkpoints文件夹中的可用模型""" + models = [] + + if not os.path.exists(self.checkpoints_dir): + return models + + for folder in os.listdir(self.checkpoints_dir): + folder_path = os.path.join(self.checkpoints_dir, folder) + if os.path.isdir(folder_path): + best_model_path = os.path.join(folder_path, f"best.pth") + # C:\Users\Malong\Desktop\wenhao\语音情感识别\checkpoints\spectrogram_resnet18_20250306_143234\best.pth + # 从中提取spectrogram_resnet18,首先提取最后一个文件夹名称,然后提取前两个单词 + print(folder_path) + if os.path.exists(best_model_path): + # 提取信息 + model_name = self.extract_model_name(folder_path) + is_mfcc = "spectrogram" in folder + models.append({ + "name": folder, + "path": best_model_path, + "model_name": model_name, + "use_mfcc": is_mfcc + }) + + return models + def extract_model_name(self, model_path): + """ + 从模型路径中提取模型名称(如spectrogram_resnet18) + + 参数: + model_path (str): 模型文件的完整路径 + + 返回: + str: 提取出的模型名称 + """ + # 标准化路径分隔符 + model_path = model_path.replace('\\', '/') + + # 获取路径的最后一个目录 + directory = model_path.split('/')[-1] if '/' in model_path else "" + + # 如果目录为空,返回空字符串 + if not directory: + return "" + + # 分割目录名称 + parts = directory.split('_') + + # 如果至少有两个部分,提取前两个作为模型名称 + if len(parts) >= 2: + return f"{parts[0]}_{parts[1]}" + + return directory # 如果没有足够的部分,返回原始目录名 + + def update_model_combo(self): + """更新模型下拉列表""" + self.model_combo.clear() + for model_info in self.available_models: + self.model_combo.addItem(model_info["name"], userData=model_info) + + if not self.available_models: + self.model_combo.addItem("未找到模型") + + def load_model(self): + """加载选择的模型""" + if not self.available_models: + QMessageBox.warning(self, "错误", "未找到可用模型,请确认checkpoints目录中包含训练好的模型文件") + return + + try: + # 激活进度条(不确定模式) + self.progress_bar.setRange(0, 0) # 设置为不确定模式 + + # 获取选择的模型信息 + selected_idx = self.model_combo.currentIndex() + model_info = self.model_combo.itemData(selected_idx) + + # 模型类型和MFCC标志 + model_name = model_info["model_name"] + self.use_mfcc = model_info["use_mfcc"] + model_path = model_info["path"] + + # 确定类别数量 (从train.py中加载,这里默认为6) + num_classes = 6 + + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 创建模型 + if model_name == "waveform_resnet18": + self.model = waveform_resnet18(num_classes=num_classes) + elif model_name == "waveform_resnet34": + self.model = waveform_resnet34(num_classes=num_classes) + elif model_name == "waveform_resnet50": + self.model = waveform_resnet50(num_classes=num_classes) + elif model_name == "waveform_resnet101": + self.model = waveform_resnet101(num_classes=num_classes) + elif model_name == "spectrogram_resnet18": + self.model = spectrogram_resnet18(num_classes=num_classes) + elif model_name == "spectrogram_resnet34": + self.model = spectrogram_resnet34(num_classes=num_classes) + elif model_name == "spectrogram_resnet50": + self.model = spectrogram_resnet50(num_classes=num_classes) + elif model_name == "spectrogram_resnet101": + self.model = spectrogram_resnet101(num_classes=num_classes) + else: + raise ValueError(f"不支持的模型类型: {model_name}") + + # 加载权重 + self.model.load_state_dict(torch.load(model_path, map_location=device)) + self.model.to(device) + self.model.eval() + + # 设置类别名称 (从数据集中获得,这里硬编码) + # 实际应用中,应该从训练数据中加载或保存在模型权重旁边 + self.class_names = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise'] + + # 更新UI + self.predict_btn.setEnabled(True) + QMessageBox.information(self, "成功", f"成功加载模型: {model_name}") + self.result_label.setText(f"模型已加载: {model_name}\n请选择音频文件进行预测") + + except Exception as e: + QMessageBox.critical(self, "错误", f"加载模型时出错: {str(e)}") + finally: + # 恢复进度条为非活动状态 + self.progress_bar.setRange(0, 100) + self.progress_bar.setValue(0) + + def browse_file(self): + """选择音频文件""" + file_path, _ = QFileDialog.getOpenFileName( + self, "选择音频文件", "", "音频文件 (*.wav *.mp3 *.flac *.ogg)" + ) + + if file_path: + # 设置文本和工具提示 + self.file_path_label.setText(file_path) + self.file_path_label.setToolTip(file_path) + self.current_audio_path = file_path + + # 如果模型已加载,则启用预测按钮 + if self.model is not None: + self.predict_btn.setEnabled(True) + + # 加载并显示波形图 + self.load_and_display_waveform(file_path) + + # 启用播放按钮 + self.play_btn.setEnabled(True) + + def load_and_display_waveform(self, file_path): + """加载并显示波形图和MFCC图""" + try: + # 使用soundfile加载音频 + data, samplerate = sf.read(file_path) + + # 如果是立体声,转换为单声道 + if len(data.shape) > 1 and data.shape[1] > 1: + data = data.mean(axis=1) + + # 计算时间轴 + time = np.arange(0, len(data)) / samplerate + + # 绘制波形图 + self.waveform_figure.clear() + ax1 = self.waveform_figure.add_subplot(111) + ax1.plot(time, data, color='blue') + ax1.set_title("音频波形") + ax1.set_xlabel("时间 (秒)") + ax1.set_ylabel("振幅") + self.waveform_figure.tight_layout() + self.waveform_canvas.draw() + + # 生成MFCC特征图 + self.mfcc_figure.clear() + ax2 = self.mfcc_figure.add_subplot(111) + + # 使用torchaudio计算MFCC特征 + waveform_tensor = torch.tensor(data).float().unsqueeze(0) + mfcc_transform = torchaudio.transforms.MFCC( + sample_rate=samplerate, + n_mfcc=13 + ) + mfcc_tensor = mfcc_transform(waveform_tensor) + mfcc_data = mfcc_tensor.numpy()[0] + + # 绘制MFCC频谱图 + img = ax2.imshow(mfcc_data, aspect='auto', origin='lower', interpolation='none') + ax2.set_title("MFCC特征") + ax2.set_ylabel("MFCC系数") + ax2.set_xlabel("时间帧") + + # 添加颜色条 + self.mfcc_figure.colorbar(img, ax=ax2) + self.mfcc_figure.tight_layout() + self.mfcc_canvas.draw() + + except Exception as e: + QMessageBox.warning(self, "错误", f"无法加载或显示音频: {str(e)}") + + def play_audio(self): + """播放音频""" + if self.current_audio_path: + # 使用QMediaPlayer播放音频 + url = QUrl.fromLocalFile(self.current_audio_path) + self.media_player.setMedia(QMediaContent(url)) + self.media_player.play() + self.play_btn.setEnabled(False) + self.stop_btn.setEnabled(True) + + def stop_audio(self): + """停止音频播放""" + self.media_player.pause() + self.play_btn.setEnabled(True) + self.stop_btn.setEnabled(False) + + def handle_media_state_changed(self, state): + """处理媒体播放状态变化""" + if state == QMediaPlayer.StoppedState: + self.play_btn.setEnabled(True) + self.stop_btn.setEnabled(False) + + def predict(self): + """对选择的音频文件进行预测""" + if not self.model: + QMessageBox.warning(self, "警告", "请先加载模型") + return + + file_path = self.file_path_label.text() + if file_path == "未选择文件" or not os.path.exists(file_path): + QMessageBox.warning(self, "警告", "请选择有效的音频文件") + return + + # 激活进度条(不确定模式) + self.progress_bar.setRange(0, 0) # 设置为不确定模式 + self.predict_btn.setEnabled(False) + self.load_model_btn.setEnabled(False) + self.browse_btn.setEnabled(False) + + # 创建并启动预测线程 + self.prediction_thread = PredictionThread( + self.model, file_path, self.use_mfcc, self.class_names + ) + self.prediction_thread.prediction_complete.connect(self.handle_prediction_result) + self.prediction_thread.error.connect(self.handle_prediction_error) + self.prediction_thread.finished.connect(self.prediction_finished) + self.prediction_thread.start() + + def handle_prediction_result(self, results): + """处理预测结果""" + predicted_class = results["predicted_class"] + probabilities = results["probabilities"] + class_names = results["class_names"] + + # 更新结果标签 + self.result_label.setText(f"预测结果: {predicted_class}") + + # 绘制概率条形图 + self.result_figure.clear() + ax = self.result_figure.add_subplot(111) + + # 设置中文情感名称标签 + emotion_names_zh = { + 'angry': '愤怒', + 'fear': '恐惧', + 'happy': '高兴', + 'neutral': '中性', + 'sad': '悲伤', + 'surprise': '惊讶' + } + + zh_labels = [emotion_names_zh.get(name, name) for name in class_names] + + bars = ax.bar(zh_labels, probabilities, color='skyblue') + + # 突出显示预测的类别 + predicted_idx = results["predicted_idx"] + bars[predicted_idx].set_color('orange') + + # 添加值标签 + for bar in bars: + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, + f"{height:.2f}", ha='center', va='bottom') + + ax.set_ylabel('概率') + # 设置y轴范围 + ax.set_ylim(0, 1) + ax.set_title('各情感类别预测概率') + plt.xticks(rotation=45) + self.result_figure.tight_layout() + + self.result_canvas.draw() + + def handle_prediction_error(self, error_msg): + """处理预测错误""" + QMessageBox.critical(self, "预测错误", error_msg) + self.result_label.setText("预测失败,请重试") + + def prediction_finished(self): + """预测完成后的清理工作""" + # 恢复进度条为非活动状态 + self.progress_bar.setRange(0, 100) + self.progress_bar.setValue(0) + self.predict_btn.setEnabled(True) + self.load_model_btn.setEnabled(True) + self.browse_btn.setEnabled(True) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = EmotionRecognitionApp() + window.show() + sys.exit(app.exec_()) \ No newline at end of file