Compare commits

...

3 Commits

Author SHA1 Message Date
zhd 9ca7ceb0f3 1
2 months ago
zhd 1d5e4c5e09 2
2 months ago
zhd 6d7b4fafb9 1
2 months ago

@ -1,945 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
import sqlite3
from datetime import datetime
from scipy.signal import hilbert, find_peaks
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout,
QWidget, QProgressBar, QStackedWidget, QMessageBox,
QTableWidget, QTableWidgetItem, QHeaderView,
QLineEdit, QDialog, QDialogButtonBox, QFormLayout,
QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
# ---------------------- 数据库管理模块 ----------------------
class DatabaseManager:
"""数据库管理器,负责结果数据的存储与查询"""
def __init__(self, db_path="audio_classification.db"):
self.db_path = db_path
self.init_database()
def init_database(self):
"""初始化数据库表结构"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建结果表
cursor.execute('''
CREATE TABLE IF NOT EXISTS classification_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
segment_count INTEGER NOT NULL,
segment_labels TEXT NOT NULL,
segment_probs TEXT NOT NULL,
mean_probability REAL NOT NULL,
final_label INTEGER NOT NULL,
label_text TEXT NOT NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
notes TEXT
)
''')
conn.commit()
conn.close()
def insert_result(self, result_data, notes=""):
"""插入新的分类结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO classification_results
(filename, segment_count, segment_labels, segment_probs,
mean_probability, final_label, label_text, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
result_data['filename'],
result_data['segment_count'],
str(result_data['segment_labels']),
str(result_data['segment_probs']),
result_data['mean_probability'],
result_data['final_label'],
result_data['label_text'],
notes
))
conn.commit()
conn.close()
return cursor.lastrowid
def get_all_results(self):
"""获取所有结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM classification_results
ORDER BY create_time DESC
''')
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def search_results(self, filename_filter="", label_filter=None, date_filter=None):
"""根据条件搜索结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = '''
SELECT * FROM classification_results
WHERE 1=1
'''
params = []
if filename_filter:
query += " AND filename LIKE ?"
params.append(f'%{filename_filter}%')
if label_filter is not None:
query += " AND final_label = ?"
params.append(label_filter)
if date_filter:
query += " AND DATE(create_time) = ?"
params.append(date_filter)
query += " ORDER BY create_time DESC"
cursor.execute(query, params)
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def update_result(self, result_id, updates):
"""更新结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
set_clause = ", ".join([f"{key} = ?" for key in updates.keys()])
query = f"UPDATE classification_results SET {set_clause} WHERE id = ?"
params = list(updates.values())
params.append(result_id)
cursor.execute(query, params)
conn.commit()
conn.close()
return cursor.rowcount > 0
def delete_result(self, result_id):
"""删除结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,))
conn.commit()
conn.close()
return cursor.rowcount > 0
# ---------------------- 特征提取模块(与训练集保持一致) ----------------------
class FeatureProcessor:
"""特征提取器,统一使用训练时的特征维度和计算方式"""
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
"""计算帧能量"""
frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN,
hop_length=FeatureProcessor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_impacts(energy: np.ndarray) -> np.ndarray:
"""检测有效敲击片段起始点"""
idx = np.where(energy > FeatureProcessor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
# 间隔超过5帧视为新片段
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]:
"""切分固定长度的音频片段"""
seg_len = int(FeatureProcessor.SEG_LEN_S * sr)
segments = []
for frame_idx in starts:
start = frame_idx * FeatureProcessor.STEP
end = min(start + seg_len, len(signal))
segments.append(signal[start:end])
return segments
@staticmethod
def extract_features(signal: np.ndarray, sr: int) -> np.ndarray:
"""提取与训练集一致的5维特征修正原GUI可能的特征不匹配问题"""
signal = signal.flatten()
if len(signal) == 0:
return np.zeros(5, dtype=np.float32)
# 1. RMS能量
rms = np.sqrt(np.mean(signal ** 2))
# 2. 主频(频谱峰值)
fft = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), 1 / sr)
positive_mask = freq >= 0
freq = freq[positive_mask]
fft_mag = np.abs(fft[positive_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) /
(np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(signal)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
# ---------------------- 录音线程 ----------------------
class AudioRecorder(QThread):
status_updated = pyqtSignal(str)
progress_updated = pyqtSignal(int)
recording_finished = pyqtSignal(str)
def __init__(self, max_duration=60):
super().__init__()
self.max_duration = max_duration
self.recording = False
self.temp_file = "temp_audio.wav"
def run(self):
# 音频参数(与特征提取兼容)
FORMAT = pyaudio.paFloat32
CHANNELS = 1
RATE = 22050
CHUNK = 1024
try:
audio = pyaudio.PyAudio()
stream = audio.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK
)
self.recording = True
start_time = time.time()
frames = []
while self.recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.status_updated.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_updated.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK)
frames.append(data)
# 停止录音并保存
stream.stop_stream()
stream.close()
audio.terminate()
if frames:
with wave.open(self.temp_file, 'wb') as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(audio.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}")
self.recording_finished.emit(self.temp_file)
else:
self.status_updated.emit("未检测到音频输入")
except Exception as e:
self.status_updated.emit(f"录音错误: {str(e)}")
def stop(self):
self.recording = False
# ---------------------- 音频处理与预测线程 ----------------------
class AudioProcessor(QThread):
status_updated = pyqtSignal(str)
result_generated = pyqtSignal(dict)
def __init__(self, audio_path, model_path, scaler_path):
super().__init__()
self.audio_path = audio_path
self.model_path = model_path
self.scaler_path = scaler_path
self.class_threshold = 0.5 # 分类阈值
def run(self):
try:
# 加载模型和标准化器
self.status_updated.emit("加载模型资源...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
# 读取音频文件
self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}")
signal, sr = librosa.load(self.audio_path, sr=None, mono=True)
signal = signal / np.max(np.abs(signal)) # 归一化
# 提取片段
self.status_updated.emit("检测有效音频片段...")
energy = FeatureProcessor.frame_energy(signal)
impact_starts = FeatureProcessor.detect_impacts(energy)
segments = FeatureProcessor.extract_segments(signal, sr, impact_starts)
if not segments:
self.status_updated.emit("未检测到有效敲击片段")
return
# 提取特征并预测
self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...")
features = [FeatureProcessor.extract_features(seg, sr) for seg in segments]
X = np.vstack(features)
# 标准化特征
X_scaled = scaler.transform(X)
# 模型预测(处理标签转换:-1/1 → 0/1
predictions = model.predict(X_scaled)
pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率
pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签
# 计算文件级结果
mean_prob = pred_proba.mean()
final_label = 1 if mean_prob >= self.class_threshold else 0
result = {
"filename": os.path.basename(self.audio_path),
"segment_count": len(segments),
"segment_labels": pred_labels.tolist(),
"segment_probs": [round(p, 4) for p in pred_proba],
"mean_probability": round(mean_prob, 4),
"final_label": final_label,
"label_text": "空心" if final_label == 0 else "实心" # 假设0=空心1=实心
}
self.result_generated.emit(result)
except Exception as e:
self.status_updated.emit(f"处理错误: {str(e)}")
# ---------------------- 数据库编辑对话框 ----------------------
class EditResultDialog(QDialog):
def __init__(self, result_data=None, parent=None):
super().__init__(parent)
self.result_data = result_data or {}
self.init_ui()
def init_ui(self):
self.setWindowTitle("编辑结果记录")
self.setModal(True)
self.resize(400, 300)
layout = QFormLayout()
# 文件名
self.filename_edit = QLineEdit(self.result_data.get('filename', ''))
layout.addRow("文件名:", self.filename_edit)
# 片段数量
self.segment_count_spin = QSpinBox()
self.segment_count_spin.setRange(0, 1000)
self.segment_count_spin.setValue(self.result_data.get('segment_count', 0))
layout.addRow("片段数量:", self.segment_count_spin)
# 平均概率
self.mean_prob_spin = QDoubleSpinBox()
self.mean_prob_spin.setRange(0.0, 1.0)
self.mean_prob_spin.setDecimals(4)
self.mean_prob_spin.setSingleStep(0.01)
self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0))
layout.addRow("平均概率:", self.mean_prob_spin)
# 最终标签
self.final_label_combo = QComboBox()
self.final_label_combo.addItem("空心", 0)
self.final_label_combo.addItem("实心", 1)
current_label = self.result_data.get('final_label', 0)
index = self.final_label_combo.findData(current_label)
if index >= 0:
self.final_label_combo.setCurrentIndex(index)
layout.addRow("最终分类:", self.final_label_combo)
# 备注
self.notes_edit = QTextEdit()
self.notes_edit.setMaximumHeight(80)
self.notes_edit.setText(self.result_data.get('notes', ''))
layout.addRow("备注:", self.notes_edit)
# 按钮
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
layout.addRow(button_box)
self.setLayout(layout)
def get_updated_data(self):
"""获取更新后的数据"""
return {
'filename': self.filename_edit.text(),
'segment_count': self.segment_count_spin.value(),
'mean_probability': self.mean_prob_spin.value(),
'final_label': self.final_label_combo.currentData(),
'label_text': self.final_label_combo.currentText(),
'notes': self.notes_edit.toPlainText()
}
# ---------------------- 数据库汇总界面 ----------------------
class DatabaseUI(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.db_manager = DatabaseManager()
self.init_ui()
self.load_data()
def init_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
layout.addWidget(btn_back)
# 标题
title = QLabel("数据库汇总管理")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
layout.addWidget(title)
# 创建标签页
self.tab_widget = QTabWidget()
# 数据浏览标签页
self.browse_tab = QWidget()
self.init_browse_tab()
self.tab_widget.addTab(self.browse_tab, "数据浏览")
# 搜索标签页
self.search_tab = QWidget()
self.init_search_tab()
self.tab_widget.addTab(self.search_tab, "搜索过滤")
layout.addWidget(self.tab_widget)
self.setLayout(layout)
def init_browse_tab(self):
layout = QVBoxLayout()
# 操作按钮
btn_layout = QHBoxLayout()
self.btn_refresh = QPushButton("刷新数据")
self.btn_refresh.clicked.connect(self.load_data)
self.btn_add = QPushButton("新增记录")
self.btn_add.clicked.connect(self.add_result)
self.btn_edit = QPushButton("编辑选中")
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_delete = QPushButton("删除选中")
self.btn_delete.clicked.connect(self.delete_selected)
btn_layout.addWidget(self.btn_refresh)
btn_layout.addWidget(self.btn_add)
btn_layout.addWidget(self.btn_edit)
btn_layout.addWidget(self.btn_delete)
btn_layout.addStretch()
layout.addLayout(btn_layout)
# 数据表格
self.table_widget = QTableWidget()
self.table_widget.setColumnCount(8)
self.table_widget.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.table_widget.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.table_widget)
self.browse_tab.setLayout(layout)
def init_search_tab(self):
layout = QVBoxLayout()
# 搜索条件表单
form_layout = QFormLayout()
self.search_filename = QLineEdit()
self.search_filename.setPlaceholderText("输入文件名关键词")
form_layout.addRow("文件名:", self.search_filename)
self.search_label = QComboBox()
self.search_label.addItem("全部", None)
self.search_label.addItem("空心", 0)
self.search_label.addItem("实心", 1)
form_layout.addRow("分类结果:", self.search_label)
self.search_date = QLineEdit()
self.search_date.setPlaceholderText("YYYY-MM-DD")
form_layout.addRow("创建日期:", self.search_date)
layout.addLayout(form_layout)
# 搜索按钮
btn_search = QPushButton("搜索")
btn_search.clicked.connect(self.perform_search)
layout.addWidget(btn_search)
# 搜索结果表格
self.search_table = QTableWidget()
self.search_table.setColumnCount(8)
self.search_table.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.search_table.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.search_table)
self.search_tab.setLayout(layout)
def load_data(self):
"""加载所有数据到浏览表格"""
results = self.db_manager.get_all_results()
self.populate_table(self.table_widget, results)
def perform_search(self):
"""执行搜索操作"""
filename_filter = self.search_filename.text().strip()
label_filter = self.search_label.currentData()
date_filter = self.search_date.text().strip() or None
results = self.db_manager.search_results(filename_filter, label_filter, date_filter)
self.populate_table(self.search_table, results)
def populate_table(self, table, results):
"""填充表格数据"""
table.setRowCount(len(results))
for row, result in enumerate(results):
# ID
table.setItem(row, 0, QTableWidgetItem(str(result['id'])))
# 文件名
table.setItem(row, 1, QTableWidgetItem(result['filename']))
# 片段数
table.setItem(row, 2, QTableWidgetItem(str(result['segment_count'])))
# 平均概率
table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}"))
# 分类结果
table.setItem(row, 4, QTableWidgetItem(result['label_text']))
# 创建时间
create_time = result['create_time']
if isinstance(create_time, str):
table.setItem(row, 5, QTableWidgetItem(create_time))
else:
table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S")))
# 备注
notes = result.get('notes', '')
table.setItem(row, 6, QTableWidgetItem(notes if notes else ""))
# 操作按钮
btn_view = QPushButton("查看详情")
btn_view.clicked.connect(lambda checked, r=result: self.view_details(r))
table.setCellWidget(row, 7, btn_view)
def add_result(self):
"""新增结果记录"""
dialog = EditResultDialog()
if dialog.exec_() == QDialog.Accepted:
new_data = dialog.get_updated_data()
# 设置一些默认值
new_data['segment_labels'] = []
new_data['segment_probs'] = []
try:
self.db_manager.insert_result(new_data, new_data.get('notes', ''))
self.load_data()
QMessageBox.information(self, "成功", "记录添加成功!")
except Exception as e:
QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}")
def edit_selected(self):
"""编辑选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
results = self.db_manager.search_results()
result_data = next((r for r in results if r['id'] == result_id), None)
if not result_data:
QMessageBox.warning(self, "错误", "未找到选中的记录!")
return
dialog = EditResultDialog(result_data)
if dialog.exec_() == QDialog.Accepted:
updated_data = dialog.get_updated_data()
try:
success = self.db_manager.update_result(result_id, updated_data)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录更新成功!")
else:
QMessageBox.warning(self, "错误", "更新记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}")
def delete_selected(self):
"""删除选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
filename = current_table.item(current_row, 1).text()
reply = QMessageBox.question(
self, "确认删除",
f"确定要删除文件 '{filename}' 的记录吗?",
QMessageBox.Yes | QMessageBox.No
)
if reply == QMessageBox.Yes:
try:
success = self.db_manager.delete_result(result_id)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录删除成功!")
else:
QMessageBox.warning(self, "错误", "删除记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}")
def view_details(self, result_data):
"""查看记录详情"""
details = (
f"记录ID: {result_data['id']}\n"
f"文件名: {result_data['filename']}\n"
f"片段数量: {result_data['segment_count']}\n"
f"片段标签: {result_data['segment_labels']}\n"
f"片段概率: {result_data['segment_probs']}\n"
f"平均概率: {result_data['mean_probability']:.4f}\n"
f"最终分类: {result_data['label_text']}\n"
f"创建时间: {result_data['create_time']}\n"
f"备注: {result_data.get('notes', '')}"
)
QMessageBox.information(self, "记录详情", details)
# ---------------------- 主界面组件 ----------------------
class MainMenu(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
layout.setSpacing(20)
# 标题
title = QLabel("音频分类系统")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;")
layout.addWidget(title)
# 功能按钮
btn_record = QPushButton("录制音频")
btn_record.setMinimumHeight(70)
btn_record.setStyleSheet("font-size: 18px;")
btn_record.clicked.connect(lambda: self.parent.switch_page("record"))
btn_upload = QPushButton("上传音频文件")
btn_upload.setMinimumHeight(70)
btn_upload.setStyleSheet("font-size: 18px;")
btn_upload.clicked.connect(lambda: self.parent.switch_page("upload"))
btn_database = QPushButton("数据库管理")
btn_database.setMinimumHeight(70)
btn_database.setStyleSheet("font-size: 18px;")
btn_database.clicked.connect(lambda: self.parent.switch_page("database"))
layout.addWidget(btn_record)
layout.addWidget(btn_upload)
layout.addWidget(btn_database)
layout.addStretch()
self.setLayout(layout)
class InputPage(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode # "record" 或 "upload"
self.audio_path = ""
self.recorder = None
self.db_manager = DatabaseManager() # 添加数据库管理器
self.init_ui()
def init_ui(self):
# 主布局
main_layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
main_layout.addWidget(btn_back)
# 标题
title = QLabel("录制音频" if self.mode == "record" else "上传音频文件")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
main_layout.addWidget(title)
# 操作区域
if self.mode == "record":
# 录音控制
self.btn_start_rec = QPushButton("开始录音")
self.btn_start_rec.clicked.connect(self.start_recording)
self.btn_stop_rec = QPushButton("停止录音")
self.btn_stop_rec.setEnabled(False)
self.btn_stop_rec.clicked.connect(self.stop_recording)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
# 录音布局
rec_layout = QHBoxLayout()
rec_layout.addWidget(self.btn_start_rec)
rec_layout.addWidget(self.btn_stop_rec)
main_layout.addLayout(rec_layout)
main_layout.addWidget(self.progress_bar)
else:
# 上传控制
self.btn_browse = QPushButton("选择WAV文件")
self.btn_browse.clicked.connect(self.browse_file)
self.lbl_file = QLabel("未选择文件")
self.lbl_file.setStyleSheet("color: #666;")
main_layout.addWidget(self.btn_browse)
main_layout.addWidget(self.lbl_file)
# 状态显示
self.status_display = QTextEdit()
self.status_display.setReadOnly(True)
self.status_display.setMinimumHeight(150)
main_layout.addWidget(QLabel("状态信息:"))
main_layout.addWidget(self.status_display)
# 处理按钮
self.btn_process = QPushButton("开始分析")
self.btn_process.setEnabled(False)
self.btn_process.clicked.connect(self.process_audio)
main_layout.addWidget(self.btn_process)
main_layout.addStretch()
self.setLayout(main_layout)
def start_recording(self):
self.status_display.append("开始录音...")
self.recorder = AudioRecorder(max_duration=60)
self.recorder.status_updated.connect(self.update_status)
self.recorder.progress_updated.connect(self.progress_bar.setValue)
self.recorder.recording_finished.connect(self.on_recording_finished)
self.recorder.start()
self.btn_start_rec.setEnabled(False)
self.btn_stop_rec.setEnabled(True)
def stop_recording(self):
if self.recorder and self.recorder.recording:
self.recorder.stop()
self.btn_start_rec.setEnabled(True)
self.btn_stop_rec.setEnabled(False)
def on_recording_finished(self, file_path):
self.audio_path = file_path
self.btn_process.setEnabled(True)
self.status_display.append(f"录音文件已保存: {file_path}")
def browse_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "选择音频文件", "", "WAV文件 (*.wav)"
)
if file_path:
self.audio_path = file_path
self.lbl_file.setText(os.path.basename(file_path))
self.btn_process.setEnabled(True)
self.update_status(f"已选择文件: {file_path}")
def process_audio(self):
if not self.audio_path or not os.path.exists(self.audio_path):
QMessageBox.warning(self, "错误", "音频文件不存在")
return
# 模型路径(请根据实际情况修改)
model_path = "pipeline_model.pkl" # 或 "svm_model.pkl"
scaler_path = "scaler.pkl"
if not (os.path.exists(model_path) and os.path.exists(scaler_path)):
QMessageBox.warning(self, "错误", "模型文件或标准化器不存在")
return
# 启动处理线程
self.processor = AudioProcessor(self.audio_path, model_path, scaler_path)
self.processor.status_updated.connect(self.update_status)
self.processor.result_generated.connect(self.show_result)
self.processor.start()
self.btn_process.setEnabled(False)
self.update_status("开始分析音频...")
def update_status(self, message):
self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}")
# 自动滚动到底部
self.status_display.moveCursor(self.status_display.textCursor().End)
def show_result(self, result):
# 显示结果对话框,增加保存到数据库的选项
msg = QMessageBox()
msg.setWindowTitle("分析结果")
msg.setIcon(QMessageBox.Information)
text = (
f"文件名: {result['filename']}\n"
f"有效片段数: {result['segment_count']}\n"
f"平均概率: {result['mean_probability']}\n"
f"最终分类: {result['label_text']}"
)
msg.setText(text)
# 添加保存到数据库的按钮
msg.addButton("保存到数据库", QMessageBox.AcceptRole)
msg.addButton("仅查看", QMessageBox.RejectRole)
reply = msg.exec_()
if reply == 0: # 保存到数据库
notes, ok = QInputDialog.getText(
self, "添加备注", "请输入备注信息(可选):",
QLineEdit.Normal, ""
)
if ok:
try:
self.db_manager.insert_result(result, notes)
self.update_status("结果已保存到数据库")
except Exception as e:
QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}")
self.update_status("分析完成")
self.btn_process.setEnabled(True)
# ---------------------- 主窗口 ----------------------
class AudioClassifierApp(QMainWindow):
def __init__(self):
super().__init__()
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(300, 300, 800, 600)
# 堆叠窗口管理页面
self.stack = QStackedWidget()
self.main_menu = MainMenu(self)
self.record_page = InputPage(self, mode="record")
self.upload_page = InputPage(self, mode="upload")
self.database_page = DatabaseUI(self) # 新增数据库页面
self.stack.addWidget(self.main_menu)
self.stack.addWidget(self.record_page)
self.stack.addWidget(self.upload_page)
self.stack.addWidget(self.database_page)
self.setCentralWidget(self.stack)
def switch_page(self, page_name):
"""切换页面"""
if page_name == "main":
self.stack.setCurrentWidget(self.main_menu)
elif page_name == "record":
self.stack.setCurrentWidget(self.record_page)
elif page_name == "upload":
self.stack.setCurrentWidget(self.upload_page)
elif page_name == "database":
self.stack.setCurrentWidget(self.database_page)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AudioClassifierApp()
window.show()
sys.exit(app.exec_())

@ -1,630 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox) # >>> CHANGED: add QComboBox
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# 特征提取类
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]:
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
# 1. RMS能量
rms = np.sqrt(np.mean(x ** 2))
# 2. 主频(频谱峰值频率)
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive_freq_mask = freq >= 0
freq = freq[positive_freq_mask]
fft_mag = np.abs(fft[positive_freq_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# 录音线程
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int) # >>> NEW: 实时电平0-100
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index # >>> NEW
def run(self):
# >>> CHANGED: 更通用的音频参数
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p = None
stream = None
try:
p = pyaudio.PyAudio()
# 设备信息日志(帮助排查)
try:
device_log = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}")
if device_log:
self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log))
except Exception as _:
pass
# >>> CHANGED: 指定 input_device_index若传入
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
input_device_index=self.device_index, # 可能为 None则使用系统默认
frames_per_buffer=CHUNK
)
self.is_recording = True
start_time = time.time()
frames = []
# 用于电平估计的满刻度
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED
frames.append(data)
# >>> NEW: 计算电平RMS
# 将bytes转为np.int16归一化到[-1,1]计算RMS并发射到UI
try:
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
# 简单映射到0-100
level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整
self.level_signal.emit(level)
except Exception:
pass
if stream is not None:
stream.stop_stream()
stream.close()
if p is not None:
p.terminate()
if len(frames) > 0:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}")
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)")
except Exception as e:
# >>> NEW: 打开失败时提示设备列表
err_msg = f"录制错误: {str(e)}"
try:
if p is not None:
err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)"
finally:
self.update_signal.emit(err_msg)
finally:
try:
if stream is not None:
stream.stop_stream()
stream.close()
except Exception:
pass
try:
if p is not None:
p.terminate()
except Exception:
pass
# 处理线程
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path = wav_path
self.model_path = model_path
self.scaler_path = scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
# 归一化避免静音/削顶影响
if np.max(np.abs(sig)) > 0:
sig = sig / np.max(np.abs(sig))
self.update_signal.emit("提取特征...")
ene = FeatureExtractor.frame_energy(sig)
hit_starts = FeatureExtractor.detect_hits(ene)
segments = FeatureExtractor.extract_segments(sig, sr, hit_starts)
if not segments:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments]
X = np.vstack(feats)
self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}")
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
# 兼容无 predict_proba 的分类器
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
# 退化:用决策函数映射到(0,1)
if hasattr(model, "decision_function"):
df = model.decision_function(X_std)
y_proba = 1 / (1 + np.exp(-df))
else:
y_proba = (y_pred.astype(float) + 0.0)
self.finish_signal.emit({
"filename": os.path.basename(self.wav_path),
"segments": len(segments),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
})
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}")
# 第一层界面:主菜单
class MainMenuWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
title = QLabel("音频分类器")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
# 添加按钮
record_btn = QPushButton("采集音频")
record_btn.setMinimumHeight(60)
record_btn.setStyleSheet("font-size: 16px;")
record_btn.clicked.connect(lambda: self.parent.switch_to_input("record"))
upload_btn = QPushButton("上传外部WAV文件")
upload_btn.setMinimumHeight(60)
upload_btn.setStyleSheet("font-size: 16px;")
upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload"))
layout.addWidget(record_btn)
layout.addWidget(upload_btn)
layout.addStretch(1)
self.setLayout(layout)
# 第二层界面:输入界面(录音或上传文件)
class InputWidget(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode
self.wav_path = ""
self.record_thread = None
self.device_index = None # >>> NEW: 当前选中的输入设备索引
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back_btn)
if self.mode == "record":
self.setup_record_ui(layout)
else:
self.setup_upload_ui(layout)
# 模型路径
model_layout = QHBoxLayout()
self.model_path = QLineEdit("svm_model.pkl")
self.scaler_path = QLineEdit("scaler.pkl")
model_layout.addWidget(QLabel("模型路径:"))
model_layout.addWidget(self.model_path)
model_layout.addWidget(QLabel("标准化器路径:"))
model_layout.addWidget(self.scaler_path)
layout.addLayout(model_layout)
# 处理按钮
self.process_btn = QPushButton("开始处理")
self.process_btn.setMinimumHeight(50)
self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;")
self.process_btn.clicked.connect(self.start_process)
self.process_btn.setEnabled(False) # 初始不可用
layout.addWidget(self.process_btn)
# 日志区域
self.log_area = QTextEdit()
self.log_area.setReadOnly(True)
layout.addWidget(QLabel("日志:"))
layout.addWidget(self.log_area)
self.setLayout(layout)
def setup_record_ui(self, layout):
title = QLabel("音频采集")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# >>> NEW: 麦克风选择
device_layout = QHBoxLayout()
device_layout.addWidget(QLabel("输入设备:"))
self.device_combo = QComboBox()
self.refresh_devices()
self.device_combo.currentIndexChanged.connect(self.on_device_changed)
device_layout.addWidget(self.device_combo)
refresh_btn = QPushButton("刷新设备")
refresh_btn.clicked.connect(self.refresh_devices)
device_layout.addWidget(refresh_btn)
layout.addLayout(device_layout)
# 录音控制
record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)")
record_hint.setAlignment(Qt.AlignCenter)
self.record_btn = QPushButton("按住录音")
self.record_btn.setMinimumHeight(80)
self.record_btn.setStyleSheet("""
QPushButton {
background-color: #ff4d4d;
color: white;
font-size: 18px;
border-radius: 10px;
}
QPushButton:pressed {
background-color: #cc0000;
}
""")
self.record_btn.mousePressEvent = self.start_recording
self.record_btn.mouseReleaseEvent = self.stop_recording
self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu)
# >>> NEW: 实时电平与录音进度
self.mic_level = QProgressBar()
self.mic_level.setRange(0, 100)
self.mic_level.setFormat("麦克风电平:%p%")
self.record_progress = QProgressBar()
self.record_progress.setRange(0, 100)
self.record_progress.setValue(0)
self.record_duration_label = QLabel("录音时长: 0.0秒")
self.record_duration_label.setAlignment(Qt.AlignCenter)
layout.addWidget(record_hint)
layout.addWidget(self.record_btn)
layout.addWidget(self.mic_level) # >>> NEW
layout.addWidget(self.record_progress)
layout.addWidget(self.record_duration_label)
# 录音计时器
self.record_timer = QTimer(self)
self.record_timer.timeout.connect(self.update_record_duration)
self.record_start_time = 0
def refresh_devices(self):
"""枚举有输入通道的设备,并填充到下拉框"""
self.device_combo.clear()
try:
p = pyaudio.PyAudio()
default_host_api = p.get_host_api_info_by_index(0)
default_input_index = default_host_api.get("defaultInputDevice", -1)
found = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
name = info.get('name', f"Device {i}")
sr = int(info.get('defaultSampleRate', 0))
label = f"[{i}] {name} (sr={sr})"
self.device_combo.addItem(label, i)
found.append(i)
p.terminate()
# 选中默认输入设备(若存在)
if default_input_index in found:
idx = found.index(default_input_index)
self.device_combo.setCurrentIndex(idx)
self.device_index = default_input_index
elif found:
self.device_combo.setCurrentIndex(0)
self.device_index = self.device_combo.currentData()
else:
self.device_index = None
except Exception:
self.device_index = None
def on_device_changed(self, _):
self.device_index = self.device_combo.currentData()
def setup_upload_ui(self, layout):
title = QLabel("上传WAV文件")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 文件选择
file_layout = QHBoxLayout()
self.file_path = QLineEdit()
self.file_path.setReadOnly(True)
self.browse_btn = QPushButton("浏览WAV文件")
self.browse_btn.clicked.connect(self.browse_file)
file_layout.addWidget(self.file_path)
file_layout.addWidget(self.browse_btn)
layout.addLayout(file_layout)
def start_recording(self, event):
if event.button() == Qt.LeftButton:
if not self.record_thread or not self.record_thread.isRunning():
# >>> CHANGED: 传入 device_index
self.record_thread = RecordThread(max_duration=60, device_index=self.device_index)
self.record_thread.update_signal.connect(self.update_log)
self.record_thread.finish_signal.connect(self.on_recording_finish)
self.record_thread.progress_signal.connect(self.record_progress.setValue)
self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW
self.record_thread.start()
self.record_start_time = time.time()
self.record_timer.start(100)
self.update_log("开始录音...(松开按钮结束)")
def stop_recording(self, event):
if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning():
self.record_thread.is_recording = False
self.record_timer.stop()
self.record_duration_label.setText("录音时长: 0.0秒")
self.record_progress.setValue(0)
# 让电平逐步回落
self.mic_level.setValue(0)
def update_record_duration(self):
elapsed = time.time() - self.record_start_time
self.record_duration_label.setText(f"录音时长: {elapsed:.1f}")
def on_recording_finish(self, temp_file):
self.wav_path = temp_file
self.process_btn.setEnabled(True)
def browse_file(self):
file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)")
if file:
self.file_path.setText(file)
self.wav_path = file
self.process_btn.setEnabled(True)
def start_process(self):
if self.mode == "upload":
self.wav_path = self.file_path.text()
model_path = self.model_path.text()
scaler_path = self.scaler_path.text()
if not self.wav_path or not os.path.exists(self.wav_path):
self.log_area.append("请先录音或选择有效的WAV文件")
return
if not os.path.exists(model_path) or not os.path.exists(scaler_path):
self.log_area.append("模型文件不存在请先运行train_model.py训练模型")
return
self.log_area.clear()
self.process_btn.setEnabled(False)
self.thread = ProcessThread(self.wav_path, model_path, scaler_path)
self.thread.update_signal.connect(self.update_log)
self.thread.finish_signal.connect(self.on_process_finish)
self.thread.start()
def update_log(self, msg):
self.log_area.append(msg)
def on_process_finish(self, result):
self.parent.switch_to_result(result)
# 第三层界面:结果显示界面
class ResultWidget(QWidget):
def __init__(self, parent=None, result=None):
super().__init__(parent)
self.parent = parent
self.result = result
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_input_from_result)
layout.addWidget(back_btn)
title = QLabel("处理结果")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 结果显示
self.result_area = QTextEdit()
self.result_area.setReadOnly(True)
if self.result:
res = f"文件名: {self.result['filename']}\n"
res += f"片段数: {self.result['segments']}\n"
res += "预测结果:\n"
for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])):
res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n"
res += f"\n平均概率: {self.result['mean_prob']}\n"
res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}"
self.result_area.setText(res)
layout.addWidget(self.result_area)
self.setLayout(layout)
# 主窗口
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record" # 记录当前输入模式
self.process_result = None # 存储处理结果
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 800, 600)
# 创建堆叠窗口
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
# 创建三个界面
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
# 添加到堆叠窗口
self.stacked_widget.addWidget(self.main_menu_widget)
self.stacked_widget.addWidget(self.record_input_widget)
self.stacked_widget.addWidget(self.upload_input_widget)
self.stacked_widget.addWidget(self.result_widget)
# 显示主菜单
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
# 移除旧的结果界面并添加新的
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_input_from_result(self):
if self.current_input_mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
if __name__ == "__main__":
try:
import pyaudio # 确保已安装
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())

@ -1,917 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import csv
import sqlite3
from datetime import datetime
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox, QTableWidget, QTableWidgetItem, QMessageBox,
QDialog, QFormLayout, QDialogButtonBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# ======================
# 本地 SQLite 数据库(含 CRUD
# ======================
class DatabaseManager:
def __init__(self, db_path="results.db"):
self.db_path = db_path
self._ensure_schema()
def _connect(self):
return sqlite3.connect(self.db_path)
def _ensure_schema(self):
con = self._connect()
cur = con.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT,
segments INTEGER,
mean_prob REAL,
final_label INTEGER,
created_at TEXT
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS run_segments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id INTEGER,
seg_index INTEGER,
label INTEGER,
proba REAL,
FOREIGN KEY(run_id) REFERENCES runs(id)
);
""")
con.commit()
con.close()
# ---------- 原有写入 ----------
def insert_result(self, result_dict):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(
result_dict.get("filename"),
int(result_dict.get("segments", 0)),
float(result_dict.get("mean_prob", 0.0)),
int(result_dict.get("final_label", 0)),
datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)
)
run_id = cur.lastrowid
preds = result_dict.get("predictions", [])
probas = result_dict.get("probabilities", [])
for i, (lab, pr) in enumerate(zip(preds, probas), start=1):
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(run_id, i, int(lab), float(pr))
)
con.commit()
con.close()
return run_id
# ---------- R查询列表/单条) ----------
def fetch_recent_runs(self, limit=50):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs ORDER BY id DESC LIMIT ?
""", (limit,))
rows = cur.fetchall()
con.close()
return rows
def get_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs WHERE id = ?
""", (run_id,))
row = cur.fetchone()
con.close()
return row
def search_runs(self, keyword: str, limit=100):
kw = f"%{keyword}%"
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs
WHERE filename LIKE ?
ORDER BY id DESC LIMIT ?
""", (kw, limit))
rows = cur.fetchall()
con.close()
return rows
# ---------- C/U/Drun ----------
def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None):
created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(filename, int(segments), float(mean_prob), int(final_label), created_at)
)
run_id = cur.lastrowid
con.commit()
con.close()
return run_id
def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE runs
SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=?
WHERE id=?
""", (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id)))
con.commit()
con.close()
def delete_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
# 先删子表
cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),))
cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),))
con.commit()
con.close()
# ---------- C/R/U/Dsegments ----------
def list_segments(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, seg_index, label, proba FROM run_segments
WHERE run_id=? ORDER BY seg_index ASC
""", (int(run_id),))
rows = cur.fetchall()
con.close()
return rows
def add_segment(self, run_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(int(run_id), int(seg_index), int(label), float(proba))
)
seg_id = cur.lastrowid
con.commit()
con.close()
return seg_id
def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=?
""", (int(seg_index), int(label), float(proba), int(seg_id)))
con.commit()
con.close()
def delete_segment(self, seg_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),))
con.commit()
con.close()
# ---------- 导出 ----------
def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000):
rows = self.fetch_recent_runs(limit)
headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"]
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(headers)
for r in rows:
writer.writerow(r)
return csv_path
DB = DatabaseManager("results.db")
# ======================
# 特征提取(同前)
# ======================
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray):
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
rms = np.sqrt(np.mean(x ** 2))
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive = freq >= 0
freq, fft_mag = freq[positive], np.abs(fft[positive])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# ======================
# 录音线程(同前)
# ======================
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int)
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index
def run(self):
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p, stream = None, None
try:
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True,
input_device_index=self.device_index, frames_per_buffer=CHUNK)
self.is_recording = True
start_time = time.time()
frames = []
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False)
frames.append(data)
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
self.level_signal.emit(int(np.clip(rms * 500, 0, 100)))
stream.stop_stream(); stream.close(); p.terminate()
if frames:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录到有效音频。")
except Exception as e:
self.update_signal.emit(f"录制错误: {e}")
# ======================
# 推理处理线程(写入 DB
# ======================
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
sig = sig / max(np.max(np.abs(sig)), 1e-9)
ene = FeatureExtractor.frame_energy(sig)
hits = FeatureExtractor.detect_hits(ene)
segs = FeatureExtractor.extract_segments(sig, sr, hits)
if not segs:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs]
X = np.vstack(feats)
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
y_proba = (y_pred.astype(float) + 0.0)
result = {
"filename": os.path.basename(self.wav_path),
"segments": len(segs),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
}
DB.insert_result(result)
self.finish_signal.emit(result)
except Exception as e:
self.update_signal.emit(f"错误: {e}")
# ======================
# 对话框:编辑 run
# ======================
class RunEditDialog(QDialog):
def __init__(self, parent=None, run_row=None):
"""
run_row: (id, created_at, filename, segments, mean_prob, final_label) None新增
"""
super().__init__(parent)
self.setWindowTitle("编辑记录" if run_row else "新增记录")
self.resize(380, 240)
self.run_row = run_row
self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True)
self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.ed_filename = QLineEdit()
self.ed_segments = QLineEdit("0")
self.ed_mean_prob = QLineEdit("0.0")
self.ed_final_label = QLineEdit("0") # 0=实心 1=空心
form = QFormLayout()
form.addRow("ID", self.ed_id)
form.addRow("时间(created_at)", self.ed_created)
form.addRow("文件名(filename)", self.ed_filename)
form.addRow("片段数(segments)", self.ed_segments)
form.addRow("平均概率(mean_prob)", self.ed_mean_prob)
form.addRow("最终标签(0=实 1=空)", self.ed_final_label)
if run_row:
self.ed_id.setText(str(run_row[0]))
self.ed_created.setText(str(run_row[1]))
self.ed_filename.setText(str(run_row[2]))
self.ed_segments.setText(str(run_row[3]))
self.ed_mean_prob.setText(str(run_row[4]))
self.ed_final_label.setText(str(run_row[5]))
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
lay = QVBoxLayout()
lay.addLayout(form)
lay.addWidget(btns)
self.setLayout(lay)
def values(self):
return dict(
id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None),
created_at=self.ed_created.text().strip(),
filename=self.ed_filename.text().strip(),
segments=int(float(self.ed_segments.text() or 0)),
mean_prob=float(self.ed_mean_prob.text() or 0.0),
final_label=int(float(self.ed_final_label.text() or 0))
)
# ======================
# 对话框:管理 segments
# ======================
class SegmentManagerDialog(QDialog):
def __init__(self, parent=None, run_id=None):
super().__init__(parent)
self.setWindowTitle(f"片段管理 - run_id={run_id}")
self.resize(560, 420)
self.run_id = run_id
self.table = QTableWidget(0, 4)
self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"])
self.btn_add = QPushButton("新增片段")
self.btn_edit = QPushButton("编辑所选")
self.btn_del = QPushButton("删除所选")
self.btn_close = QPushButton("关闭")
hl = QHBoxLayout()
for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close):
hl.addWidget(b)
lay = QVBoxLayout()
lay.addWidget(self.table)
lay.addLayout(hl)
self.setLayout(lay)
self.btn_add.clicked.connect(self.add_segment)
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_del.clicked.connect(self.delete_selected)
self.btn_close.clicked.connect(self.accept)
self.load_segments()
def load_segments(self):
rows = DB.list_segments(self.run_id)
self.table.setRowCount(len(rows))
for i, (sid, idx, label, proba) in enumerate(rows):
self.table.setItem(i, 0, QTableWidgetItem(str(sid)))
self.table.setItem(i, 1, QTableWidgetItem(str(idx)))
self.table.setItem(i, 2, QTableWidgetItem(str(label)))
self.table.setItem(i, 3, QTableWidgetItem(str(proba)))
self.table.resizeColumnsToContents()
def _get_selected_seg_id(self):
r = self.table.currentRow()
if r < 0:
return None
return int(self.table.item(r, 0).text())
def add_segment(self):
d = QDialog(self)
d.setWindowTitle("新增片段")
form = QFormLayout()
ed_index = QLineEdit("1")
ed_label = QLineEdit("0")
ed_proba = QLineEdit("0.0")
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)),
float(ed_proba.text() or 0.0))
self.load_segments()
def edit_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
# 读当前值
row = self.table.currentRow()
cur_idx = self.table.item(row, 1).text()
cur_label = self.table.item(row, 2).text()
cur_proba = self.table.item(row, 3).text()
d = QDialog(self)
d.setWindowTitle("编辑片段")
form = QFormLayout()
ed_index = QLineEdit(cur_idx)
ed_label = QLineEdit(cur_label)
ed_proba = QLineEdit(cur_proba)
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)),
int(float(ed_label.text() or cur_label)),
float(ed_proba.text() or cur_proba))
self.load_segments()
def delete_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!",
QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes:
DB.delete_segment(seg_id)
self.load_segments()
# ======================
# 主菜单
# ======================
class MainMenuWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
layout = QVBoxLayout()
title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record"))
btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload"))
btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history)
for b in [btn1, btn2, btn3]:
b.setMinimumHeight(60); layout.addWidget(b)
layout.addStretch(1); self.setLayout(layout)
# ======================
# 输入界面(录音 / 上传)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
# ======================
# 历史记录(含 CRUD 按钮)
# ======================
class HistoryWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
self.parent=parent
self.init_ui()
def init_ui(self):
l=QVBoxLayout()
back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back)
top=QHBoxLayout()
self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...")
btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search)
btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load)
btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv)
btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run)
btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected)
btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected)
btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments)
for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]:
top.addWidget(w)
l.addLayout(top)
self.table=QTableWidget(0,6)
self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"])
l.addWidget(self.table)
self.setLayout(l)
self.load()
def _current_run_id(self):
r=self.table.currentRow()
if r<0: return None
return int(self.table.item(r,0).text())
# ---- 数据加载 / 搜索 ----
def load(self):
rows=DB.fetch_recent_runs(200)
self._fill(rows)
def search(self):
kw=self.search_edit.text().strip()
rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200)
self._fill(rows)
def _fill(self, rows):
self.table.setRowCount(len(rows))
for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows):
self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid)))
self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at)))
self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn)))
self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs)))
self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob)))
self.table.setItem(r_idx, 5, QTableWidgetItem("空心" if int(final_label) else "实心"))
self.table.resizeColumnsToContents()
# ---- 导出 ----
def export_csv(self):
p=DB.export_runs_to_csv()
QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}")
# ---- CRUD: run ----
def add_run(self):
dlg=RunEditDialog(self, run_row=None)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def edit_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
row=DB.get_run(rid)
dlg=RunEditDialog(self, run_row=row)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def delete_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!",
QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes:
DB.delete_run(rid)
self.load()
# ---- 片段管理 ----
def open_segments(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
dlg=SegmentManagerDialog(self, run_id=rid)
dlg.exec_()
# 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主
self.load()
# ======================
# 结果页(简)
# ======================
class ResultWidget(QWidget):
def __init__(self, parent, result=None):
super().__init__(parent); self.parent=parent
layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b)
t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t)
area=QTextEdit(); area.setReadOnly(True)
if result:
txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'空心' if result['final_label'] else '实心'}"
area.setText(txt)
layout.addWidget(area); self.setLayout(layout)
# ======================
# 主窗口
# ======================
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record"
self.process_result = None
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 1000, 700)
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
self.history_widget = HistoryWidget(self)
self.stacked_widget.addWidget(self.main_menu_widget) # 0
self.stacked_widget.addWidget(self.record_input_widget) # 1
self.stacked_widget.addWidget(self.upload_input_widget) # 2
self.stacked_widget.addWidget(self.result_widget) # 3
self.stacked_widget.addWidget(self.history_widget) # 4
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_history(self):
self.history_widget.load()
self.stacked_widget.setCurrentWidget(self.history_widget)
# ======================
# 上传/录音 界面(放最后,避免打断阅读)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
if __name__ == "__main__":
try:
import pyaudio
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WAV转MAT预处理就绪将WAV音频预处理后保存为MAT文件含原始音频、采样率、切分片段等未提取特征
"""
from pathlib import Path
import numpy as np
from scipy.io import savemat
import librosa
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示正常
plt.rcParams['axes.unicode_minus'] = False # 负号显示正常
# ---------- 参数设定 ----------
WAV_FILE = r"D:\SummerSchool\sample\瓷空1.wav" # 输入WAV文件路径
WIN_SIZE = 1024 # 帧长(与特征提取环节保持一致)
OVERLAP = 512 # 帧移(与特征提取环节保持一致)
STEP = WIN_SIZE - OVERLAP # 帧步长
THRESH = 0.01 # 能量阈值(用于片段切分)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒)
# 输出MAT文件路径默认保存在音频同目录添加"preprocessed"标识)
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_preprocessed.mat"
def segment_preprocessed_signal(signal: np.ndarray, fs: int):
"""基于能量切分预处理后的音频片段(为特征提取准备输入)"""
# 分帧并计算帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选有效帧并定位片段起始点
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return [], energy # 无有效片段时返回空列表和能量值
# 识别新片段起始点间隔超过5帧视为新片段
hit_mask = np.diff(np.concatenate(([0], idx))) > 5
hit_starts = idx[hit_mask]
# 切分固定长度片段
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for start_frame in hit_starts:
start_sample = start_frame * STEP
end_sample = min(start_sample + seg_len, len(signal))
segments.append(signal[start_sample:end_sample])
return segments, energy
# ---------- 主程序WAV→预处理→保存MAT ----------
def main():
# 1. 检查输入文件
wav_path = Path(WAV_FILE)
if not wav_path.exists():
print(f"❌ 错误:音频文件 {WAV_FILE} 不存在!")
return
if wav_path.suffix != ".wav":
print(f"❌ 错误:{wav_path.name} 不是WAV格式")
return
# 2. 读取原始音频
raw_audio, sr = librosa.load(wav_path, sr=None, mono=False) # 保留原始声道信息用于对比
print(f"✅ 成功读取音频:{wav_path.name}")
print(f" 原始采样率:{sr} Hz | 原始长度:{len(raw_audio) / sr:.2f} 秒 | 原始声道数:{raw_audio.ndim}")
# 3. 执行预处理
processed_audio = preprocess_audio(raw_audio, sr)
print(f"✅ 音频预处理完成(双声道转单声道 + 归一化)")
print(f" 预处理后长度:{len(processed_audio) / sr:.2f} 秒 | 声道数1")
# 4. 切分预处理后的片段(为特征提取准备输入)
segments, frame_energy = segment_preprocessed_signal(processed_audio, sr)
if len(segments) == 0:
print(f"⚠️ 未检测到有效片段可降低THRESH值重试")
else:
print(f"✅ 检测到 {len(segments)} 个有效片段(每个片段约 {SEG_LEN_SEC} 秒)")
# 5. 保存为MAT文件包含所有预处理数据未提取特征
savemat(OUT_MAT, {
"raw_audio": raw_audio, # 原始音频数据(保留原始信息)
"processed_audio": processed_audio, # 预处理后音频(单声道+归一化)
"sample_rate": sr, # 采样率
"segments": segments, # 切分后的片段列表(特征提取的直接输入)
"frame_energy": frame_energy, # 帧能量数据(用于片段切分的依据)
"params": { # 预处理参数(确保特征提取时参数一致)
"win_size": WIN_SIZE,
"overlap": OVERLAP,
"step": STEP,
"threshold": THRESH,
"seg_len_sec": SEG_LEN_SEC
}
})
print(f"✅ 预处理数据已保存为MAT文件{OUT_MAT}")
# 6. 可视化预处理结果(可选)
plt.figure(figsize=(12, 8))
# 原始音频波形
plt.subplot(2, 2, 1)
plt.plot(raw_audio[:min(10000, len(raw_audio))]) # 只显示前10000个采样点
plt.title("原始音频波形")
plt.xlabel("采样点")
plt.ylabel("幅值")
plt.grid(True, alpha=0.3)
# 预处理后音频波形
plt.subplot(2, 2, 2)
plt.plot(processed_audio[:min(10000, len(processed_audio))])
plt.title("预处理后音频波形(单声道+归一化)")
plt.xlabel("采样点")
plt.ylabel("归一化幅值")
plt.grid(True, alpha=0.3)
# 帧能量分布
plt.subplot(2, 2, 3)
plt.plot(frame_energy)
plt.axhline(y=THRESH, color='r', linestyle='--', label=f'阈值={THRESH}')
plt.title("帧能量分布")
plt.xlabel("帧索引")
plt.ylabel("能量值")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WAV转MAT预处理就绪将WAV音频预处理后保存为MAT文件含原始音频采样率切分片段等未提取特征
"""
from pathlib import Path
import numpy as np
from scipy.io import savemat
import librosa
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示正常
plt.rcParams['axes.unicode_minus'] = False # 负号显示正常
# ---------- 参数设定 ----------
WAV_FILE = r"D:\SummerSchool\sample\瓷空1.wav" # 输入WAV文件路径
WIN_SIZE = 1024 # 帧长(与特征提取环节保持一致)
OVERLAP = 512 # 帧移(与特征提取环节保持一致)
STEP = WIN_SIZE - OVERLAP # 帧步长
THRESH = 0.01 # 能量阈值(用于片段切分)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒)
# 输出MAT文件路径默认保存在音频同目录添加"preprocessed"标识)
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_preprocessed.mat"
# ---------- 预处理核心函数 ----------
def preprocess_audio(signal: np.ndarray, fs: int):
"""音频预处理:双声道转单声道 + 归一化"""
# 双声道转单声道(取第一声道)
if signal.ndim > 1:
signal = signal[:, 0]
# 归一化(消除幅值差异影响)
signal = signal / (np.max(np.abs(signal)) + 1e-12)
return signal
def segment_preprocessed_signal(signal: np.ndarray, fs: int):
"""基于能量切分预处理后的音频片段(为特征提取准备输入)"""
# 分帧并计算帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选有效帧并定位片段起始点
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return [], energy # 无有效片段时返回空列表和能量值
# 识别新片段起始点间隔超过5帧视为新片段
hit_mask = np.diff(np.concatenate(([0], idx))) > 5
hit_starts = idx[hit_mask]
# 切分固定长度片段
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for start_frame in hit_starts:
start_sample = start_frame * STEP
end_sample = min(start_sample + seg_len, len(signal))
segments.append(signal[start_sample:end_sample])
return segments, energy
# ---------- 主程序WAV→预处理→保存MAT ----------
def main():
# 1. 检查输入文件
wav_path = Path(WAV_FILE)
if not wav_path.exists():
print(f"❌ 错误:音频文件 {WAV_FILE} 不存在!")
return
if wav_path.suffix != ".wav":
print(f"❌ 错误:{wav_path.name} 不是WAV格式")
return
# 2. 读取原始音频
raw_audio, sr = librosa.load(wav_path, sr=None, mono=False) # 保留原始声道信息用于对比
print(f"✅ 成功读取音频:{wav_path.name}")
print(f" 原始采样率:{sr} Hz | 原始长度:{len(raw_audio) / sr:.2f} 秒 | 原始声道数:{raw_audio.ndim}")
# 3. 执行预处理
processed_audio = preprocess_audio(raw_audio, sr)
print(f"✅ 音频预处理完成(双声道转单声道 + 归一化)")
print(f" 预处理后长度:{len(processed_audio) / sr:.2f} 秒 | 声道数1")
# 4. 切分预处理后的片段(为特征提取准备输入)
segments, frame_energy = segment_preprocessed_signal(processed_audio, sr)
if len(segments) == 0:
print(f"⚠️ 未检测到有效片段可降低THRESH值重试")
else:
print(f"✅ 检测到 {len(segments)} 个有效片段(每个片段约 {SEG_LEN_SEC} 秒)")
# 5. 保存为MAT文件包含所有预处理数据未提取特征
savemat(OUT_MAT, {
"raw_audio": raw_audio, # 原始音频数据(保留原始信息)
"processed_audio": processed_audio, # 预处理后音频(单声道+归一化)
"sample_rate": sr, # 采样率
"segments": segments, # 切分后的片段列表(特征提取的直接输入)
"frame_energy": frame_energy, # 帧能量数据(用于片段切分的依据)
"params": { # 预处理参数(确保特征提取时参数一致)
"win_size": WIN_SIZE,
"overlap": OVERLAP,
"step": STEP,
"threshold": THRESH,
"seg_len_sec": SEG_LEN_SEC
}
})
print(f"✅ 预处理数据已保存为MAT文件{OUT_MAT}")
# 6. 可视化预处理结果(可选)
plt.figure(figsize=(12, 8))
# 原始音频波形
plt.subplot(2, 2, 1)
plt.plot(raw_audio[:min(10000, len(raw_audio))]) # 只显示前10000个采样点
plt.title("原始音频波形")
plt.xlabel("采样点")
plt.ylabel("幅值")
plt.grid(True, alpha=0.3)
# 预处理后音频波形
plt.subplot(2, 2, 2)
plt.plot(processed_audio[:min(10000, len(processed_audio))])
plt.title("预处理后音频波形(单声道+归一化)")
plt.xlabel("采样点")
plt.ylabel("归一化幅值")
plt.grid(True, alpha=0.3)
# 帧能量分布
plt.subplot(2, 2, 3)
plt.plot(frame_energy)
plt.axhline(y=THRESH, color='r', linestyle='--', label=f'阈值={THRESH}')
plt.title("帧能量分布")
plt.xlabel("帧索引")
plt.ylabel("能量值")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()
Loading…
Cancel
Save