|
|
#!/usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
import librosa
|
|
|
import pyaudio
|
|
|
import wave
|
|
|
import sqlite3
|
|
|
from datetime import datetime
|
|
|
from scipy.signal import hilbert, find_peaks
|
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
|
|
|
QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout,
|
|
|
QWidget, QProgressBar, QStackedWidget, QMessageBox,
|
|
|
QTableWidget, QTableWidgetItem, QHeaderView,
|
|
|
QLineEdit, QDialog, QDialogButtonBox, QFormLayout,
|
|
|
QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog)
|
|
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
|
|
|
|
|
|
|
|
# ---------------------- 数据库管理模块 ----------------------
|
|
|
class DatabaseManager:
|
|
|
"""数据库管理器,负责结果数据的存储与查询"""
|
|
|
|
|
|
def __init__(self, db_path="audio_classification.db"):
|
|
|
self.db_path = db_path
|
|
|
self.init_database()
|
|
|
|
|
|
def init_database(self):
|
|
|
"""初始化数据库表结构"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
# 创建结果表
|
|
|
cursor.execute('''
|
|
|
CREATE TABLE IF NOT EXISTS classification_results (
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
filename TEXT NOT NULL,
|
|
|
segment_count INTEGER NOT NULL,
|
|
|
segment_labels TEXT NOT NULL,
|
|
|
segment_probs TEXT NOT NULL,
|
|
|
mean_probability REAL NOT NULL,
|
|
|
final_label INTEGER NOT NULL,
|
|
|
label_text TEXT NOT NULL,
|
|
|
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
notes TEXT
|
|
|
)
|
|
|
''')
|
|
|
|
|
|
conn.commit()
|
|
|
conn.close()
|
|
|
|
|
|
def insert_result(self, result_data, notes=""):
|
|
|
"""插入新的分类结果"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute('''
|
|
|
INSERT INTO classification_results
|
|
|
(filename, segment_count, segment_labels, segment_probs,
|
|
|
mean_probability, final_label, label_text, notes)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
''', (
|
|
|
result_data['filename'],
|
|
|
result_data['segment_count'],
|
|
|
str(result_data['segment_labels']),
|
|
|
str(result_data['segment_probs']),
|
|
|
result_data['mean_probability'],
|
|
|
result_data['final_label'],
|
|
|
result_data['label_text'],
|
|
|
notes
|
|
|
))
|
|
|
|
|
|
conn.commit()
|
|
|
conn.close()
|
|
|
return cursor.lastrowid
|
|
|
|
|
|
def get_all_results(self):
|
|
|
"""获取所有结果记录"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute('''
|
|
|
SELECT * FROM classification_results
|
|
|
ORDER BY create_time DESC
|
|
|
''')
|
|
|
|
|
|
results = []
|
|
|
columns = [desc[0] for desc in cursor.description]
|
|
|
|
|
|
for row in cursor.fetchall():
|
|
|
results.append(dict(zip(columns, row)))
|
|
|
|
|
|
conn.close()
|
|
|
return results
|
|
|
|
|
|
def search_results(self, filename_filter="", label_filter=None, date_filter=None):
|
|
|
"""根据条件搜索结果"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
query = '''
|
|
|
SELECT * FROM classification_results
|
|
|
WHERE 1=1
|
|
|
'''
|
|
|
params = []
|
|
|
|
|
|
if filename_filter:
|
|
|
query += " AND filename LIKE ?"
|
|
|
params.append(f'%{filename_filter}%')
|
|
|
|
|
|
if label_filter is not None:
|
|
|
query += " AND final_label = ?"
|
|
|
params.append(label_filter)
|
|
|
|
|
|
if date_filter:
|
|
|
query += " AND DATE(create_time) = ?"
|
|
|
params.append(date_filter)
|
|
|
|
|
|
query += " ORDER BY create_time DESC"
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
|
|
|
results = []
|
|
|
columns = [desc[0] for desc in cursor.description]
|
|
|
|
|
|
for row in cursor.fetchall():
|
|
|
results.append(dict(zip(columns, row)))
|
|
|
|
|
|
conn.close()
|
|
|
return results
|
|
|
|
|
|
def update_result(self, result_id, updates):
|
|
|
"""更新结果记录"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
set_clause = ", ".join([f"{key} = ?" for key in updates.keys()])
|
|
|
query = f"UPDATE classification_results SET {set_clause} WHERE id = ?"
|
|
|
|
|
|
params = list(updates.values())
|
|
|
params.append(result_id)
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
conn.commit()
|
|
|
conn.close()
|
|
|
|
|
|
return cursor.rowcount > 0
|
|
|
|
|
|
def delete_result(self, result_id):
|
|
|
"""删除结果记录"""
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,))
|
|
|
conn.commit()
|
|
|
conn.close()
|
|
|
|
|
|
return cursor.rowcount > 0
|
|
|
|
|
|
|
|
|
# ---------------------- 特征提取模块(与训练集保持一致) ----------------------
|
|
|
class FeatureProcessor:
|
|
|
"""特征提取器,统一使用训练时的特征维度和计算方式"""
|
|
|
WIN = 1024
|
|
|
OVERLAP = 512
|
|
|
THRESHOLD = 0.1
|
|
|
SEG_LEN_S = 0.2
|
|
|
STEP = WIN - OVERLAP
|
|
|
|
|
|
@staticmethod
|
|
|
def frame_energy(signal: np.ndarray) -> np.ndarray:
|
|
|
"""计算帧能量"""
|
|
|
frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN,
|
|
|
hop_length=FeatureProcessor.STEP)
|
|
|
return np.sum(frames ** 2, axis=0)
|
|
|
|
|
|
@staticmethod
|
|
|
def detect_impacts(energy: np.ndarray) -> np.ndarray:
|
|
|
"""检测有效敲击片段起始点"""
|
|
|
idx = np.where(energy > FeatureProcessor.THRESHOLD)[0]
|
|
|
if idx.size == 0:
|
|
|
return np.array([])
|
|
|
# 间隔超过5帧视为新片段
|
|
|
flags = np.diff(np.concatenate(([0], idx))) > 5
|
|
|
return idx[flags]
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]:
|
|
|
"""切分固定长度的音频片段"""
|
|
|
seg_len = int(FeatureProcessor.SEG_LEN_S * sr)
|
|
|
segments = []
|
|
|
for frame_idx in starts:
|
|
|
start = frame_idx * FeatureProcessor.STEP
|
|
|
end = min(start + seg_len, len(signal))
|
|
|
segments.append(signal[start:end])
|
|
|
return segments
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_features(signal: np.ndarray, sr: int) -> np.ndarray:
|
|
|
"""提取与训练集一致的5维特征(修正原GUI可能的特征不匹配问题)"""
|
|
|
signal = signal.flatten()
|
|
|
if len(signal) == 0:
|
|
|
return np.zeros(5, dtype=np.float32)
|
|
|
|
|
|
# 1. RMS能量
|
|
|
rms = np.sqrt(np.mean(signal ** 2))
|
|
|
|
|
|
# 2. 主频(频谱峰值)
|
|
|
fft = np.fft.fft(signal)
|
|
|
freq = np.fft.fftfreq(len(signal), 1 / sr)
|
|
|
positive_mask = freq >= 0
|
|
|
freq = freq[positive_mask]
|
|
|
fft_mag = np.abs(fft[positive_mask])
|
|
|
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
|
|
|
|
|
|
# 3. 频谱偏度
|
|
|
spec_power = fft_mag
|
|
|
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
|
|
|
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) /
|
|
|
(np.sum(spec_power) + 1e-12))
|
|
|
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
|
|
|
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0
|
|
|
|
|
|
# 4. MFCC第一维均值
|
|
|
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13)
|
|
|
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
|
|
|
|
|
|
# 5. 包络峰值(希尔伯特变换)
|
|
|
env_peak = np.max(np.abs(hilbert(signal)))
|
|
|
|
|
|
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
|
|
|
|
|
|
|
|
|
# ---------------------- 录音线程 ----------------------
|
|
|
class AudioRecorder(QThread):
|
|
|
status_updated = pyqtSignal(str)
|
|
|
progress_updated = pyqtSignal(int)
|
|
|
recording_finished = pyqtSignal(str)
|
|
|
|
|
|
def __init__(self, max_duration=60):
|
|
|
super().__init__()
|
|
|
self.max_duration = max_duration
|
|
|
self.recording = False
|
|
|
self.temp_file = "temp_audio.wav"
|
|
|
|
|
|
def run(self):
|
|
|
# 音频参数(与特征提取兼容)
|
|
|
FORMAT = pyaudio.paFloat32
|
|
|
CHANNELS = 1
|
|
|
RATE = 22050
|
|
|
CHUNK = 1024
|
|
|
|
|
|
try:
|
|
|
audio = pyaudio.PyAudio()
|
|
|
stream = audio.open(
|
|
|
format=FORMAT,
|
|
|
channels=CHANNELS,
|
|
|
rate=RATE,
|
|
|
input=True,
|
|
|
frames_per_buffer=CHUNK
|
|
|
)
|
|
|
|
|
|
self.recording = True
|
|
|
start_time = time.time()
|
|
|
frames = []
|
|
|
|
|
|
while self.recording:
|
|
|
elapsed = time.time() - start_time
|
|
|
if elapsed >= self.max_duration:
|
|
|
self.status_updated.emit(f"已达最大时长 {self.max_duration} 秒")
|
|
|
break
|
|
|
|
|
|
self.progress_updated.emit(int((elapsed / self.max_duration) * 100))
|
|
|
data = stream.read(CHUNK)
|
|
|
frames.append(data)
|
|
|
|
|
|
# 停止录音并保存
|
|
|
stream.stop_stream()
|
|
|
stream.close()
|
|
|
audio.terminate()
|
|
|
|
|
|
if frames:
|
|
|
with wave.open(self.temp_file, 'wb') as wf:
|
|
|
wf.setnchannels(CHANNELS)
|
|
|
wf.setsampwidth(audio.get_sample_size(FORMAT))
|
|
|
wf.setframerate(RATE)
|
|
|
wf.writeframes(b''.join(frames))
|
|
|
self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}秒")
|
|
|
self.recording_finished.emit(self.temp_file)
|
|
|
else:
|
|
|
self.status_updated.emit("未检测到音频输入")
|
|
|
|
|
|
except Exception as e:
|
|
|
self.status_updated.emit(f"录音错误: {str(e)}")
|
|
|
|
|
|
def stop(self):
|
|
|
self.recording = False
|
|
|
|
|
|
|
|
|
# ---------------------- 音频处理与预测线程 ----------------------
|
|
|
class AudioProcessor(QThread):
|
|
|
status_updated = pyqtSignal(str)
|
|
|
result_generated = pyqtSignal(dict)
|
|
|
|
|
|
def __init__(self, audio_path, model_path, scaler_path):
|
|
|
super().__init__()
|
|
|
self.audio_path = audio_path
|
|
|
self.model_path = model_path
|
|
|
self.scaler_path = scaler_path
|
|
|
self.class_threshold = 0.5 # 分类阈值
|
|
|
|
|
|
def run(self):
|
|
|
try:
|
|
|
# 加载模型和标准化器
|
|
|
self.status_updated.emit("加载模型资源...")
|
|
|
model = joblib.load(self.model_path)
|
|
|
scaler = joblib.load(self.scaler_path)
|
|
|
|
|
|
# 读取音频文件
|
|
|
self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}")
|
|
|
signal, sr = librosa.load(self.audio_path, sr=None, mono=True)
|
|
|
signal = signal / np.max(np.abs(signal)) # 归一化
|
|
|
|
|
|
# 提取片段
|
|
|
self.status_updated.emit("检测有效音频片段...")
|
|
|
energy = FeatureProcessor.frame_energy(signal)
|
|
|
impact_starts = FeatureProcessor.detect_impacts(energy)
|
|
|
segments = FeatureProcessor.extract_segments(signal, sr, impact_starts)
|
|
|
|
|
|
if not segments:
|
|
|
self.status_updated.emit("未检测到有效敲击片段")
|
|
|
return
|
|
|
|
|
|
# 提取特征并预测
|
|
|
self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...")
|
|
|
features = [FeatureProcessor.extract_features(seg, sr) for seg in segments]
|
|
|
X = np.vstack(features)
|
|
|
|
|
|
# 标准化特征
|
|
|
X_scaled = scaler.transform(X)
|
|
|
|
|
|
# 模型预测(处理标签转换:-1/1 → 0/1)
|
|
|
predictions = model.predict(X_scaled)
|
|
|
pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率
|
|
|
pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签
|
|
|
|
|
|
# 计算文件级结果
|
|
|
mean_prob = pred_proba.mean()
|
|
|
final_label = 1 if mean_prob >= self.class_threshold else 0
|
|
|
result = {
|
|
|
"filename": os.path.basename(self.audio_path),
|
|
|
"segment_count": len(segments),
|
|
|
"segment_labels": pred_labels.tolist(),
|
|
|
"segment_probs": [round(p, 4) for p in pred_proba],
|
|
|
"mean_probability": round(mean_prob, 4),
|
|
|
"final_label": final_label,
|
|
|
"label_text": "空心" if final_label == 0 else "实心" # 假设0=空心,1=实心
|
|
|
}
|
|
|
|
|
|
self.result_generated.emit(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.status_updated.emit(f"处理错误: {str(e)}")
|
|
|
|
|
|
|
|
|
# ---------------------- 数据库编辑对话框 ----------------------
|
|
|
class EditResultDialog(QDialog):
|
|
|
def __init__(self, result_data=None, parent=None):
|
|
|
super().__init__(parent)
|
|
|
self.result_data = result_data or {}
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
self.setWindowTitle("编辑结果记录")
|
|
|
self.setModal(True)
|
|
|
self.resize(400, 300)
|
|
|
|
|
|
layout = QFormLayout()
|
|
|
|
|
|
# 文件名
|
|
|
self.filename_edit = QLineEdit(self.result_data.get('filename', ''))
|
|
|
layout.addRow("文件名:", self.filename_edit)
|
|
|
|
|
|
# 片段数量
|
|
|
self.segment_count_spin = QSpinBox()
|
|
|
self.segment_count_spin.setRange(0, 1000)
|
|
|
self.segment_count_spin.setValue(self.result_data.get('segment_count', 0))
|
|
|
layout.addRow("片段数量:", self.segment_count_spin)
|
|
|
|
|
|
# 平均概率
|
|
|
self.mean_prob_spin = QDoubleSpinBox()
|
|
|
self.mean_prob_spin.setRange(0.0, 1.0)
|
|
|
self.mean_prob_spin.setDecimals(4)
|
|
|
self.mean_prob_spin.setSingleStep(0.01)
|
|
|
self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0))
|
|
|
layout.addRow("平均概率:", self.mean_prob_spin)
|
|
|
|
|
|
# 最终标签
|
|
|
self.final_label_combo = QComboBox()
|
|
|
self.final_label_combo.addItem("空心", 0)
|
|
|
self.final_label_combo.addItem("实心", 1)
|
|
|
current_label = self.result_data.get('final_label', 0)
|
|
|
index = self.final_label_combo.findData(current_label)
|
|
|
if index >= 0:
|
|
|
self.final_label_combo.setCurrentIndex(index)
|
|
|
layout.addRow("最终分类:", self.final_label_combo)
|
|
|
|
|
|
# 备注
|
|
|
self.notes_edit = QTextEdit()
|
|
|
self.notes_edit.setMaximumHeight(80)
|
|
|
self.notes_edit.setText(self.result_data.get('notes', ''))
|
|
|
layout.addRow("备注:", self.notes_edit)
|
|
|
|
|
|
# 按钮
|
|
|
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
|
|
button_box.accepted.connect(self.accept)
|
|
|
button_box.rejected.connect(self.reject)
|
|
|
layout.addRow(button_box)
|
|
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def get_updated_data(self):
|
|
|
"""获取更新后的数据"""
|
|
|
return {
|
|
|
'filename': self.filename_edit.text(),
|
|
|
'segment_count': self.segment_count_spin.value(),
|
|
|
'mean_probability': self.mean_prob_spin.value(),
|
|
|
'final_label': self.final_label_combo.currentData(),
|
|
|
'label_text': self.final_label_combo.currentText(),
|
|
|
'notes': self.notes_edit.toPlainText()
|
|
|
}
|
|
|
|
|
|
|
|
|
# ---------------------- 数据库汇总界面 ----------------------
|
|
|
class DatabaseUI(QWidget):
|
|
|
def __init__(self, parent=None):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.db_manager = DatabaseManager()
|
|
|
self.init_ui()
|
|
|
self.load_data()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
layout.setContentsMargins(20, 20, 20, 20)
|
|
|
|
|
|
# 返回按钮
|
|
|
btn_back = QPushButton("← 返回主菜单")
|
|
|
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
|
|
|
layout.addWidget(btn_back)
|
|
|
|
|
|
# 标题
|
|
|
title = QLabel("数据库汇总管理")
|
|
|
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# 创建标签页
|
|
|
self.tab_widget = QTabWidget()
|
|
|
|
|
|
# 数据浏览标签页
|
|
|
self.browse_tab = QWidget()
|
|
|
self.init_browse_tab()
|
|
|
self.tab_widget.addTab(self.browse_tab, "数据浏览")
|
|
|
|
|
|
# 搜索标签页
|
|
|
self.search_tab = QWidget()
|
|
|
self.init_search_tab()
|
|
|
self.tab_widget.addTab(self.search_tab, "搜索过滤")
|
|
|
|
|
|
layout.addWidget(self.tab_widget)
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def init_browse_tab(self):
|
|
|
layout = QVBoxLayout()
|
|
|
|
|
|
# 操作按钮
|
|
|
btn_layout = QHBoxLayout()
|
|
|
|
|
|
self.btn_refresh = QPushButton("刷新数据")
|
|
|
self.btn_refresh.clicked.connect(self.load_data)
|
|
|
|
|
|
self.btn_add = QPushButton("新增记录")
|
|
|
self.btn_add.clicked.connect(self.add_result)
|
|
|
|
|
|
self.btn_edit = QPushButton("编辑选中")
|
|
|
self.btn_edit.clicked.connect(self.edit_selected)
|
|
|
|
|
|
self.btn_delete = QPushButton("删除选中")
|
|
|
self.btn_delete.clicked.connect(self.delete_selected)
|
|
|
|
|
|
btn_layout.addWidget(self.btn_refresh)
|
|
|
btn_layout.addWidget(self.btn_add)
|
|
|
btn_layout.addWidget(self.btn_edit)
|
|
|
btn_layout.addWidget(self.btn_delete)
|
|
|
btn_layout.addStretch()
|
|
|
|
|
|
layout.addLayout(btn_layout)
|
|
|
|
|
|
# 数据表格
|
|
|
self.table_widget = QTableWidget()
|
|
|
self.table_widget.setColumnCount(8)
|
|
|
self.table_widget.setHorizontalHeaderLabels([
|
|
|
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
|
|
|
])
|
|
|
self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
|
|
|
self.table_widget.setSelectionBehavior(QTableWidget.SelectRows)
|
|
|
|
|
|
layout.addWidget(self.table_widget)
|
|
|
self.browse_tab.setLayout(layout)
|
|
|
|
|
|
def init_search_tab(self):
|
|
|
layout = QVBoxLayout()
|
|
|
|
|
|
# 搜索条件表单
|
|
|
form_layout = QFormLayout()
|
|
|
|
|
|
self.search_filename = QLineEdit()
|
|
|
self.search_filename.setPlaceholderText("输入文件名关键词")
|
|
|
form_layout.addRow("文件名:", self.search_filename)
|
|
|
|
|
|
self.search_label = QComboBox()
|
|
|
self.search_label.addItem("全部", None)
|
|
|
self.search_label.addItem("空心", 0)
|
|
|
self.search_label.addItem("实心", 1)
|
|
|
form_layout.addRow("分类结果:", self.search_label)
|
|
|
|
|
|
self.search_date = QLineEdit()
|
|
|
self.search_date.setPlaceholderText("YYYY-MM-DD")
|
|
|
form_layout.addRow("创建日期:", self.search_date)
|
|
|
|
|
|
layout.addLayout(form_layout)
|
|
|
|
|
|
# 搜索按钮
|
|
|
btn_search = QPushButton("搜索")
|
|
|
btn_search.clicked.connect(self.perform_search)
|
|
|
layout.addWidget(btn_search)
|
|
|
|
|
|
# 搜索结果表格
|
|
|
self.search_table = QTableWidget()
|
|
|
self.search_table.setColumnCount(8)
|
|
|
self.search_table.setHorizontalHeaderLabels([
|
|
|
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
|
|
|
])
|
|
|
self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
|
|
|
self.search_table.setSelectionBehavior(QTableWidget.SelectRows)
|
|
|
|
|
|
layout.addWidget(self.search_table)
|
|
|
self.search_tab.setLayout(layout)
|
|
|
|
|
|
def load_data(self):
|
|
|
"""加载所有数据到浏览表格"""
|
|
|
results = self.db_manager.get_all_results()
|
|
|
self.populate_table(self.table_widget, results)
|
|
|
|
|
|
def perform_search(self):
|
|
|
"""执行搜索操作"""
|
|
|
filename_filter = self.search_filename.text().strip()
|
|
|
label_filter = self.search_label.currentData()
|
|
|
date_filter = self.search_date.text().strip() or None
|
|
|
|
|
|
results = self.db_manager.search_results(filename_filter, label_filter, date_filter)
|
|
|
self.populate_table(self.search_table, results)
|
|
|
|
|
|
def populate_table(self, table, results):
|
|
|
"""填充表格数据"""
|
|
|
table.setRowCount(len(results))
|
|
|
|
|
|
for row, result in enumerate(results):
|
|
|
# ID
|
|
|
table.setItem(row, 0, QTableWidgetItem(str(result['id'])))
|
|
|
|
|
|
# 文件名
|
|
|
table.setItem(row, 1, QTableWidgetItem(result['filename']))
|
|
|
|
|
|
# 片段数
|
|
|
table.setItem(row, 2, QTableWidgetItem(str(result['segment_count'])))
|
|
|
|
|
|
# 平均概率
|
|
|
table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}"))
|
|
|
|
|
|
# 分类结果
|
|
|
table.setItem(row, 4, QTableWidgetItem(result['label_text']))
|
|
|
|
|
|
# 创建时间
|
|
|
create_time = result['create_time']
|
|
|
if isinstance(create_time, str):
|
|
|
table.setItem(row, 5, QTableWidgetItem(create_time))
|
|
|
else:
|
|
|
table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S")))
|
|
|
|
|
|
# 备注
|
|
|
notes = result.get('notes', '')
|
|
|
table.setItem(row, 6, QTableWidgetItem(notes if notes else "无"))
|
|
|
|
|
|
# 操作按钮
|
|
|
btn_view = QPushButton("查看详情")
|
|
|
btn_view.clicked.connect(lambda checked, r=result: self.view_details(r))
|
|
|
table.setCellWidget(row, 7, btn_view)
|
|
|
|
|
|
def add_result(self):
|
|
|
"""新增结果记录"""
|
|
|
dialog = EditResultDialog()
|
|
|
if dialog.exec_() == QDialog.Accepted:
|
|
|
new_data = dialog.get_updated_data()
|
|
|
# 设置一些默认值
|
|
|
new_data['segment_labels'] = []
|
|
|
new_data['segment_probs'] = []
|
|
|
|
|
|
try:
|
|
|
self.db_manager.insert_result(new_data, new_data.get('notes', ''))
|
|
|
self.load_data()
|
|
|
QMessageBox.information(self, "成功", "记录添加成功!")
|
|
|
except Exception as e:
|
|
|
QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}")
|
|
|
|
|
|
def edit_selected(self):
|
|
|
"""编辑选中的记录"""
|
|
|
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
|
|
|
current_row = current_table.currentRow()
|
|
|
|
|
|
if current_row < 0:
|
|
|
QMessageBox.warning(self, "警告", "请先选择一条记录!")
|
|
|
return
|
|
|
|
|
|
result_id = int(current_table.item(current_row, 0).text())
|
|
|
results = self.db_manager.search_results()
|
|
|
result_data = next((r for r in results if r['id'] == result_id), None)
|
|
|
|
|
|
if not result_data:
|
|
|
QMessageBox.warning(self, "错误", "未找到选中的记录!")
|
|
|
return
|
|
|
|
|
|
dialog = EditResultDialog(result_data)
|
|
|
if dialog.exec_() == QDialog.Accepted:
|
|
|
updated_data = dialog.get_updated_data()
|
|
|
|
|
|
try:
|
|
|
success = self.db_manager.update_result(result_id, updated_data)
|
|
|
if success:
|
|
|
self.load_data()
|
|
|
if current_table == self.search_table:
|
|
|
self.perform_search()
|
|
|
QMessageBox.information(self, "成功", "记录更新成功!")
|
|
|
else:
|
|
|
QMessageBox.warning(self, "错误", "更新记录失败!")
|
|
|
except Exception as e:
|
|
|
QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}")
|
|
|
|
|
|
def delete_selected(self):
|
|
|
"""删除选中的记录"""
|
|
|
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
|
|
|
current_row = current_table.currentRow()
|
|
|
|
|
|
if current_row < 0:
|
|
|
QMessageBox.warning(self, "警告", "请先选择一条记录!")
|
|
|
return
|
|
|
|
|
|
result_id = int(current_table.item(current_row, 0).text())
|
|
|
filename = current_table.item(current_row, 1).text()
|
|
|
|
|
|
reply = QMessageBox.question(
|
|
|
self, "确认删除",
|
|
|
f"确定要删除文件 '{filename}' 的记录吗?",
|
|
|
QMessageBox.Yes | QMessageBox.No
|
|
|
)
|
|
|
|
|
|
if reply == QMessageBox.Yes:
|
|
|
try:
|
|
|
success = self.db_manager.delete_result(result_id)
|
|
|
if success:
|
|
|
self.load_data()
|
|
|
if current_table == self.search_table:
|
|
|
self.perform_search()
|
|
|
QMessageBox.information(self, "成功", "记录删除成功!")
|
|
|
else:
|
|
|
QMessageBox.warning(self, "错误", "删除记录失败!")
|
|
|
except Exception as e:
|
|
|
QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}")
|
|
|
|
|
|
def view_details(self, result_data):
|
|
|
"""查看记录详情"""
|
|
|
details = (
|
|
|
f"记录ID: {result_data['id']}\n"
|
|
|
f"文件名: {result_data['filename']}\n"
|
|
|
f"片段数量: {result_data['segment_count']}\n"
|
|
|
f"片段标签: {result_data['segment_labels']}\n"
|
|
|
f"片段概率: {result_data['segment_probs']}\n"
|
|
|
f"平均概率: {result_data['mean_probability']:.4f}\n"
|
|
|
f"最终分类: {result_data['label_text']}\n"
|
|
|
f"创建时间: {result_data['create_time']}\n"
|
|
|
f"备注: {result_data.get('notes', '无')}"
|
|
|
)
|
|
|
|
|
|
QMessageBox.information(self, "记录详情", details)
|
|
|
|
|
|
|
|
|
# ---------------------- 主界面组件 ----------------------
|
|
|
class MainMenu(QWidget):
|
|
|
def __init__(self, parent=None):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
layout.setSpacing(20)
|
|
|
|
|
|
# 标题
|
|
|
title = QLabel("音频分类系统")
|
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# 功能按钮
|
|
|
btn_record = QPushButton("录制音频")
|
|
|
btn_record.setMinimumHeight(70)
|
|
|
btn_record.setStyleSheet("font-size: 18px;")
|
|
|
btn_record.clicked.connect(lambda: self.parent.switch_page("record"))
|
|
|
|
|
|
btn_upload = QPushButton("上传音频文件")
|
|
|
btn_upload.setMinimumHeight(70)
|
|
|
btn_upload.setStyleSheet("font-size: 18px;")
|
|
|
btn_upload.clicked.connect(lambda: self.parent.switch_page("upload"))
|
|
|
|
|
|
btn_database = QPushButton("数据库管理")
|
|
|
btn_database.setMinimumHeight(70)
|
|
|
btn_database.setStyleSheet("font-size: 18px;")
|
|
|
btn_database.clicked.connect(lambda: self.parent.switch_page("database"))
|
|
|
|
|
|
layout.addWidget(btn_record)
|
|
|
layout.addWidget(btn_upload)
|
|
|
layout.addWidget(btn_database)
|
|
|
layout.addStretch()
|
|
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
|
class InputPage(QWidget):
|
|
|
def __init__(self, parent=None, mode="record"):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.mode = mode # "record" 或 "upload"
|
|
|
self.audio_path = ""
|
|
|
self.recorder = None
|
|
|
self.db_manager = DatabaseManager() # 添加数据库管理器
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
# 主布局
|
|
|
main_layout = QVBoxLayout()
|
|
|
main_layout.setContentsMargins(20, 20, 20, 20)
|
|
|
|
|
|
# 返回按钮
|
|
|
btn_back = QPushButton("← 返回主菜单")
|
|
|
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
|
|
|
main_layout.addWidget(btn_back)
|
|
|
|
|
|
# 标题
|
|
|
title = QLabel("录制音频" if self.mode == "record" else "上传音频文件")
|
|
|
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
|
|
|
main_layout.addWidget(title)
|
|
|
|
|
|
# 操作区域
|
|
|
if self.mode == "record":
|
|
|
# 录音控制
|
|
|
self.btn_start_rec = QPushButton("开始录音")
|
|
|
self.btn_start_rec.clicked.connect(self.start_recording)
|
|
|
self.btn_stop_rec = QPushButton("停止录音")
|
|
|
self.btn_stop_rec.setEnabled(False)
|
|
|
self.btn_stop_rec.clicked.connect(self.stop_recording)
|
|
|
|
|
|
# 进度条
|
|
|
self.progress_bar = QProgressBar()
|
|
|
self.progress_bar.setRange(0, 100)
|
|
|
self.progress_bar.setValue(0)
|
|
|
|
|
|
# 录音布局
|
|
|
rec_layout = QHBoxLayout()
|
|
|
rec_layout.addWidget(self.btn_start_rec)
|
|
|
rec_layout.addWidget(self.btn_stop_rec)
|
|
|
main_layout.addLayout(rec_layout)
|
|
|
main_layout.addWidget(self.progress_bar)
|
|
|
|
|
|
else:
|
|
|
# 上传控制
|
|
|
self.btn_browse = QPushButton("选择WAV文件")
|
|
|
self.btn_browse.clicked.connect(self.browse_file)
|
|
|
self.lbl_file = QLabel("未选择文件")
|
|
|
self.lbl_file.setStyleSheet("color: #666;")
|
|
|
main_layout.addWidget(self.btn_browse)
|
|
|
main_layout.addWidget(self.lbl_file)
|
|
|
|
|
|
# 状态显示
|
|
|
self.status_display = QTextEdit()
|
|
|
self.status_display.setReadOnly(True)
|
|
|
self.status_display.setMinimumHeight(150)
|
|
|
main_layout.addWidget(QLabel("状态信息:"))
|
|
|
main_layout.addWidget(self.status_display)
|
|
|
|
|
|
# 处理按钮
|
|
|
self.btn_process = QPushButton("开始分析")
|
|
|
self.btn_process.setEnabled(False)
|
|
|
self.btn_process.clicked.connect(self.process_audio)
|
|
|
main_layout.addWidget(self.btn_process)
|
|
|
|
|
|
main_layout.addStretch()
|
|
|
self.setLayout(main_layout)
|
|
|
|
|
|
def start_recording(self):
|
|
|
self.status_display.append("开始录音...")
|
|
|
self.recorder = AudioRecorder(max_duration=60)
|
|
|
self.recorder.status_updated.connect(self.update_status)
|
|
|
self.recorder.progress_updated.connect(self.progress_bar.setValue)
|
|
|
self.recorder.recording_finished.connect(self.on_recording_finished)
|
|
|
self.recorder.start()
|
|
|
self.btn_start_rec.setEnabled(False)
|
|
|
self.btn_stop_rec.setEnabled(True)
|
|
|
|
|
|
def stop_recording(self):
|
|
|
if self.recorder and self.recorder.recording:
|
|
|
self.recorder.stop()
|
|
|
self.btn_start_rec.setEnabled(True)
|
|
|
self.btn_stop_rec.setEnabled(False)
|
|
|
|
|
|
def on_recording_finished(self, file_path):
|
|
|
self.audio_path = file_path
|
|
|
self.btn_process.setEnabled(True)
|
|
|
self.status_display.append(f"录音文件已保存: {file_path}")
|
|
|
|
|
|
def browse_file(self):
|
|
|
file_path, _ = QFileDialog.getOpenFileName(
|
|
|
self, "选择音频文件", "", "WAV文件 (*.wav)"
|
|
|
)
|
|
|
if file_path:
|
|
|
self.audio_path = file_path
|
|
|
self.lbl_file.setText(os.path.basename(file_path))
|
|
|
self.btn_process.setEnabled(True)
|
|
|
self.update_status(f"已选择文件: {file_path}")
|
|
|
|
|
|
def process_audio(self):
|
|
|
if not self.audio_path or not os.path.exists(self.audio_path):
|
|
|
QMessageBox.warning(self, "错误", "音频文件不存在")
|
|
|
return
|
|
|
|
|
|
# 模型路径(请根据实际情况修改)
|
|
|
model_path = "pipeline_model.pkl" # 或 "svm_model.pkl"
|
|
|
scaler_path = "scaler.pkl"
|
|
|
|
|
|
if not (os.path.exists(model_path) and os.path.exists(scaler_path)):
|
|
|
QMessageBox.warning(self, "错误", "模型文件或标准化器不存在")
|
|
|
return
|
|
|
|
|
|
# 启动处理线程
|
|
|
self.processor = AudioProcessor(self.audio_path, model_path, scaler_path)
|
|
|
self.processor.status_updated.connect(self.update_status)
|
|
|
self.processor.result_generated.connect(self.show_result)
|
|
|
self.processor.start()
|
|
|
self.btn_process.setEnabled(False)
|
|
|
self.update_status("开始分析音频...")
|
|
|
|
|
|
def update_status(self, message):
|
|
|
self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}")
|
|
|
# 自动滚动到底部
|
|
|
self.status_display.moveCursor(self.status_display.textCursor().End)
|
|
|
|
|
|
def show_result(self, result):
|
|
|
# 显示结果对话框,增加保存到数据库的选项
|
|
|
msg = QMessageBox()
|
|
|
msg.setWindowTitle("分析结果")
|
|
|
msg.setIcon(QMessageBox.Information)
|
|
|
|
|
|
text = (
|
|
|
f"文件名: {result['filename']}\n"
|
|
|
f"有效片段数: {result['segment_count']}\n"
|
|
|
f"平均概率: {result['mean_probability']}\n"
|
|
|
f"最终分类: {result['label_text']}"
|
|
|
)
|
|
|
msg.setText(text)
|
|
|
|
|
|
# 添加保存到数据库的按钮
|
|
|
msg.addButton("保存到数据库", QMessageBox.AcceptRole)
|
|
|
msg.addButton("仅查看", QMessageBox.RejectRole)
|
|
|
|
|
|
reply = msg.exec_()
|
|
|
|
|
|
if reply == 0: # 保存到数据库
|
|
|
notes, ok = QInputDialog.getText(
|
|
|
self, "添加备注", "请输入备注信息(可选):",
|
|
|
QLineEdit.Normal, ""
|
|
|
)
|
|
|
if ok:
|
|
|
try:
|
|
|
self.db_manager.insert_result(result, notes)
|
|
|
self.update_status("结果已保存到数据库")
|
|
|
except Exception as e:
|
|
|
QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}")
|
|
|
|
|
|
self.update_status("分析完成")
|
|
|
self.btn_process.setEnabled(True)
|
|
|
|
|
|
|
|
|
# ---------------------- 主窗口 ----------------------
|
|
|
class AudioClassifierApp(QMainWindow):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
self.setWindowTitle("音频分类器")
|
|
|
self.setGeometry(300, 300, 800, 600)
|
|
|
|
|
|
# 堆叠窗口管理页面
|
|
|
self.stack = QStackedWidget()
|
|
|
self.main_menu = MainMenu(self)
|
|
|
self.record_page = InputPage(self, mode="record")
|
|
|
self.upload_page = InputPage(self, mode="upload")
|
|
|
self.database_page = DatabaseUI(self) # 新增数据库页面
|
|
|
|
|
|
self.stack.addWidget(self.main_menu)
|
|
|
self.stack.addWidget(self.record_page)
|
|
|
self.stack.addWidget(self.upload_page)
|
|
|
self.stack.addWidget(self.database_page)
|
|
|
|
|
|
self.setCentralWidget(self.stack)
|
|
|
|
|
|
def switch_page(self, page_name):
|
|
|
"""切换页面"""
|
|
|
if page_name == "main":
|
|
|
self.stack.setCurrentWidget(self.main_menu)
|
|
|
elif page_name == "record":
|
|
|
self.stack.setCurrentWidget(self.record_page)
|
|
|
elif page_name == "upload":
|
|
|
self.stack.setCurrentWidget(self.upload_page)
|
|
|
elif page_name == "database":
|
|
|
self.stack.setCurrentWidget(self.database_page)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app = QApplication(sys.argv)
|
|
|
window = AudioClassifierApp()
|
|
|
window.show()
|
|
|
sys.exit(app.exec_()) |