You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-02/audio_classifier_gui_01.py

945 lines
34 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
import sqlite3
from datetime import datetime
from scipy.signal import hilbert, find_peaks
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout,
QWidget, QProgressBar, QStackedWidget, QMessageBox,
QTableWidget, QTableWidgetItem, QHeaderView,
QLineEdit, QDialog, QDialogButtonBox, QFormLayout,
QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
# ---------------------- 数据库管理模块 ----------------------
class DatabaseManager:
"""数据库管理器,负责结果数据的存储与查询"""
def __init__(self, db_path="audio_classification.db"):
self.db_path = db_path
self.init_database()
def init_database(self):
"""初始化数据库表结构"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建结果表
cursor.execute('''
CREATE TABLE IF NOT EXISTS classification_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
segment_count INTEGER NOT NULL,
segment_labels TEXT NOT NULL,
segment_probs TEXT NOT NULL,
mean_probability REAL NOT NULL,
final_label INTEGER NOT NULL,
label_text TEXT NOT NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
notes TEXT
)
''')
conn.commit()
conn.close()
def insert_result(self, result_data, notes=""):
"""插入新的分类结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO classification_results
(filename, segment_count, segment_labels, segment_probs,
mean_probability, final_label, label_text, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
result_data['filename'],
result_data['segment_count'],
str(result_data['segment_labels']),
str(result_data['segment_probs']),
result_data['mean_probability'],
result_data['final_label'],
result_data['label_text'],
notes
))
conn.commit()
conn.close()
return cursor.lastrowid
def get_all_results(self):
"""获取所有结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM classification_results
ORDER BY create_time DESC
''')
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def search_results(self, filename_filter="", label_filter=None, date_filter=None):
"""根据条件搜索结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = '''
SELECT * FROM classification_results
WHERE 1=1
'''
params = []
if filename_filter:
query += " AND filename LIKE ?"
params.append(f'%{filename_filter}%')
if label_filter is not None:
query += " AND final_label = ?"
params.append(label_filter)
if date_filter:
query += " AND DATE(create_time) = ?"
params.append(date_filter)
query += " ORDER BY create_time DESC"
cursor.execute(query, params)
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def update_result(self, result_id, updates):
"""更新结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
set_clause = ", ".join([f"{key} = ?" for key in updates.keys()])
query = f"UPDATE classification_results SET {set_clause} WHERE id = ?"
params = list(updates.values())
params.append(result_id)
cursor.execute(query, params)
conn.commit()
conn.close()
return cursor.rowcount > 0
def delete_result(self, result_id):
"""删除结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,))
conn.commit()
conn.close()
return cursor.rowcount > 0
# ---------------------- 特征提取模块(与训练集保持一致) ----------------------
class FeatureProcessor:
"""特征提取器,统一使用训练时的特征维度和计算方式"""
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
"""计算帧能量"""
frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN,
hop_length=FeatureProcessor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_impacts(energy: np.ndarray) -> np.ndarray:
"""检测有效敲击片段起始点"""
idx = np.where(energy > FeatureProcessor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
# 间隔超过5帧视为新片段
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]:
"""切分固定长度的音频片段"""
seg_len = int(FeatureProcessor.SEG_LEN_S * sr)
segments = []
for frame_idx in starts:
start = frame_idx * FeatureProcessor.STEP
end = min(start + seg_len, len(signal))
segments.append(signal[start:end])
return segments
@staticmethod
def extract_features(signal: np.ndarray, sr: int) -> np.ndarray:
"""提取与训练集一致的5维特征修正原GUI可能的特征不匹配问题"""
signal = signal.flatten()
if len(signal) == 0:
return np.zeros(5, dtype=np.float32)
# 1. RMS能量
rms = np.sqrt(np.mean(signal ** 2))
# 2. 主频(频谱峰值)
fft = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), 1 / sr)
positive_mask = freq >= 0
freq = freq[positive_mask]
fft_mag = np.abs(fft[positive_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) /
(np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(signal)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
# ---------------------- 录音线程 ----------------------
class AudioRecorder(QThread):
status_updated = pyqtSignal(str)
progress_updated = pyqtSignal(int)
recording_finished = pyqtSignal(str)
def __init__(self, max_duration=60):
super().__init__()
self.max_duration = max_duration
self.recording = False
self.temp_file = "temp_audio.wav"
def run(self):
# 音频参数(与特征提取兼容)
FORMAT = pyaudio.paFloat32
CHANNELS = 1
RATE = 22050
CHUNK = 1024
try:
audio = pyaudio.PyAudio()
stream = audio.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK
)
self.recording = True
start_time = time.time()
frames = []
while self.recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.status_updated.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_updated.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK)
frames.append(data)
# 停止录音并保存
stream.stop_stream()
stream.close()
audio.terminate()
if frames:
with wave.open(self.temp_file, 'wb') as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(audio.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}")
self.recording_finished.emit(self.temp_file)
else:
self.status_updated.emit("未检测到音频输入")
except Exception as e:
self.status_updated.emit(f"录音错误: {str(e)}")
def stop(self):
self.recording = False
# ---------------------- 音频处理与预测线程 ----------------------
class AudioProcessor(QThread):
status_updated = pyqtSignal(str)
result_generated = pyqtSignal(dict)
def __init__(self, audio_path, model_path, scaler_path):
super().__init__()
self.audio_path = audio_path
self.model_path = model_path
self.scaler_path = scaler_path
self.class_threshold = 0.5 # 分类阈值
def run(self):
try:
# 加载模型和标准化器
self.status_updated.emit("加载模型资源...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
# 读取音频文件
self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}")
signal, sr = librosa.load(self.audio_path, sr=None, mono=True)
signal = signal / np.max(np.abs(signal)) # 归一化
# 提取片段
self.status_updated.emit("检测有效音频片段...")
energy = FeatureProcessor.frame_energy(signal)
impact_starts = FeatureProcessor.detect_impacts(energy)
segments = FeatureProcessor.extract_segments(signal, sr, impact_starts)
if not segments:
self.status_updated.emit("未检测到有效敲击片段")
return
# 提取特征并预测
self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...")
features = [FeatureProcessor.extract_features(seg, sr) for seg in segments]
X = np.vstack(features)
# 标准化特征
X_scaled = scaler.transform(X)
# 模型预测(处理标签转换:-1/1 → 0/1
predictions = model.predict(X_scaled)
pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率
pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签
# 计算文件级结果
mean_prob = pred_proba.mean()
final_label = 1 if mean_prob >= self.class_threshold else 0
result = {
"filename": os.path.basename(self.audio_path),
"segment_count": len(segments),
"segment_labels": pred_labels.tolist(),
"segment_probs": [round(p, 4) for p in pred_proba],
"mean_probability": round(mean_prob, 4),
"final_label": final_label,
"label_text": "空心" if final_label == 0 else "实心" # 假设0=空心1=实心
}
self.result_generated.emit(result)
except Exception as e:
self.status_updated.emit(f"处理错误: {str(e)}")
# ---------------------- 数据库编辑对话框 ----------------------
class EditResultDialog(QDialog):
def __init__(self, result_data=None, parent=None):
super().__init__(parent)
self.result_data = result_data or {}
self.init_ui()
def init_ui(self):
self.setWindowTitle("编辑结果记录")
self.setModal(True)
self.resize(400, 300)
layout = QFormLayout()
# 文件名
self.filename_edit = QLineEdit(self.result_data.get('filename', ''))
layout.addRow("文件名:", self.filename_edit)
# 片段数量
self.segment_count_spin = QSpinBox()
self.segment_count_spin.setRange(0, 1000)
self.segment_count_spin.setValue(self.result_data.get('segment_count', 0))
layout.addRow("片段数量:", self.segment_count_spin)
# 平均概率
self.mean_prob_spin = QDoubleSpinBox()
self.mean_prob_spin.setRange(0.0, 1.0)
self.mean_prob_spin.setDecimals(4)
self.mean_prob_spin.setSingleStep(0.01)
self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0))
layout.addRow("平均概率:", self.mean_prob_spin)
# 最终标签
self.final_label_combo = QComboBox()
self.final_label_combo.addItem("空心", 0)
self.final_label_combo.addItem("实心", 1)
current_label = self.result_data.get('final_label', 0)
index = self.final_label_combo.findData(current_label)
if index >= 0:
self.final_label_combo.setCurrentIndex(index)
layout.addRow("最终分类:", self.final_label_combo)
# 备注
self.notes_edit = QTextEdit()
self.notes_edit.setMaximumHeight(80)
self.notes_edit.setText(self.result_data.get('notes', ''))
layout.addRow("备注:", self.notes_edit)
# 按钮
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
layout.addRow(button_box)
self.setLayout(layout)
def get_updated_data(self):
"""获取更新后的数据"""
return {
'filename': self.filename_edit.text(),
'segment_count': self.segment_count_spin.value(),
'mean_probability': self.mean_prob_spin.value(),
'final_label': self.final_label_combo.currentData(),
'label_text': self.final_label_combo.currentText(),
'notes': self.notes_edit.toPlainText()
}
# ---------------------- 数据库汇总界面 ----------------------
class DatabaseUI(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.db_manager = DatabaseManager()
self.init_ui()
self.load_data()
def init_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
layout.addWidget(btn_back)
# 标题
title = QLabel("数据库汇总管理")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
layout.addWidget(title)
# 创建标签页
self.tab_widget = QTabWidget()
# 数据浏览标签页
self.browse_tab = QWidget()
self.init_browse_tab()
self.tab_widget.addTab(self.browse_tab, "数据浏览")
# 搜索标签页
self.search_tab = QWidget()
self.init_search_tab()
self.tab_widget.addTab(self.search_tab, "搜索过滤")
layout.addWidget(self.tab_widget)
self.setLayout(layout)
def init_browse_tab(self):
layout = QVBoxLayout()
# 操作按钮
btn_layout = QHBoxLayout()
self.btn_refresh = QPushButton("刷新数据")
self.btn_refresh.clicked.connect(self.load_data)
self.btn_add = QPushButton("新增记录")
self.btn_add.clicked.connect(self.add_result)
self.btn_edit = QPushButton("编辑选中")
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_delete = QPushButton("删除选中")
self.btn_delete.clicked.connect(self.delete_selected)
btn_layout.addWidget(self.btn_refresh)
btn_layout.addWidget(self.btn_add)
btn_layout.addWidget(self.btn_edit)
btn_layout.addWidget(self.btn_delete)
btn_layout.addStretch()
layout.addLayout(btn_layout)
# 数据表格
self.table_widget = QTableWidget()
self.table_widget.setColumnCount(8)
self.table_widget.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.table_widget.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.table_widget)
self.browse_tab.setLayout(layout)
def init_search_tab(self):
layout = QVBoxLayout()
# 搜索条件表单
form_layout = QFormLayout()
self.search_filename = QLineEdit()
self.search_filename.setPlaceholderText("输入文件名关键词")
form_layout.addRow("文件名:", self.search_filename)
self.search_label = QComboBox()
self.search_label.addItem("全部", None)
self.search_label.addItem("空心", 0)
self.search_label.addItem("实心", 1)
form_layout.addRow("分类结果:", self.search_label)
self.search_date = QLineEdit()
self.search_date.setPlaceholderText("YYYY-MM-DD")
form_layout.addRow("创建日期:", self.search_date)
layout.addLayout(form_layout)
# 搜索按钮
btn_search = QPushButton("搜索")
btn_search.clicked.connect(self.perform_search)
layout.addWidget(btn_search)
# 搜索结果表格
self.search_table = QTableWidget()
self.search_table.setColumnCount(8)
self.search_table.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.search_table.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.search_table)
self.search_tab.setLayout(layout)
def load_data(self):
"""加载所有数据到浏览表格"""
results = self.db_manager.get_all_results()
self.populate_table(self.table_widget, results)
def perform_search(self):
"""执行搜索操作"""
filename_filter = self.search_filename.text().strip()
label_filter = self.search_label.currentData()
date_filter = self.search_date.text().strip() or None
results = self.db_manager.search_results(filename_filter, label_filter, date_filter)
self.populate_table(self.search_table, results)
def populate_table(self, table, results):
"""填充表格数据"""
table.setRowCount(len(results))
for row, result in enumerate(results):
# ID
table.setItem(row, 0, QTableWidgetItem(str(result['id'])))
# 文件名
table.setItem(row, 1, QTableWidgetItem(result['filename']))
# 片段数
table.setItem(row, 2, QTableWidgetItem(str(result['segment_count'])))
# 平均概率
table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}"))
# 分类结果
table.setItem(row, 4, QTableWidgetItem(result['label_text']))
# 创建时间
create_time = result['create_time']
if isinstance(create_time, str):
table.setItem(row, 5, QTableWidgetItem(create_time))
else:
table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S")))
# 备注
notes = result.get('notes', '')
table.setItem(row, 6, QTableWidgetItem(notes if notes else ""))
# 操作按钮
btn_view = QPushButton("查看详情")
btn_view.clicked.connect(lambda checked, r=result: self.view_details(r))
table.setCellWidget(row, 7, btn_view)
def add_result(self):
"""新增结果记录"""
dialog = EditResultDialog()
if dialog.exec_() == QDialog.Accepted:
new_data = dialog.get_updated_data()
# 设置一些默认值
new_data['segment_labels'] = []
new_data['segment_probs'] = []
try:
self.db_manager.insert_result(new_data, new_data.get('notes', ''))
self.load_data()
QMessageBox.information(self, "成功", "记录添加成功!")
except Exception as e:
QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}")
def edit_selected(self):
"""编辑选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
results = self.db_manager.search_results()
result_data = next((r for r in results if r['id'] == result_id), None)
if not result_data:
QMessageBox.warning(self, "错误", "未找到选中的记录!")
return
dialog = EditResultDialog(result_data)
if dialog.exec_() == QDialog.Accepted:
updated_data = dialog.get_updated_data()
try:
success = self.db_manager.update_result(result_id, updated_data)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录更新成功!")
else:
QMessageBox.warning(self, "错误", "更新记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}")
def delete_selected(self):
"""删除选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
filename = current_table.item(current_row, 1).text()
reply = QMessageBox.question(
self, "确认删除",
f"确定要删除文件 '{filename}' 的记录吗?",
QMessageBox.Yes | QMessageBox.No
)
if reply == QMessageBox.Yes:
try:
success = self.db_manager.delete_result(result_id)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录删除成功!")
else:
QMessageBox.warning(self, "错误", "删除记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}")
def view_details(self, result_data):
"""查看记录详情"""
details = (
f"记录ID: {result_data['id']}\n"
f"文件名: {result_data['filename']}\n"
f"片段数量: {result_data['segment_count']}\n"
f"片段标签: {result_data['segment_labels']}\n"
f"片段概率: {result_data['segment_probs']}\n"
f"平均概率: {result_data['mean_probability']:.4f}\n"
f"最终分类: {result_data['label_text']}\n"
f"创建时间: {result_data['create_time']}\n"
f"备注: {result_data.get('notes', '')}"
)
QMessageBox.information(self, "记录详情", details)
# ---------------------- 主界面组件 ----------------------
class MainMenu(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
layout.setSpacing(20)
# 标题
title = QLabel("音频分类系统")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;")
layout.addWidget(title)
# 功能按钮
btn_record = QPushButton("录制音频")
btn_record.setMinimumHeight(70)
btn_record.setStyleSheet("font-size: 18px;")
btn_record.clicked.connect(lambda: self.parent.switch_page("record"))
btn_upload = QPushButton("上传音频文件")
btn_upload.setMinimumHeight(70)
btn_upload.setStyleSheet("font-size: 18px;")
btn_upload.clicked.connect(lambda: self.parent.switch_page("upload"))
btn_database = QPushButton("数据库管理")
btn_database.setMinimumHeight(70)
btn_database.setStyleSheet("font-size: 18px;")
btn_database.clicked.connect(lambda: self.parent.switch_page("database"))
layout.addWidget(btn_record)
layout.addWidget(btn_upload)
layout.addWidget(btn_database)
layout.addStretch()
self.setLayout(layout)
class InputPage(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode # "record" 或 "upload"
self.audio_path = ""
self.recorder = None
self.db_manager = DatabaseManager() # 添加数据库管理器
self.init_ui()
def init_ui(self):
# 主布局
main_layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
main_layout.addWidget(btn_back)
# 标题
title = QLabel("录制音频" if self.mode == "record" else "上传音频文件")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
main_layout.addWidget(title)
# 操作区域
if self.mode == "record":
# 录音控制
self.btn_start_rec = QPushButton("开始录音")
self.btn_start_rec.clicked.connect(self.start_recording)
self.btn_stop_rec = QPushButton("停止录音")
self.btn_stop_rec.setEnabled(False)
self.btn_stop_rec.clicked.connect(self.stop_recording)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
# 录音布局
rec_layout = QHBoxLayout()
rec_layout.addWidget(self.btn_start_rec)
rec_layout.addWidget(self.btn_stop_rec)
main_layout.addLayout(rec_layout)
main_layout.addWidget(self.progress_bar)
else:
# 上传控制
self.btn_browse = QPushButton("选择WAV文件")
self.btn_browse.clicked.connect(self.browse_file)
self.lbl_file = QLabel("未选择文件")
self.lbl_file.setStyleSheet("color: #666;")
main_layout.addWidget(self.btn_browse)
main_layout.addWidget(self.lbl_file)
# 状态显示
self.status_display = QTextEdit()
self.status_display.setReadOnly(True)
self.status_display.setMinimumHeight(150)
main_layout.addWidget(QLabel("状态信息:"))
main_layout.addWidget(self.status_display)
# 处理按钮
self.btn_process = QPushButton("开始分析")
self.btn_process.setEnabled(False)
self.btn_process.clicked.connect(self.process_audio)
main_layout.addWidget(self.btn_process)
main_layout.addStretch()
self.setLayout(main_layout)
def start_recording(self):
self.status_display.append("开始录音...")
self.recorder = AudioRecorder(max_duration=60)
self.recorder.status_updated.connect(self.update_status)
self.recorder.progress_updated.connect(self.progress_bar.setValue)
self.recorder.recording_finished.connect(self.on_recording_finished)
self.recorder.start()
self.btn_start_rec.setEnabled(False)
self.btn_stop_rec.setEnabled(True)
def stop_recording(self):
if self.recorder and self.recorder.recording:
self.recorder.stop()
self.btn_start_rec.setEnabled(True)
self.btn_stop_rec.setEnabled(False)
def on_recording_finished(self, file_path):
self.audio_path = file_path
self.btn_process.setEnabled(True)
self.status_display.append(f"录音文件已保存: {file_path}")
def browse_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "选择音频文件", "", "WAV文件 (*.wav)"
)
if file_path:
self.audio_path = file_path
self.lbl_file.setText(os.path.basename(file_path))
self.btn_process.setEnabled(True)
self.update_status(f"已选择文件: {file_path}")
def process_audio(self):
if not self.audio_path or not os.path.exists(self.audio_path):
QMessageBox.warning(self, "错误", "音频文件不存在")
return
# 模型路径(请根据实际情况修改)
model_path = "pipeline_model.pkl" # 或 "svm_model.pkl"
scaler_path = "scaler.pkl"
if not (os.path.exists(model_path) and os.path.exists(scaler_path)):
QMessageBox.warning(self, "错误", "模型文件或标准化器不存在")
return
# 启动处理线程
self.processor = AudioProcessor(self.audio_path, model_path, scaler_path)
self.processor.status_updated.connect(self.update_status)
self.processor.result_generated.connect(self.show_result)
self.processor.start()
self.btn_process.setEnabled(False)
self.update_status("开始分析音频...")
def update_status(self, message):
self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}")
# 自动滚动到底部
self.status_display.moveCursor(self.status_display.textCursor().End)
def show_result(self, result):
# 显示结果对话框,增加保存到数据库的选项
msg = QMessageBox()
msg.setWindowTitle("分析结果")
msg.setIcon(QMessageBox.Information)
text = (
f"文件名: {result['filename']}\n"
f"有效片段数: {result['segment_count']}\n"
f"平均概率: {result['mean_probability']}\n"
f"最终分类: {result['label_text']}"
)
msg.setText(text)
# 添加保存到数据库的按钮
msg.addButton("保存到数据库", QMessageBox.AcceptRole)
msg.addButton("仅查看", QMessageBox.RejectRole)
reply = msg.exec_()
if reply == 0: # 保存到数据库
notes, ok = QInputDialog.getText(
self, "添加备注", "请输入备注信息(可选):",
QLineEdit.Normal, ""
)
if ok:
try:
self.db_manager.insert_result(result, notes)
self.update_status("结果已保存到数据库")
except Exception as e:
QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}")
self.update_status("分析完成")
self.btn_process.setEnabled(True)
# ---------------------- 主窗口 ----------------------
class AudioClassifierApp(QMainWindow):
def __init__(self):
super().__init__()
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(300, 300, 800, 600)
# 堆叠窗口管理页面
self.stack = QStackedWidget()
self.main_menu = MainMenu(self)
self.record_page = InputPage(self, mode="record")
self.upload_page = InputPage(self, mode="upload")
self.database_page = DatabaseUI(self) # 新增数据库页面
self.stack.addWidget(self.main_menu)
self.stack.addWidget(self.record_page)
self.stack.addWidget(self.upload_page)
self.stack.addWidget(self.database_page)
self.setCentralWidget(self.stack)
def switch_page(self, page_name):
"""切换页面"""
if page_name == "main":
self.stack.setCurrentWidget(self.main_menu)
elif page_name == "record":
self.stack.setCurrentWidget(self.record_page)
elif page_name == "upload":
self.stack.setCurrentWidget(self.upload_page)
elif page_name == "database":
self.stack.setCurrentWidget(self.database_page)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AudioClassifierApp()
window.show()
sys.exit(app.exec_())