Compare commits

...

13 Commits

Author SHA1 Message Date
lyd 887f7df042 汇报PPT及自评报告等
2 months ago
lyd 6eee692f79 系统相关文档
2 months ago
lyd db1eebe9c2 UML模型
2 months ago
lyd fcf5f755cb 系统相关文档
2 months ago
lyd 5260c67e2e UML模型
2 months ago
lyd cb2a69be4c 系统相关文档
2 months ago
lyd 3fa7867141 model(存放UML模型)
2 months ago
lyd bb7726cf15 01-SRC,存放源代码及数据库等,包含:readme.txt文件(系统简介及配置环境)
2 months ago
lyd d7f2df1b7d 模型训练
2 months ago
lyd 52dba77e65 特征提取,svm学习验证
2 months ago
lyd 7e28124d0a 进行特征提取并学习
2 months ago
宋昊天 186f63b42a #!/usr/bin/env python
2 months ago
宋昊天 6e1cca5545 #!/usr/bin/env python
2 months ago

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
使用 75% 训练 / 25% 测试 的方式评估 SVM输出 ACC & AUC
"""
import pickle
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, roc_auc_score
# ---------- 1. 数据路径 ----------
PKL_PATH = Path(r"D:\Python\空心检测\pythonProject\feature_dataset.pkl")
# ---------- 2. 读取特征 ----------
def load_pkl_matrix(path: Path):
with open(path, "rb") as f:
data = pickle.load(f)
return data["matrix"], data["label"]
X, y = load_pkl_matrix(PKL_PATH)
y = y.ravel() # shape (N,)
# ---------- 3. 75% / 25% 拆分 ----------
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=42, stratify=y, shuffle=True
)
# ---------- 4. 标准化 + SVM ----------
scaler = StandardScaler().fit(X_train)
X_train_std = scaler.transform(X_train)
X_test_std = scaler.transform(X_test)
svm = SVC(
kernel="rbf",
C=10,
gamma="scale",
probability=True,
class_weight="balanced",
random_state=42,
)
svm.fit(X_train_std, y_train)
# ---------- 5. 评估 ----------
y_pred = svm.predict(X_test_std)
y_proba_pos = svm.predict_proba(X_test_std)[:, list(svm.classes_).index(1)]
acc = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_proba_pos)
print("\n========== 评估结果 ==========")
print(f"样本总数: {len(y)} | 训练: {len(y_train)} 测试: {len(y_test)}")
print(f"ACC = {acc:.4f}")
print(f"AUC = {auc:.4f}")

Binary file not shown.

@ -0,0 +1,945 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
import sqlite3
from datetime import datetime
from scipy.signal import hilbert, find_peaks
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout,
QWidget, QProgressBar, QStackedWidget, QMessageBox,
QTableWidget, QTableWidgetItem, QHeaderView,
QLineEdit, QDialog, QDialogButtonBox, QFormLayout,
QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
# ---------------------- 数据库管理模块 ----------------------
class DatabaseManager:
"""数据库管理器,负责结果数据的存储与查询"""
def __init__(self, db_path="audio_classification.db"):
self.db_path = db_path
self.init_database()
def init_database(self):
"""初始化数据库表结构"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建结果表
cursor.execute('''
CREATE TABLE IF NOT EXISTS classification_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
segment_count INTEGER NOT NULL,
segment_labels TEXT NOT NULL,
segment_probs TEXT NOT NULL,
mean_probability REAL NOT NULL,
final_label INTEGER NOT NULL,
label_text TEXT NOT NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
notes TEXT
)
''')
conn.commit()
conn.close()
def insert_result(self, result_data, notes=""):
"""插入新的分类结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO classification_results
(filename, segment_count, segment_labels, segment_probs,
mean_probability, final_label, label_text, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
result_data['filename'],
result_data['segment_count'],
str(result_data['segment_labels']),
str(result_data['segment_probs']),
result_data['mean_probability'],
result_data['final_label'],
result_data['label_text'],
notes
))
conn.commit()
conn.close()
return cursor.lastrowid
def get_all_results(self):
"""获取所有结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM classification_results
ORDER BY create_time DESC
''')
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def search_results(self, filename_filter="", label_filter=None, date_filter=None):
"""根据条件搜索结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = '''
SELECT * FROM classification_results
WHERE 1=1
'''
params = []
if filename_filter:
query += " AND filename LIKE ?"
params.append(f'%{filename_filter}%')
if label_filter is not None:
query += " AND final_label = ?"
params.append(label_filter)
if date_filter:
query += " AND DATE(create_time) = ?"
params.append(date_filter)
query += " ORDER BY create_time DESC"
cursor.execute(query, params)
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
conn.close()
return results
def update_result(self, result_id, updates):
"""更新结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
set_clause = ", ".join([f"{key} = ?" for key in updates.keys()])
query = f"UPDATE classification_results SET {set_clause} WHERE id = ?"
params = list(updates.values())
params.append(result_id)
cursor.execute(query, params)
conn.commit()
conn.close()
return cursor.rowcount > 0
def delete_result(self, result_id):
"""删除结果记录"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,))
conn.commit()
conn.close()
return cursor.rowcount > 0
# ---------------------- 特征提取模块(与训练集保持一致) ----------------------
class FeatureProcessor:
"""特征提取器,统一使用训练时的特征维度和计算方式"""
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
"""计算帧能量"""
frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN,
hop_length=FeatureProcessor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_impacts(energy: np.ndarray) -> np.ndarray:
"""检测有效敲击片段起始点"""
idx = np.where(energy > FeatureProcessor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
# 间隔超过5帧视为新片段
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]:
"""切分固定长度的音频片段"""
seg_len = int(FeatureProcessor.SEG_LEN_S * sr)
segments = []
for frame_idx in starts:
start = frame_idx * FeatureProcessor.STEP
end = min(start + seg_len, len(signal))
segments.append(signal[start:end])
return segments
@staticmethod
def extract_features(signal: np.ndarray, sr: int) -> np.ndarray:
"""提取与训练集一致的5维特征修正原GUI可能的特征不匹配问题"""
signal = signal.flatten()
if len(signal) == 0:
return np.zeros(5, dtype=np.float32)
# 1. RMS能量
rms = np.sqrt(np.mean(signal ** 2))
# 2. 主频(频谱峰值)
fft = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), 1 / sr)
positive_mask = freq >= 0
freq = freq[positive_mask]
fft_mag = np.abs(fft[positive_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) /
(np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(signal)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
# ---------------------- 录音线程 ----------------------
class AudioRecorder(QThread):
status_updated = pyqtSignal(str)
progress_updated = pyqtSignal(int)
recording_finished = pyqtSignal(str)
def __init__(self, max_duration=60):
super().__init__()
self.max_duration = max_duration
self.recording = False
self.temp_file = "temp_audio.wav"
def run(self):
# 音频参数(与特征提取兼容)
FORMAT = pyaudio.paFloat32
CHANNELS = 1
RATE = 22050
CHUNK = 1024
try:
audio = pyaudio.PyAudio()
stream = audio.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK
)
self.recording = True
start_time = time.time()
frames = []
while self.recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.status_updated.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_updated.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK)
frames.append(data)
# 停止录音并保存
stream.stop_stream()
stream.close()
audio.terminate()
if frames:
with wave.open(self.temp_file, 'wb') as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(audio.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}")
self.recording_finished.emit(self.temp_file)
else:
self.status_updated.emit("未检测到音频输入")
except Exception as e:
self.status_updated.emit(f"录音错误: {str(e)}")
def stop(self):
self.recording = False
# ---------------------- 音频处理与预测线程 ----------------------
class AudioProcessor(QThread):
status_updated = pyqtSignal(str)
result_generated = pyqtSignal(dict)
def __init__(self, audio_path, model_path, scaler_path):
super().__init__()
self.audio_path = audio_path
self.model_path = model_path
self.scaler_path = scaler_path
self.class_threshold = 0.5 # 分类阈值
def run(self):
try:
# 加载模型和标准化器
self.status_updated.emit("加载模型资源...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
# 读取音频文件
self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}")
signal, sr = librosa.load(self.audio_path, sr=None, mono=True)
signal = signal / np.max(np.abs(signal)) # 归一化
# 提取片段
self.status_updated.emit("检测有效音频片段...")
energy = FeatureProcessor.frame_energy(signal)
impact_starts = FeatureProcessor.detect_impacts(energy)
segments = FeatureProcessor.extract_segments(signal, sr, impact_starts)
if not segments:
self.status_updated.emit("未检测到有效敲击片段")
return
# 提取特征并预测
self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...")
features = [FeatureProcessor.extract_features(seg, sr) for seg in segments]
X = np.vstack(features)
# 标准化特征
X_scaled = scaler.transform(X)
# 模型预测(处理标签转换:-1/1 → 0/1
predictions = model.predict(X_scaled)
pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率
pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签
# 计算文件级结果
mean_prob = pred_proba.mean()
final_label = 1 if mean_prob >= self.class_threshold else 0
result = {
"filename": os.path.basename(self.audio_path),
"segment_count": len(segments),
"segment_labels": pred_labels.tolist(),
"segment_probs": [round(p, 4) for p in pred_proba],
"mean_probability": round(mean_prob, 4),
"final_label": final_label,
"label_text": "空心" if final_label == 0 else "实心" # 假设0=空心1=实心
}
self.result_generated.emit(result)
except Exception as e:
self.status_updated.emit(f"处理错误: {str(e)}")
# ---------------------- 数据库编辑对话框 ----------------------
class EditResultDialog(QDialog):
def __init__(self, result_data=None, parent=None):
super().__init__(parent)
self.result_data = result_data or {}
self.init_ui()
def init_ui(self):
self.setWindowTitle("编辑结果记录")
self.setModal(True)
self.resize(400, 300)
layout = QFormLayout()
# 文件名
self.filename_edit = QLineEdit(self.result_data.get('filename', ''))
layout.addRow("文件名:", self.filename_edit)
# 片段数量
self.segment_count_spin = QSpinBox()
self.segment_count_spin.setRange(0, 1000)
self.segment_count_spin.setValue(self.result_data.get('segment_count', 0))
layout.addRow("片段数量:", self.segment_count_spin)
# 平均概率
self.mean_prob_spin = QDoubleSpinBox()
self.mean_prob_spin.setRange(0.0, 1.0)
self.mean_prob_spin.setDecimals(4)
self.mean_prob_spin.setSingleStep(0.01)
self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0))
layout.addRow("平均概率:", self.mean_prob_spin)
# 最终标签
self.final_label_combo = QComboBox()
self.final_label_combo.addItem("空心", 0)
self.final_label_combo.addItem("实心", 1)
current_label = self.result_data.get('final_label', 0)
index = self.final_label_combo.findData(current_label)
if index >= 0:
self.final_label_combo.setCurrentIndex(index)
layout.addRow("最终分类:", self.final_label_combo)
# 备注
self.notes_edit = QTextEdit()
self.notes_edit.setMaximumHeight(80)
self.notes_edit.setText(self.result_data.get('notes', ''))
layout.addRow("备注:", self.notes_edit)
# 按钮
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
layout.addRow(button_box)
self.setLayout(layout)
def get_updated_data(self):
"""获取更新后的数据"""
return {
'filename': self.filename_edit.text(),
'segment_count': self.segment_count_spin.value(),
'mean_probability': self.mean_prob_spin.value(),
'final_label': self.final_label_combo.currentData(),
'label_text': self.final_label_combo.currentText(),
'notes': self.notes_edit.toPlainText()
}
# ---------------------- 数据库汇总界面 ----------------------
class DatabaseUI(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.db_manager = DatabaseManager()
self.init_ui()
self.load_data()
def init_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
layout.addWidget(btn_back)
# 标题
title = QLabel("数据库汇总管理")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
layout.addWidget(title)
# 创建标签页
self.tab_widget = QTabWidget()
# 数据浏览标签页
self.browse_tab = QWidget()
self.init_browse_tab()
self.tab_widget.addTab(self.browse_tab, "数据浏览")
# 搜索标签页
self.search_tab = QWidget()
self.init_search_tab()
self.tab_widget.addTab(self.search_tab, "搜索过滤")
layout.addWidget(self.tab_widget)
self.setLayout(layout)
def init_browse_tab(self):
layout = QVBoxLayout()
# 操作按钮
btn_layout = QHBoxLayout()
self.btn_refresh = QPushButton("刷新数据")
self.btn_refresh.clicked.connect(self.load_data)
self.btn_add = QPushButton("新增记录")
self.btn_add.clicked.connect(self.add_result)
self.btn_edit = QPushButton("编辑选中")
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_delete = QPushButton("删除选中")
self.btn_delete.clicked.connect(self.delete_selected)
btn_layout.addWidget(self.btn_refresh)
btn_layout.addWidget(self.btn_add)
btn_layout.addWidget(self.btn_edit)
btn_layout.addWidget(self.btn_delete)
btn_layout.addStretch()
layout.addLayout(btn_layout)
# 数据表格
self.table_widget = QTableWidget()
self.table_widget.setColumnCount(8)
self.table_widget.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.table_widget.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.table_widget)
self.browse_tab.setLayout(layout)
def init_search_tab(self):
layout = QVBoxLayout()
# 搜索条件表单
form_layout = QFormLayout()
self.search_filename = QLineEdit()
self.search_filename.setPlaceholderText("输入文件名关键词")
form_layout.addRow("文件名:", self.search_filename)
self.search_label = QComboBox()
self.search_label.addItem("全部", None)
self.search_label.addItem("空心", 0)
self.search_label.addItem("实心", 1)
form_layout.addRow("分类结果:", self.search_label)
self.search_date = QLineEdit()
self.search_date.setPlaceholderText("YYYY-MM-DD")
form_layout.addRow("创建日期:", self.search_date)
layout.addLayout(form_layout)
# 搜索按钮
btn_search = QPushButton("搜索")
btn_search.clicked.connect(self.perform_search)
layout.addWidget(btn_search)
# 搜索结果表格
self.search_table = QTableWidget()
self.search_table.setColumnCount(8)
self.search_table.setHorizontalHeaderLabels([
"ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作"
])
self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.search_table.setSelectionBehavior(QTableWidget.SelectRows)
layout.addWidget(self.search_table)
self.search_tab.setLayout(layout)
def load_data(self):
"""加载所有数据到浏览表格"""
results = self.db_manager.get_all_results()
self.populate_table(self.table_widget, results)
def perform_search(self):
"""执行搜索操作"""
filename_filter = self.search_filename.text().strip()
label_filter = self.search_label.currentData()
date_filter = self.search_date.text().strip() or None
results = self.db_manager.search_results(filename_filter, label_filter, date_filter)
self.populate_table(self.search_table, results)
def populate_table(self, table, results):
"""填充表格数据"""
table.setRowCount(len(results))
for row, result in enumerate(results):
# ID
table.setItem(row, 0, QTableWidgetItem(str(result['id'])))
# 文件名
table.setItem(row, 1, QTableWidgetItem(result['filename']))
# 片段数
table.setItem(row, 2, QTableWidgetItem(str(result['segment_count'])))
# 平均概率
table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}"))
# 分类结果
table.setItem(row, 4, QTableWidgetItem(result['label_text']))
# 创建时间
create_time = result['create_time']
if isinstance(create_time, str):
table.setItem(row, 5, QTableWidgetItem(create_time))
else:
table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S")))
# 备注
notes = result.get('notes', '')
table.setItem(row, 6, QTableWidgetItem(notes if notes else ""))
# 操作按钮
btn_view = QPushButton("查看详情")
btn_view.clicked.connect(lambda checked, r=result: self.view_details(r))
table.setCellWidget(row, 7, btn_view)
def add_result(self):
"""新增结果记录"""
dialog = EditResultDialog()
if dialog.exec_() == QDialog.Accepted:
new_data = dialog.get_updated_data()
# 设置一些默认值
new_data['segment_labels'] = []
new_data['segment_probs'] = []
try:
self.db_manager.insert_result(new_data, new_data.get('notes', ''))
self.load_data()
QMessageBox.information(self, "成功", "记录添加成功!")
except Exception as e:
QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}")
def edit_selected(self):
"""编辑选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
results = self.db_manager.search_results()
result_data = next((r for r in results if r['id'] == result_id), None)
if not result_data:
QMessageBox.warning(self, "错误", "未找到选中的记录!")
return
dialog = EditResultDialog(result_data)
if dialog.exec_() == QDialog.Accepted:
updated_data = dialog.get_updated_data()
try:
success = self.db_manager.update_result(result_id, updated_data)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录更新成功!")
else:
QMessageBox.warning(self, "错误", "更新记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}")
def delete_selected(self):
"""删除选中的记录"""
current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget()
current_row = current_table.currentRow()
if current_row < 0:
QMessageBox.warning(self, "警告", "请先选择一条记录!")
return
result_id = int(current_table.item(current_row, 0).text())
filename = current_table.item(current_row, 1).text()
reply = QMessageBox.question(
self, "确认删除",
f"确定要删除文件 '{filename}' 的记录吗?",
QMessageBox.Yes | QMessageBox.No
)
if reply == QMessageBox.Yes:
try:
success = self.db_manager.delete_result(result_id)
if success:
self.load_data()
if current_table == self.search_table:
self.perform_search()
QMessageBox.information(self, "成功", "记录删除成功!")
else:
QMessageBox.warning(self, "错误", "删除记录失败!")
except Exception as e:
QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}")
def view_details(self, result_data):
"""查看记录详情"""
details = (
f"记录ID: {result_data['id']}\n"
f"文件名: {result_data['filename']}\n"
f"片段数量: {result_data['segment_count']}\n"
f"片段标签: {result_data['segment_labels']}\n"
f"片段概率: {result_data['segment_probs']}\n"
f"平均概率: {result_data['mean_probability']:.4f}\n"
f"最终分类: {result_data['label_text']}\n"
f"创建时间: {result_data['create_time']}\n"
f"备注: {result_data.get('notes', '')}"
)
QMessageBox.information(self, "记录详情", details)
# ---------------------- 主界面组件 ----------------------
class MainMenu(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
layout.setSpacing(20)
# 标题
title = QLabel("音频分类系统")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;")
layout.addWidget(title)
# 功能按钮
btn_record = QPushButton("录制音频")
btn_record.setMinimumHeight(70)
btn_record.setStyleSheet("font-size: 18px;")
btn_record.clicked.connect(lambda: self.parent.switch_page("record"))
btn_upload = QPushButton("上传音频文件")
btn_upload.setMinimumHeight(70)
btn_upload.setStyleSheet("font-size: 18px;")
btn_upload.clicked.connect(lambda: self.parent.switch_page("upload"))
btn_database = QPushButton("数据库管理")
btn_database.setMinimumHeight(70)
btn_database.setStyleSheet("font-size: 18px;")
btn_database.clicked.connect(lambda: self.parent.switch_page("database"))
layout.addWidget(btn_record)
layout.addWidget(btn_upload)
layout.addWidget(btn_database)
layout.addStretch()
self.setLayout(layout)
class InputPage(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode # "record" 或 "upload"
self.audio_path = ""
self.recorder = None
self.db_manager = DatabaseManager() # 添加数据库管理器
self.init_ui()
def init_ui(self):
# 主布局
main_layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
# 返回按钮
btn_back = QPushButton("← 返回主菜单")
btn_back.clicked.connect(lambda: self.parent.switch_page("main"))
main_layout.addWidget(btn_back)
# 标题
title = QLabel("录制音频" if self.mode == "record" else "上传音频文件")
title.setStyleSheet("font-size: 20px; margin: 15px 0;")
main_layout.addWidget(title)
# 操作区域
if self.mode == "record":
# 录音控制
self.btn_start_rec = QPushButton("开始录音")
self.btn_start_rec.clicked.connect(self.start_recording)
self.btn_stop_rec = QPushButton("停止录音")
self.btn_stop_rec.setEnabled(False)
self.btn_stop_rec.clicked.connect(self.stop_recording)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
# 录音布局
rec_layout = QHBoxLayout()
rec_layout.addWidget(self.btn_start_rec)
rec_layout.addWidget(self.btn_stop_rec)
main_layout.addLayout(rec_layout)
main_layout.addWidget(self.progress_bar)
else:
# 上传控制
self.btn_browse = QPushButton("选择WAV文件")
self.btn_browse.clicked.connect(self.browse_file)
self.lbl_file = QLabel("未选择文件")
self.lbl_file.setStyleSheet("color: #666;")
main_layout.addWidget(self.btn_browse)
main_layout.addWidget(self.lbl_file)
# 状态显示
self.status_display = QTextEdit()
self.status_display.setReadOnly(True)
self.status_display.setMinimumHeight(150)
main_layout.addWidget(QLabel("状态信息:"))
main_layout.addWidget(self.status_display)
# 处理按钮
self.btn_process = QPushButton("开始分析")
self.btn_process.setEnabled(False)
self.btn_process.clicked.connect(self.process_audio)
main_layout.addWidget(self.btn_process)
main_layout.addStretch()
self.setLayout(main_layout)
def start_recording(self):
self.status_display.append("开始录音...")
self.recorder = AudioRecorder(max_duration=60)
self.recorder.status_updated.connect(self.update_status)
self.recorder.progress_updated.connect(self.progress_bar.setValue)
self.recorder.recording_finished.connect(self.on_recording_finished)
self.recorder.start()
self.btn_start_rec.setEnabled(False)
self.btn_stop_rec.setEnabled(True)
def stop_recording(self):
if self.recorder and self.recorder.recording:
self.recorder.stop()
self.btn_start_rec.setEnabled(True)
self.btn_stop_rec.setEnabled(False)
def on_recording_finished(self, file_path):
self.audio_path = file_path
self.btn_process.setEnabled(True)
self.status_display.append(f"录音文件已保存: {file_path}")
def browse_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "选择音频文件", "", "WAV文件 (*.wav)"
)
if file_path:
self.audio_path = file_path
self.lbl_file.setText(os.path.basename(file_path))
self.btn_process.setEnabled(True)
self.update_status(f"已选择文件: {file_path}")
def process_audio(self):
if not self.audio_path or not os.path.exists(self.audio_path):
QMessageBox.warning(self, "错误", "音频文件不存在")
return
# 模型路径(请根据实际情况修改)
model_path = "pipeline_model.pkl" # 或 "svm_model.pkl"
scaler_path = "scaler.pkl"
if not (os.path.exists(model_path) and os.path.exists(scaler_path)):
QMessageBox.warning(self, "错误", "模型文件或标准化器不存在")
return
# 启动处理线程
self.processor = AudioProcessor(self.audio_path, model_path, scaler_path)
self.processor.status_updated.connect(self.update_status)
self.processor.result_generated.connect(self.show_result)
self.processor.start()
self.btn_process.setEnabled(False)
self.update_status("开始分析音频...")
def update_status(self, message):
self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}")
# 自动滚动到底部
self.status_display.moveCursor(self.status_display.textCursor().End)
def show_result(self, result):
# 显示结果对话框,增加保存到数据库的选项
msg = QMessageBox()
msg.setWindowTitle("分析结果")
msg.setIcon(QMessageBox.Information)
text = (
f"文件名: {result['filename']}\n"
f"有效片段数: {result['segment_count']}\n"
f"平均概率: {result['mean_probability']}\n"
f"最终分类: {result['label_text']}"
)
msg.setText(text)
# 添加保存到数据库的按钮
msg.addButton("保存到数据库", QMessageBox.AcceptRole)
msg.addButton("仅查看", QMessageBox.RejectRole)
reply = msg.exec_()
if reply == 0: # 保存到数据库
notes, ok = QInputDialog.getText(
self, "添加备注", "请输入备注信息(可选):",
QLineEdit.Normal, ""
)
if ok:
try:
self.db_manager.insert_result(result, notes)
self.update_status("结果已保存到数据库")
except Exception as e:
QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}")
self.update_status("分析完成")
self.btn_process.setEnabled(True)
# ---------------------- 主窗口 ----------------------
class AudioClassifierApp(QMainWindow):
def __init__(self):
super().__init__()
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(300, 300, 800, 600)
# 堆叠窗口管理页面
self.stack = QStackedWidget()
self.main_menu = MainMenu(self)
self.record_page = InputPage(self, mode="record")
self.upload_page = InputPage(self, mode="upload")
self.database_page = DatabaseUI(self) # 新增数据库页面
self.stack.addWidget(self.main_menu)
self.stack.addWidget(self.record_page)
self.stack.addWidget(self.upload_page)
self.stack.addWidget(self.database_page)
self.setCentralWidget(self.stack)
def switch_page(self, page_name):
"""切换页面"""
if page_name == "main":
self.stack.setCurrentWidget(self.main_menu)
elif page_name == "record":
self.stack.setCurrentWidget(self.record_page)
elif page_name == "upload":
self.stack.setCurrentWidget(self.upload_page)
elif page_name == "database":
self.stack.setCurrentWidget(self.database_page)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AudioClassifierApp()
window.show()
sys.exit(app.exec_())

@ -0,0 +1,630 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox) # >>> CHANGED: add QComboBox
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# 特征提取类
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]:
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
# 1. RMS能量
rms = np.sqrt(np.mean(x ** 2))
# 2. 主频(频谱峰值频率)
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive_freq_mask = freq >= 0
freq = freq[positive_freq_mask]
fft_mag = np.abs(fft[positive_freq_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# 录音线程
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int) # >>> NEW: 实时电平0-100
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index # >>> NEW
def run(self):
# >>> CHANGED: 更通用的音频参数
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p = None
stream = None
try:
p = pyaudio.PyAudio()
# 设备信息日志(帮助排查)
try:
device_log = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}")
if device_log:
self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log))
except Exception as _:
pass
# >>> CHANGED: 指定 input_device_index若传入
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
input_device_index=self.device_index, # 可能为 None则使用系统默认
frames_per_buffer=CHUNK
)
self.is_recording = True
start_time = time.time()
frames = []
# 用于电平估计的满刻度
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED
frames.append(data)
# >>> NEW: 计算电平RMS
# 将bytes转为np.int16归一化到[-1,1]计算RMS并发射到UI
try:
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
# 简单映射到0-100
level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整
self.level_signal.emit(level)
except Exception:
pass
if stream is not None:
stream.stop_stream()
stream.close()
if p is not None:
p.terminate()
if len(frames) > 0:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}")
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)")
except Exception as e:
# >>> NEW: 打开失败时提示设备列表
err_msg = f"录制错误: {str(e)}"
try:
if p is not None:
err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)"
finally:
self.update_signal.emit(err_msg)
finally:
try:
if stream is not None:
stream.stop_stream()
stream.close()
except Exception:
pass
try:
if p is not None:
p.terminate()
except Exception:
pass
# 处理线程
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path = wav_path
self.model_path = model_path
self.scaler_path = scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
# 归一化避免静音/削顶影响
if np.max(np.abs(sig)) > 0:
sig = sig / np.max(np.abs(sig))
self.update_signal.emit("提取特征...")
ene = FeatureExtractor.frame_energy(sig)
hit_starts = FeatureExtractor.detect_hits(ene)
segments = FeatureExtractor.extract_segments(sig, sr, hit_starts)
if not segments:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments]
X = np.vstack(feats)
self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}")
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
# 兼容无 predict_proba 的分类器
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
# 退化:用决策函数映射到(0,1)
if hasattr(model, "decision_function"):
df = model.decision_function(X_std)
y_proba = 1 / (1 + np.exp(-df))
else:
y_proba = (y_pred.astype(float) + 0.0)
self.finish_signal.emit({
"filename": os.path.basename(self.wav_path),
"segments": len(segments),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
})
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}")
# 第一层界面:主菜单
class MainMenuWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
title = QLabel("音频分类器")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
# 添加按钮
record_btn = QPushButton("采集音频")
record_btn.setMinimumHeight(60)
record_btn.setStyleSheet("font-size: 16px;")
record_btn.clicked.connect(lambda: self.parent.switch_to_input("record"))
upload_btn = QPushButton("上传外部WAV文件")
upload_btn.setMinimumHeight(60)
upload_btn.setStyleSheet("font-size: 16px;")
upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload"))
layout.addWidget(record_btn)
layout.addWidget(upload_btn)
layout.addStretch(1)
self.setLayout(layout)
# 第二层界面:输入界面(录音或上传文件)
class InputWidget(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode
self.wav_path = ""
self.record_thread = None
self.device_index = None # >>> NEW: 当前选中的输入设备索引
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back_btn)
if self.mode == "record":
self.setup_record_ui(layout)
else:
self.setup_upload_ui(layout)
# 模型路径
model_layout = QHBoxLayout()
self.model_path = QLineEdit("svm_model.pkl")
self.scaler_path = QLineEdit("scaler.pkl")
model_layout.addWidget(QLabel("模型路径:"))
model_layout.addWidget(self.model_path)
model_layout.addWidget(QLabel("标准化器路径:"))
model_layout.addWidget(self.scaler_path)
layout.addLayout(model_layout)
# 处理按钮
self.process_btn = QPushButton("开始处理")
self.process_btn.setMinimumHeight(50)
self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;")
self.process_btn.clicked.connect(self.start_process)
self.process_btn.setEnabled(False) # 初始不可用
layout.addWidget(self.process_btn)
# 日志区域
self.log_area = QTextEdit()
self.log_area.setReadOnly(True)
layout.addWidget(QLabel("日志:"))
layout.addWidget(self.log_area)
self.setLayout(layout)
def setup_record_ui(self, layout):
title = QLabel("音频采集")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# >>> NEW: 麦克风选择
device_layout = QHBoxLayout()
device_layout.addWidget(QLabel("输入设备:"))
self.device_combo = QComboBox()
self.refresh_devices()
self.device_combo.currentIndexChanged.connect(self.on_device_changed)
device_layout.addWidget(self.device_combo)
refresh_btn = QPushButton("刷新设备")
refresh_btn.clicked.connect(self.refresh_devices)
device_layout.addWidget(refresh_btn)
layout.addLayout(device_layout)
# 录音控制
record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)")
record_hint.setAlignment(Qt.AlignCenter)
self.record_btn = QPushButton("按住录音")
self.record_btn.setMinimumHeight(80)
self.record_btn.setStyleSheet("""
QPushButton {
background-color: #ff4d4d;
color: white;
font-size: 18px;
border-radius: 10px;
}
QPushButton:pressed {
background-color: #cc0000;
}
""")
self.record_btn.mousePressEvent = self.start_recording
self.record_btn.mouseReleaseEvent = self.stop_recording
self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu)
# >>> NEW: 实时电平与录音进度
self.mic_level = QProgressBar()
self.mic_level.setRange(0, 100)
self.mic_level.setFormat("麦克风电平:%p%")
self.record_progress = QProgressBar()
self.record_progress.setRange(0, 100)
self.record_progress.setValue(0)
self.record_duration_label = QLabel("录音时长: 0.0秒")
self.record_duration_label.setAlignment(Qt.AlignCenter)
layout.addWidget(record_hint)
layout.addWidget(self.record_btn)
layout.addWidget(self.mic_level) # >>> NEW
layout.addWidget(self.record_progress)
layout.addWidget(self.record_duration_label)
# 录音计时器
self.record_timer = QTimer(self)
self.record_timer.timeout.connect(self.update_record_duration)
self.record_start_time = 0
def refresh_devices(self):
"""枚举有输入通道的设备,并填充到下拉框"""
self.device_combo.clear()
try:
p = pyaudio.PyAudio()
default_host_api = p.get_host_api_info_by_index(0)
default_input_index = default_host_api.get("defaultInputDevice", -1)
found = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
name = info.get('name', f"Device {i}")
sr = int(info.get('defaultSampleRate', 0))
label = f"[{i}] {name} (sr={sr})"
self.device_combo.addItem(label, i)
found.append(i)
p.terminate()
# 选中默认输入设备(若存在)
if default_input_index in found:
idx = found.index(default_input_index)
self.device_combo.setCurrentIndex(idx)
self.device_index = default_input_index
elif found:
self.device_combo.setCurrentIndex(0)
self.device_index = self.device_combo.currentData()
else:
self.device_index = None
except Exception:
self.device_index = None
def on_device_changed(self, _):
self.device_index = self.device_combo.currentData()
def setup_upload_ui(self, layout):
title = QLabel("上传WAV文件")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 文件选择
file_layout = QHBoxLayout()
self.file_path = QLineEdit()
self.file_path.setReadOnly(True)
self.browse_btn = QPushButton("浏览WAV文件")
self.browse_btn.clicked.connect(self.browse_file)
file_layout.addWidget(self.file_path)
file_layout.addWidget(self.browse_btn)
layout.addLayout(file_layout)
def start_recording(self, event):
if event.button() == Qt.LeftButton:
if not self.record_thread or not self.record_thread.isRunning():
# >>> CHANGED: 传入 device_index
self.record_thread = RecordThread(max_duration=60, device_index=self.device_index)
self.record_thread.update_signal.connect(self.update_log)
self.record_thread.finish_signal.connect(self.on_recording_finish)
self.record_thread.progress_signal.connect(self.record_progress.setValue)
self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW
self.record_thread.start()
self.record_start_time = time.time()
self.record_timer.start(100)
self.update_log("开始录音...(松开按钮结束)")
def stop_recording(self, event):
if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning():
self.record_thread.is_recording = False
self.record_timer.stop()
self.record_duration_label.setText("录音时长: 0.0秒")
self.record_progress.setValue(0)
# 让电平逐步回落
self.mic_level.setValue(0)
def update_record_duration(self):
elapsed = time.time() - self.record_start_time
self.record_duration_label.setText(f"录音时长: {elapsed:.1f}")
def on_recording_finish(self, temp_file):
self.wav_path = temp_file
self.process_btn.setEnabled(True)
def browse_file(self):
file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)")
if file:
self.file_path.setText(file)
self.wav_path = file
self.process_btn.setEnabled(True)
def start_process(self):
if self.mode == "upload":
self.wav_path = self.file_path.text()
model_path = self.model_path.text()
scaler_path = self.scaler_path.text()
if not self.wav_path or not os.path.exists(self.wav_path):
self.log_area.append("请先录音或选择有效的WAV文件")
return
if not os.path.exists(model_path) or not os.path.exists(scaler_path):
self.log_area.append("模型文件不存在请先运行train_model.py训练模型")
return
self.log_area.clear()
self.process_btn.setEnabled(False)
self.thread = ProcessThread(self.wav_path, model_path, scaler_path)
self.thread.update_signal.connect(self.update_log)
self.thread.finish_signal.connect(self.on_process_finish)
self.thread.start()
def update_log(self, msg):
self.log_area.append(msg)
def on_process_finish(self, result):
self.parent.switch_to_result(result)
# 第三层界面:结果显示界面
class ResultWidget(QWidget):
def __init__(self, parent=None, result=None):
super().__init__(parent)
self.parent = parent
self.result = result
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_input_from_result)
layout.addWidget(back_btn)
title = QLabel("处理结果")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 结果显示
self.result_area = QTextEdit()
self.result_area.setReadOnly(True)
if self.result:
res = f"文件名: {self.result['filename']}\n"
res += f"片段数: {self.result['segments']}\n"
res += "预测结果:\n"
for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])):
res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n"
res += f"\n平均概率: {self.result['mean_prob']}\n"
res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}"
self.result_area.setText(res)
layout.addWidget(self.result_area)
self.setLayout(layout)
# 主窗口
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record" # 记录当前输入模式
self.process_result = None # 存储处理结果
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 800, 600)
# 创建堆叠窗口
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
# 创建三个界面
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
# 添加到堆叠窗口
self.stacked_widget.addWidget(self.main_menu_widget)
self.stacked_widget.addWidget(self.record_input_widget)
self.stacked_widget.addWidget(self.upload_input_widget)
self.stacked_widget.addWidget(self.result_widget)
# 显示主菜单
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
# 移除旧的结果界面并添加新的
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_input_from_result(self):
if self.current_input_mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
if __name__ == "__main__":
try:
import pyaudio # 确保已安装
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())

@ -0,0 +1,917 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import csv
import sqlite3
from datetime import datetime
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox, QTableWidget, QTableWidgetItem, QMessageBox,
QDialog, QFormLayout, QDialogButtonBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# ======================
# 本地 SQLite 数据库(含 CRUD
# ======================
class DatabaseManager:
def __init__(self, db_path="results.db"):
self.db_path = db_path
self._ensure_schema()
def _connect(self):
return sqlite3.connect(self.db_path)
def _ensure_schema(self):
con = self._connect()
cur = con.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT,
segments INTEGER,
mean_prob REAL,
final_label INTEGER,
created_at TEXT
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS run_segments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id INTEGER,
seg_index INTEGER,
label INTEGER,
proba REAL,
FOREIGN KEY(run_id) REFERENCES runs(id)
);
""")
con.commit()
con.close()
# ---------- 原有写入 ----------
def insert_result(self, result_dict):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(
result_dict.get("filename"),
int(result_dict.get("segments", 0)),
float(result_dict.get("mean_prob", 0.0)),
int(result_dict.get("final_label", 0)),
datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)
)
run_id = cur.lastrowid
preds = result_dict.get("predictions", [])
probas = result_dict.get("probabilities", [])
for i, (lab, pr) in enumerate(zip(preds, probas), start=1):
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(run_id, i, int(lab), float(pr))
)
con.commit()
con.close()
return run_id
# ---------- R查询列表/单条) ----------
def fetch_recent_runs(self, limit=50):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs ORDER BY id DESC LIMIT ?
""", (limit,))
rows = cur.fetchall()
con.close()
return rows
def get_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs WHERE id = ?
""", (run_id,))
row = cur.fetchone()
con.close()
return row
def search_runs(self, keyword: str, limit=100):
kw = f"%{keyword}%"
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs
WHERE filename LIKE ?
ORDER BY id DESC LIMIT ?
""", (kw, limit))
rows = cur.fetchall()
con.close()
return rows
# ---------- C/U/Drun ----------
def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None):
created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(filename, int(segments), float(mean_prob), int(final_label), created_at)
)
run_id = cur.lastrowid
con.commit()
con.close()
return run_id
def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE runs
SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=?
WHERE id=?
""", (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id)))
con.commit()
con.close()
def delete_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
# 先删子表
cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),))
cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),))
con.commit()
con.close()
# ---------- C/R/U/Dsegments ----------
def list_segments(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, seg_index, label, proba FROM run_segments
WHERE run_id=? ORDER BY seg_index ASC
""", (int(run_id),))
rows = cur.fetchall()
con.close()
return rows
def add_segment(self, run_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(int(run_id), int(seg_index), int(label), float(proba))
)
seg_id = cur.lastrowid
con.commit()
con.close()
return seg_id
def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=?
""", (int(seg_index), int(label), float(proba), int(seg_id)))
con.commit()
con.close()
def delete_segment(self, seg_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),))
con.commit()
con.close()
# ---------- 导出 ----------
def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000):
rows = self.fetch_recent_runs(limit)
headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"]
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(headers)
for r in rows:
writer.writerow(r)
return csv_path
DB = DatabaseManager("results.db")
# ======================
# 特征提取(同前)
# ======================
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray):
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
rms = np.sqrt(np.mean(x ** 2))
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive = freq >= 0
freq, fft_mag = freq[positive], np.abs(fft[positive])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# ======================
# 录音线程(同前)
# ======================
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int)
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index
def run(self):
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p, stream = None, None
try:
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True,
input_device_index=self.device_index, frames_per_buffer=CHUNK)
self.is_recording = True
start_time = time.time()
frames = []
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False)
frames.append(data)
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
self.level_signal.emit(int(np.clip(rms * 500, 0, 100)))
stream.stop_stream(); stream.close(); p.terminate()
if frames:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录到有效音频。")
except Exception as e:
self.update_signal.emit(f"录制错误: {e}")
# ======================
# 推理处理线程(写入 DB
# ======================
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
sig = sig / max(np.max(np.abs(sig)), 1e-9)
ene = FeatureExtractor.frame_energy(sig)
hits = FeatureExtractor.detect_hits(ene)
segs = FeatureExtractor.extract_segments(sig, sr, hits)
if not segs:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs]
X = np.vstack(feats)
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
y_proba = (y_pred.astype(float) + 0.0)
result = {
"filename": os.path.basename(self.wav_path),
"segments": len(segs),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
}
DB.insert_result(result)
self.finish_signal.emit(result)
except Exception as e:
self.update_signal.emit(f"错误: {e}")
# ======================
# 对话框:编辑 run
# ======================
class RunEditDialog(QDialog):
def __init__(self, parent=None, run_row=None):
"""
run_row: (id, created_at, filename, segments, mean_prob, final_label) None新增
"""
super().__init__(parent)
self.setWindowTitle("编辑记录" if run_row else "新增记录")
self.resize(380, 240)
self.run_row = run_row
self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True)
self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.ed_filename = QLineEdit()
self.ed_segments = QLineEdit("0")
self.ed_mean_prob = QLineEdit("0.0")
self.ed_final_label = QLineEdit("0") # 0=实心 1=空心
form = QFormLayout()
form.addRow("ID", self.ed_id)
form.addRow("时间(created_at)", self.ed_created)
form.addRow("文件名(filename)", self.ed_filename)
form.addRow("片段数(segments)", self.ed_segments)
form.addRow("平均概率(mean_prob)", self.ed_mean_prob)
form.addRow("最终标签(0=实 1=空)", self.ed_final_label)
if run_row:
self.ed_id.setText(str(run_row[0]))
self.ed_created.setText(str(run_row[1]))
self.ed_filename.setText(str(run_row[2]))
self.ed_segments.setText(str(run_row[3]))
self.ed_mean_prob.setText(str(run_row[4]))
self.ed_final_label.setText(str(run_row[5]))
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
lay = QVBoxLayout()
lay.addLayout(form)
lay.addWidget(btns)
self.setLayout(lay)
def values(self):
return dict(
id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None),
created_at=self.ed_created.text().strip(),
filename=self.ed_filename.text().strip(),
segments=int(float(self.ed_segments.text() or 0)),
mean_prob=float(self.ed_mean_prob.text() or 0.0),
final_label=int(float(self.ed_final_label.text() or 0))
)
# ======================
# 对话框:管理 segments
# ======================
class SegmentManagerDialog(QDialog):
def __init__(self, parent=None, run_id=None):
super().__init__(parent)
self.setWindowTitle(f"片段管理 - run_id={run_id}")
self.resize(560, 420)
self.run_id = run_id
self.table = QTableWidget(0, 4)
self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"])
self.btn_add = QPushButton("新增片段")
self.btn_edit = QPushButton("编辑所选")
self.btn_del = QPushButton("删除所选")
self.btn_close = QPushButton("关闭")
hl = QHBoxLayout()
for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close):
hl.addWidget(b)
lay = QVBoxLayout()
lay.addWidget(self.table)
lay.addLayout(hl)
self.setLayout(lay)
self.btn_add.clicked.connect(self.add_segment)
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_del.clicked.connect(self.delete_selected)
self.btn_close.clicked.connect(self.accept)
self.load_segments()
def load_segments(self):
rows = DB.list_segments(self.run_id)
self.table.setRowCount(len(rows))
for i, (sid, idx, label, proba) in enumerate(rows):
self.table.setItem(i, 0, QTableWidgetItem(str(sid)))
self.table.setItem(i, 1, QTableWidgetItem(str(idx)))
self.table.setItem(i, 2, QTableWidgetItem(str(label)))
self.table.setItem(i, 3, QTableWidgetItem(str(proba)))
self.table.resizeColumnsToContents()
def _get_selected_seg_id(self):
r = self.table.currentRow()
if r < 0:
return None
return int(self.table.item(r, 0).text())
def add_segment(self):
d = QDialog(self)
d.setWindowTitle("新增片段")
form = QFormLayout()
ed_index = QLineEdit("1")
ed_label = QLineEdit("0")
ed_proba = QLineEdit("0.0")
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)),
float(ed_proba.text() or 0.0))
self.load_segments()
def edit_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
# 读当前值
row = self.table.currentRow()
cur_idx = self.table.item(row, 1).text()
cur_label = self.table.item(row, 2).text()
cur_proba = self.table.item(row, 3).text()
d = QDialog(self)
d.setWindowTitle("编辑片段")
form = QFormLayout()
ed_index = QLineEdit(cur_idx)
ed_label = QLineEdit(cur_label)
ed_proba = QLineEdit(cur_proba)
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)),
int(float(ed_label.text() or cur_label)),
float(ed_proba.text() or cur_proba))
self.load_segments()
def delete_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!",
QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes:
DB.delete_segment(seg_id)
self.load_segments()
# ======================
# 主菜单
# ======================
class MainMenuWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
layout = QVBoxLayout()
title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record"))
btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload"))
btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history)
for b in [btn1, btn2, btn3]:
b.setMinimumHeight(60); layout.addWidget(b)
layout.addStretch(1); self.setLayout(layout)
# ======================
# 输入界面(录音 / 上传)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
# ======================
# 历史记录(含 CRUD 按钮)
# ======================
class HistoryWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
self.parent=parent
self.init_ui()
def init_ui(self):
l=QVBoxLayout()
back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back)
top=QHBoxLayout()
self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...")
btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search)
btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load)
btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv)
btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run)
btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected)
btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected)
btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments)
for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]:
top.addWidget(w)
l.addLayout(top)
self.table=QTableWidget(0,6)
self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"])
l.addWidget(self.table)
self.setLayout(l)
self.load()
def _current_run_id(self):
r=self.table.currentRow()
if r<0: return None
return int(self.table.item(r,0).text())
# ---- 数据加载 / 搜索 ----
def load(self):
rows=DB.fetch_recent_runs(200)
self._fill(rows)
def search(self):
kw=self.search_edit.text().strip()
rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200)
self._fill(rows)
def _fill(self, rows):
self.table.setRowCount(len(rows))
for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows):
self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid)))
self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at)))
self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn)))
self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs)))
self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob)))
self.table.setItem(r_idx, 5, QTableWidgetItem("实心" if int(final_label) else "空心"))
self.table.resizeColumnsToContents()
# ---- 导出 ----
def export_csv(self):
p=DB.export_runs_to_csv()
QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}")
# ---- CRUD: run ----
def add_run(self):
dlg=RunEditDialog(self, run_row=None)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def edit_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
row=DB.get_run(rid)
dlg=RunEditDialog(self, run_row=row)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def delete_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!",
QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes:
DB.delete_run(rid)
self.load()
# ---- 片段管理 ----
def open_segments(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
dlg=SegmentManagerDialog(self, run_id=rid)
dlg.exec_()
# 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主
self.load()
# ======================
# 结果页(简)
# ======================
class ResultWidget(QWidget):
def __init__(self, parent, result=None):
super().__init__(parent); self.parent=parent
layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b)
t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t)
area=QTextEdit(); area.setReadOnly(True)
if result:
txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'实心' if result['final_label'] else '空心'}"
area.setText(txt)
layout.addWidget(area); self.setLayout(layout)
# ======================
# 主窗口
# ======================
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record"
self.process_result = None
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 1000, 700)
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
self.history_widget = HistoryWidget(self)
self.stacked_widget.addWidget(self.main_menu_widget) # 0
self.stacked_widget.addWidget(self.record_input_widget) # 1
self.stacked_widget.addWidget(self.upload_input_widget) # 2
self.stacked_widget.addWidget(self.result_widget) # 3
self.stacked_widget.addWidget(self.history_widget) # 4
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_history(self):
self.history_widget.load()
self.stacked_widget.setCurrentWidget(self.history_widget)
# ======================
# 上传/录音 界面(放最后,避免打断阅读)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
if __name__ == "__main__":
try:
import pyaudio
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())

Binary file not shown.

@ -0,0 +1,180 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
有标注学习音频特征提取读取瓷空1.wav标注为提取五维特征+标签保存MAT/PKL适配深度学习
"""
from pathlib import Path
import numpy as np
import scipy.io.wavfile as wav
from scipy.io import savemat
from scipy.signal import hilbert
import librosa
import matplotlib.pyplot as plt
import os
import pickle # 用于保存PKL文件
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示正常
plt.rcParams['axes.unicode_minus'] = False # 负号显示正常
# ---------- 参数设定(核心:指定有标注的学习音频路径,标签自动识别) ----------
WAV_FILE = r"D:\SummerSchool\sample\瓷空1.wav" # 有标注的学习音频(文件名含“空”,自动识别标签)
WIN_SIZE = 1024 # 帧长(与测试音频代码一致)
OVERLAP = 512 # 帧移(与测试音频代码一致)
STEP = WIN_SIZE - OVERLAP # 帧步长(与测试音频代码一致)
THRESH = 0.01 # 能量阈值(降低以确保检测到敲击片段,与测试音频代码一致)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒,与测试音频代码一致)
# 标签映射按深度学习习惯定义“空”标注为0后续可根据需求修改若“实”则改为1
LABEL_MAP = {"": 0, "": 1}
# 输出文件路径(默认保存在音频同目录,文件名含“学习”标识,便于区分)
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.mat"
OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.pkl"
# ---------- 工具函数(完全复用之前的逻辑,确保特征一致性) ----------
def segment_signal(signal: np.ndarray, fs: int):
"""按能量切分敲击片段(与测试音频代码完全一致)"""
if signal.ndim > 1: # 双声道自动转单声道
signal = signal[:, 0]
signal = signal / (np.max(np.abs(signal)) + 1e-12) # 音频归一化(避免幅值影响)
# 分帧并计算每帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选能量高于阈值的帧,定位“新敲击”起始点
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return []
# 相邻有效帧间隔>5帧时视为新的敲击避免连续帧重复计数
hit_mask = np.diff(np.concatenate(([0], idx))) > 5
hit_starts = idx[hit_mask]
# 切分固定长度的片段(不足长度时取到音频末尾)
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for start_frame in hit_starts:
start_sample = start_frame * STEP
end_sample = min(start_sample + seg_len, len(signal))
segments.append(signal[start_sample:end_sample])
return segments
def extract_features(sig: np.ndarray, fs: int):
"""提取五维特征(与测试音频代码完全一致,保证深度学习数据匹配)"""
sig = sig.flatten()
if sig.size == 0: # 空片段防报错返回0向量
return np.zeros(5, dtype=np.float32)
# 1. RMS均方根反映音频能量大小
rms = np.sqrt(np.mean(sig ** 2))
# 2. 主频(频谱峰值对应的频率:反映敲击声的主要频率成分)
L = sig.size
freq = np.fft.rfftfreq(L, d=1 / fs) # 频率轴
fft_mag = np.abs(np.fft.rfft(sig)) # 频谱幅值
main_freq = freq[np.argmax(fft_mag)]
# 3. 频谱偏度(反映频谱分布的不对称性:区分“空”“实”的关键特征之一)
spec_power = fft_mag
freq_centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) # 频谱质心
freq_spread = np.sqrt(np.sum(((freq - freq_centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) # 频谱展宽
skewness = np.sum(((freq - freq_centroid) ** 3) * spec_power) / ((np.sum(spec_power) + 1e-12) * (freq_spread ** 3 + 1e-12))
# 4. MFCC第一维均值梅尔频率倒谱系数反映音频的音色特征
try:
mfcc = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13) # 提取13维MFCC
mfcc_mean = float(np.mean(mfcc[0, :])) # 取第一维均值(最能区分音色)
except Exception: # 异常情况如片段过短返回0
mfcc_mean = 0.0
# 5. 包络峰值(希尔伯特变换提取幅度包络:反映敲击声的衰减特性)
amp_envelope = np.abs(hilbert(sig))
env_peak = np.max(amp_envelope)
# 特征格式统一为float32适配深度学习框架
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
# ---------- 主程序(核心:自动识别标签+特征+标签保存) ----------
def main():
# 1. 检查音频文件是否存在
wav_path = Path(WAV_FILE)
if not wav_path.exists():
print(f"❌ 错误:音频文件 {WAV_FILE} 不存在!")
return
if wav_path.suffix != ".wav":
print(f"❌ 错误:{wav_path.name} 不是WAV格式")
return
# 2. 读取音频用librosa兼容更多格式保持采样率不变
audio, sr = librosa.load(wav_path, sr=None, mono=True)
print(f"✅ 成功读取学习音频:{wav_path.name}")
print(f" 采样率:{sr} Hz | 音频长度:{len(audio)/sr:.2f}")
# 3. 切分有效敲击片段
segments = segment_signal(audio, sr)
if len(segments) == 0:
print(f"⚠️ 未检测到有效敲击片段可尝试降低THRESH当前{THRESH})或检查音频是否有敲击声。")
return
print(f"✅ 检测到 {len(segments)} 个有效敲击片段")
# 4. 提取特征+自动识别标签
features_list = []
labels_list = []
# 从文件名提取标注“瓷空1.wav”含“空”对应标签0
file_stem = wav_path.stem # 文件名(不含后缀):"瓷空1"
if "" in file_stem:
label = LABEL_MAP[""]
print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(空)")
elif "" in file_stem:
label = LABEL_MAP[""]
print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(实)")
else:
print(f"⚠️ 文件名 {file_stem} 不含''''手动指定标签为0")
label = LABEL_MAP[""] # 手动兜底,可根据实际修改
# 批量提取特征并匹配标签(每个片段对应一个标签)
for i, seg in enumerate(segments, 1):
feat = extract_features(seg, sr)
features_list.append(feat)
labels_list.append(label)
print(f" 片段{i:02d}特征提取完成维度5")
# 5. 整理为矩阵格式(适配深度学习输入)
features_matrix = np.vstack(features_list) # 特征矩阵:(片段数, 5)
labels_array = np.array(labels_list, dtype=np.int8).reshape(-1, 1) # 标签矩阵:(片段数, 1)
print(f"\n✅ 特征与标签整理完成")
print(f" 特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征")
print(f" 标签矩阵形状:{labels_array.shape}(行=片段数,列=1")
# 6. 保存为MAT文件兼容MATLAB深度学习框架
savemat(OUT_MAT, {
"matrix": features_matrix, # 特征矩阵(与之前训练集格式一致)
"label": labels_array # 标签矩阵(与之前训练集格式一致)
})
print(f"✅ MAT文件已保存{OUT_MAT}")
# 7. 保存为PKL文件兼容Python深度学习框架如PyTorch/TensorFlow
with open(OUT_PKL, "wb") as f:
pickle.dump({
"matrix": features_matrix, # 特征矩阵
"label": labels_array # 标签矩阵(含标注信息)
}, f)
print(f"✅ PKL文件已保存{OUT_PKL}")
# 8. 特征可视化(可选,帮助直观查看特征分布)
plt.figure(figsize=(12, 8))
feature_names = ["RMS能量", "主频Hz", "频谱偏度", "MFCC均值", "包络峰值"]
for i in range(5):
plt.subplot(2, 3, i+1)
plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", color="#1f77b4", linewidth=1.5, markersize=4)
plt.xlabel("片段编号", fontsize=10)
plt.ylabel("特征值", fontsize=10)
plt.title(f"特征{i+1}{feature_names[i]}", fontsize=11, fontweight="bold")
plt.grid(True, alpha=0.3)
# 标签信息标注
plt.subplot(2, 3, 6)
plt.text(0.5, 0.6, f"音频文件:{wav_path.name}", ha="center", fontsize=11)
plt.text(0.5, 0.4, f"标注标签:{label}{'' if label==0 else ''}", ha="center", fontsize=11)
plt.text(0.5, 0.2, f"有效片段数:{len(features_matrix)}", ha="center", fontsize=11)
plt.axis("off")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
无标注测试音频特征提取读取单个WAV提取五维特征保存为MAT和PKL无标签
"""
from pathlib import Path
import numpy as np
import scipy.io.wavfile as wav
from scipy.io import savemat
from scipy.signal import hilbert
import librosa
import matplotlib.pyplot as plt
import os
import pickle # 用于保存PKL文件
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示
plt.rcParams['axes.unicode_minus'] = False # 负号正常显示
# ---------- 参数设定(无需改,按原逻辑) ----------
WAV_FILE = r"D:\SummerSchool\test2.wav" # 你的无标注测试音频路径
WIN_SIZE = 1024 # 帧长
OVERLAP = 512 # 帧移
STEP = WIN_SIZE - OVERLAP # 帧步长
THRESH = 0.01 # 降低阈值,确保能检测到片段(已调小)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒)
# 输出文件路径可自定义默认保存在WAV文件同目录
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.mat"
OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.pkl"
# ---------- 工具函数(完全保留原特征提取逻辑,确保和训练集一致) ----------
def segment_signal(signal: np.ndarray, fs: int):
"""按能量切分音频片段(原逻辑不变)"""
if signal.ndim > 1: # 双声道转单声道
signal = signal[:, 0]
signal = signal / (np.max(np.abs(signal)) + 1e-12) # 归一化
# 分帧+计算帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选能量高于阈值的帧,切出有效片段
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return []
hit_mask = np.diff(np.concatenate(([0], idx))) > 5 # 新敲击起始帧
hit_starts = idx[hit_mask]
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for h in hit_starts:
start = h * STEP
end = min(start + seg_len, len(signal))
segments.append(signal[start:end])
return segments
def extract_features(sig: np.ndarray, fs: int):
"""提取五维特征(和训练集完全一致,保证特征匹配)"""
sig = sig.flatten()
if sig.size == 0:
return np.zeros(5)
# 1. RMS均方根
rms = np.sqrt(np.mean(sig ** 2))
# 2. 主频(频谱峰值对应的频率)
L = sig.size
f = np.fft.rfftfreq(L, d=1 / fs)
Y = np.abs(np.fft.rfft(sig))
main_freq = f[np.argmax(Y)]
# 3. 频谱偏度
P = Y
centroid = np.sum(f * P) / (np.sum(P) + 1e-12)
spread = np.sqrt(np.sum(((f - centroid) ** 2) * P) / (np.sum(P) + 1e-12))
skewness = np.sum(((f - centroid) ** 3) * P) / ((np.sum(P) + 1e-12) * (spread ** 3 + 1e-12))
# 4. MFCC第一维均值
try:
mfccs = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13)
mfcc_mean = float(np.mean(mfccs[0, :]))
except Exception:
mfcc_mean = 0.0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(sig)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# ---------- 主程序(核心:去掉标签,只提特征+保存) ----------
def main():
# 1. 读取音频文件
wav_path = Path(WAV_FILE)
if not (wav_path.exists() and wav_path.suffix == ".wav"):
print(f"❌ 错误:{WAV_FILE} 不存在或不是WAV文件")
return
# 用librosa读取兼容性更好避免格式问题
y, fs = librosa.load(wav_path, sr=None, mono=True)
print(f"✅ 成功读取音频:{wav_path.name},采样率:{fs} Hz")
# 2. 切分有效片段
segments = segment_signal(y, fs)
if len(segments) == 0:
print(f"⚠️ 未检测到有效音频片段尝试再降低THRESH当前{THRESH}")
return
print(f"✅ 检测到 {len(segments)} 个有效片段")
# 3. 提取五维特征
features = [extract_features(seg, fs) for seg in segments]
features_matrix = np.vstack(features).astype(np.float32) # 特征矩阵N行5列N=片段数)
print(f"✅ 提取特征完成,特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征")
# 4. 保存为MAT文件兼容MATLAB
savemat(OUT_MAT, {"matrix": features_matrix}) # 只存特征矩阵无label
print(f"✅ MAT文件已保存{OUT_MAT}")
# 5. 保存为PKL文件兼容Python后续模型推断
with open(OUT_PKL, "wb") as f:
pickle.dump({"matrix": features_matrix}, f) # 和训练集PKL结构一致只少label
print(f"✅ PKL文件已保存{OUT_PKL}")
# (可选)绘制特征可视化图
plt.figure(figsize=(10, 6))
feature_names = ["RMS", "主频(Hz)", "频谱偏度", "MFCC均值", "包络峰值"]
for i in range(5):
plt.subplot(2, 3, i+1)
plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", linewidth=1.5)
plt.xlabel("片段编号")
plt.ylabel("特征值")
plt.title(f"特征:{feature_names[i]}")
plt.grid(True)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,30 @@
系统简介
本系统是一个基于SVM机器学习的墙体声纹检测系统支持通过录制或上传 WAV 格式音频文件,实现音频片段的自动切分、特征提取与分类,并将结果存储于本地数据库中,方便后续查看、管理与导出。主要应用场景包括对特定音频(如敲击声等)的分类识别,支持 "实心" 和 "空心" 两类标签的识别与管理。
核心功能
音频输入:支持两种输入方式
实时录音:通过麦克风采集音频,支持设备选择与录音时长监控
文件上传:上传本地 WAV 格式音频文件
音频处理与分类
自动切分:基于音频能量特征切分有效片段
特征提取:提取音频片段的 RMS、主频、频谱偏度等多维特征
模型预测:使用预训练的 SVM 模型进行分类输出每个片段的标签0/1和概率值
数据管理
本地存储:采用 SQLite 数据库存储分类结果(含文件名、片段数、平均概率等)及片段详情
历史记录:支持查询、搜索、编辑、删除历史记录
片段管理:可查看、新增、编辑、删除特定记录的音频片段
可视化交互:通过直观的 GUI 界面操作,包括主菜单、输入界面、结果展示、历史管理等模块
配置环境
依赖库
本系统依赖以下 Python 库,建议使用 Python 3.9 + 版本:
GUI 框架PyQt5用于构建图形界面
音频处理librosa音频加载、特征提取、pyaudio录音功能
数值计算numpy数组处理、特征矩阵运算
机器学习scikit-learnSVM 模型、数据标准化)
模型存储joblib保存 / 加载训练好的模型)
数据库sqlite3本地数据库操作Python 标准库)
可视化matplotlib特征可视化可选
数据序列化pickle特征数据存储与读取
SVM模型的开发依赖于Python下的Scikit-learn、Librosa、NumPy等一系列科学计算与音频处理库。这些库本身存在复杂的底层依赖在不同操作系统上通过传统的pip安装方式极易出现依赖冲突、编译失败等问题。
为彻底解决环境配置的复杂性建议选用Anaconda预编译的二进制包和强大的环境管理功能能够自动处理库与库之间的依赖关系

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -0,0 +1,91 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基于交叉验证代码的训练脚本
"""
import pickle
import numpy as np
from pathlib import Path
from sklearn.svm import SVC
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from scipy.stats import randint, loguniform
import joblib
def load_dataset(pkl_path):
"""加载数据集"""
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
def train_cross_validated_model(dataset_path, pipeline_save_path, scaler_save_path):
# 加载数据
X_train, y_train = load_dataset(dataset_path)
print(f"加载数据集:{X_train.shape[0]}个样本,{X_train.shape[1]}维特征")
# 将标签转换为-1/1格式与交叉验证代码一致
y_train_signed = np.where(y_train == 0, -1, 1)
# 创建流水线包含特征选择和SVM
pipe = Pipeline([
('sel', SelectKBest(mutual_info_classif)),
('svm', SVC(kernel='rbf', class_weight='balanced', probability=True))
])
# 参数分布(与交叉验证代码一致)
n_features = X_train.shape[1]
param_dist = {
'sel__k': randint(1, n_features + 1),
'svm__C': loguniform(1e-3, 1e3),
'svm__gamma': loguniform(1e-6, 1e1)
}
# 随机搜索
cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
search = RandomizedSearchCV(
pipe,
param_distributions=param_dist,
n_iter=30, # 减少迭代次数以加快训练
scoring='roc_auc',
cv=cv_inner,
n_jobs=-1,
random_state=42,
verbose=1
)
print("开始随机搜索优化...")
search.fit(X_train, y_train_signed)
best_params = search.best_params_
print(f"\n最佳参数: {best_params}")
print(f"最佳交叉验证AUC: {search.best_score_:.4f}")
# 训练最终模型
final_model = search.best_estimator_
final_model.fit(X_train, y_train_signed)
# 单独保存标准化器用于GUI中的特征标准化
scaler = StandardScaler().fit(X_train)
# 保存模型
joblib.dump(final_model, pipeline_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"流水线模型已保存至: {pipeline_save_path}")
print(f"标准化器已保存至: {scaler_save_path}")
return final_model, scaler
if __name__ == "__main__":
# 使用你的训练集路径
DATASET_PATH = r"D:\Python\空心检测\pythonProject\feature_dataset.pkl"
PIPELINE_PATH = "pipeline_model.pkl" # 与GUI中设置的路径一致
SCALER_PATH = "scaler.pkl"
train_cross_validated_model(DATASET_PATH, PIPELINE_PATH, SCALER_PATH)

@ -0,0 +1,55 @@
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import joblib
# 加载数据集使用你之前合并的feature_dataset00.pkl
def load_dataset(pkl_path):
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
# 训练模型
def train_and_save_model(dataset_path, model_save_path, scaler_save_path):
# 加载数据
X, y = load_dataset(dataset_path)
print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征")
# 划分训练集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 标准化
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)
# 训练SVM
svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42)
svm.fit(X_train_std, y_train)
# 评估
y_pred = svm.predict(X_test_std)
print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}")
# 保存模型和标准化器
joblib.dump(svm, model_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"模型已保存至:{model_save_path}")
print(f"标准化器已保存至:{scaler_save_path}")
if __name__ == "__main__":
# 替换为你的数据集路径
DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl"
# 模型保存路径与GUI代码中设置的路径一致
MODEL_PATH = "svm_model.pkl"
SCALER_PATH = "scaler.pkl"
train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH)

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
"""
交叉验证最小改动版与原结构一致仅在第6段训练后新增模型与scaler导出
"""
from pathlib import Path
import pickle
import numpy as np
import matplotlib.pyplot as plt
# 新增/补全的 import
import joblib
from scipy.stats import randint, loguniform, norm
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.pipeline import Pipeline
# ---------- 1. 数据路径 ----------
BASE_DIR = Path(r'D:\SummerSchool\mat_cv\mat_cv')
TRAIN_PKL = BASE_DIR / 'cv10_train.pkl'
TEST_FILES = [BASE_DIR / 'cv10_test.pkl'] # 也可放多个测试集 pkl 文件
# ---------- 2. 工具 ----------
def load_pkl_matrix(path: Path):
with open(path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data.get('label')
# ---------- 3. 读取训练集 ----------
X_train, y_train = load_pkl_matrix(TRAIN_PKL)
if y_train is None:
raise ValueError('训练集缺少 label 字段')
y_train = y_train.ravel()
# {0,1} → {-1,+1}
y_train_signed = np.where(y_train == 0, -1, 1)
# ---------- 4. 标准化 ----------
scaler = StandardScaler().fit(X_train)
X_train_std = scaler.transform(X_train)
n_features = X_train_std.shape[1]
# ---------- 5. RandomizedSearchCV 搜索 ----------
pipe = Pipeline([
('sel', SelectKBest(mutual_info_classif)),
('svm', SVC(kernel='rbf', class_weight='balanced', probability=True))
])
param_dist = {
'sel__k': randint(1, n_features + 1),
'svm__C': loguniform(1e-3, 1e3),
'svm__gamma': loguniform(1e-6, 1e1)
}
cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
search = RandomizedSearchCV(
pipe,
param_distributions=param_dist,
n_iter=60, # 搜索次数
scoring='roc_auc',
cv=cv_inner,
n_jobs=-1,
random_state=42,
verbose=1
)
search.fit(X_train_std, y_train_signed)
best_params = search.best_params_
print("\n▶ RandomizedSearch 最佳参数:", best_params)
print(f" 内层 5-折 AUC ≈ {search.best_score_:.4f}")
# ---------- 6. 训练最终流水线 ----------
final_model = search.best_estimator_
final_model.fit(X_train_std, y_train_signed)
# ---------- 6.5 新增:导出模型与标准化器(供 GUI 使用) ----------
# 输出到 BASE_DIR 下,也可按需改路径
model_out = BASE_DIR / 'svm_model.pkl'
scaler_out = BASE_DIR / 'scaler.pkl'
joblib.dump(final_model, model_out)
joblib.dump(scaler, scaler_out)
print(f"\n✅ 已导出模型与标尺:\n 模型: {model_out}\n 标尺: {scaler_out}\n has_predict_proba: {hasattr(final_model, 'predict_proba')}")
# ---------- 7. 外层 5-折交叉验证 ----------
cv_outer = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_auc = cross_val_score(final_model, X_train_std, y_train_signed,
cv=cv_outer, scoring='roc_auc', n_jobs=-1)
cv_acc = cross_val_score(final_model, X_train_std, y_train_signed,
cv=cv_outer, scoring='accuracy', n_jobs=-1)
print('\n========== 外层 5-折交叉验证 ==========')
print(f'AUC = {cv_auc.mean():.4f} ± {cv_auc.std():.4f}')
print(f'ACC = {cv_acc.mean():.4f} ± {cv_acc.std():.4f}')
# ---------- 8. 推断 ----------
THRESHOLD = 0.5
Z = norm.ppf(0.975)
infer_results = []
print('\n========== 推断结果 ==========')
for pkl_path in TEST_FILES:
X_test, _ = load_pkl_matrix(pkl_path)
X_test_std = scaler.transform(X_test)
pred_signed = final_model.predict(X_test_std)
proba_pos = final_model.predict_proba(X_test_std)[:, 1]
pred_label = np.where(pred_signed == -1, 0, 1)
mean_p = proba_pos.mean()
sem_p = proba_pos.std(ddof=1) / np.sqrt(len(proba_pos)) if len(proba_pos) > 1 else 0.0
ci_low, ci_high = mean_p - Z * sem_p, mean_p + Z * sem_p
file_label = int(mean_p >= THRESHOLD)
print(f'\n▶ 文件: {pkl_path.name} (样本 {len(pred_label)})')
for i, (lbl, prob) in enumerate(zip(pred_label, proba_pos), 1):
print(f' Sample {i:02d}: pred={lbl} prob(1)={prob:.4f}')
print(' ---- 文件级融合 ----')
print(f' mean_prob(1) = {mean_p:.4f} (95% CI {ci_low:.4f} ~ {ci_high:.4f})')
print(f' Final label = {file_label} (阈值 {THRESHOLD})')
infer_results.append(dict(
file=pkl_path.name,
pred=pred_label.tolist(),
prob=proba_pos.tolist(),
mean_prob=float(mean_p),
ci_low=float(ci_low),
ci_high=float(ci_high),
final_label=int(file_label)
))
# 打印测试文件的原始标签(若有)
try:
print("TEST_FILES 标签:", load_pkl_matrix(TEST_FILES[0])[1])
except Exception:
pass
# ---------- 9. 保存 & 可视化 ----------
out_pkl = BASE_DIR / 'infer_results.pkl'
with open(out_pkl, 'wb') as f:
pickle.dump(infer_results, f)
print(f'\n所有文件结果已保存到: {out_pkl}')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
labels = [r['file'] for r in infer_results]
means = [r['mean_prob'] for r in infer_results]
yerr = [(r['mean_prob'] - r['ci_low'], r['ci_high'] - r['mean_prob'])
for r in infer_results]
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(range(len(means)), means,
yerr=np.array(yerr).T, capsize=5, alpha=0.8)
ax.axhline(THRESHOLD, color='red', ls='--', label=f'阈值 {THRESHOLD}')
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=15)
ax.set_ylim(0, 1)
ax.set_ylabel('mean_prob(空心=0)')
ax.set_title('文件级空心概率 (±95% CI)')
ax.legend()
plt.tight_layout()
desktop = Path.home() / 'Desktop'
save_path = desktop / 'infer_summary.png'
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'可视化图已保存至: {save_path}')
plt.show()

@ -0,0 +1,73 @@
import os
import pickle
import numpy as np
from pathlib import Path
def merge_pkl_files(input_dir, output_path="feature_dataset.pkl"):
all_features = []
all_labels = []
pkl_files = list(Path(input_dir).glob("*.pkl"))
if not pkl_files:
raise FileNotFoundError(f"在目录 {input_dir} 中未找到PKL文件")
# 排除已合并的文件和无效文件
pkl_files = [f for f in pkl_files if f.name not in ["feature_dataset.pkl", "infer_results.pkl"]]
print(f"发现 {len(pkl_files)} 个有效PKL文件开始合并...")
for file in pkl_files:
try:
with open(file, "rb") as f:
data = pickle.load(f)
if "matrix" not in data or "label" not in data:
print(f"跳过 {file.name}:缺少'matrix''label'字段")
continue
features = data["matrix"]
labels = data["label"]
# 强制将标签转为一维整数数组(核心修复)
labels = labels.ravel().astype(np.int64) # 转为int64类型
# 验证特征和标签数量匹配
if len(features) != len(labels):
print(f"跳过 {file.name}:特征({len(features)})与标签({len(labels)})数量不匹配")
continue
# 验证特征维度一致性
if all_features and features.shape[1] != all_features[0].shape[1]:
print(f"跳过 {file.name}:特征维度与已有数据不一致(现有{all_features[0].shape[1]}维,当前{features.shape[1]}维)")
continue
all_features.append(features)
all_labels.append(labels)
print(f"已加载 {file.name}{len(features)} 条样本(特征{features.shape[1]}维)")
except Exception as e:
print(f"处理 {file.name} 时出错:{str(e)},已跳过")
if not all_features:
raise ValueError("没有有效数据可合并,请检查输入文件")
# 合并特征和标签
merged_matrix = np.vstack(all_features)
merged_label = np.concatenate(all_labels, axis=0) # 一维数组拼接
print("\n合并结果:")
print(f"总样本数:{len(merged_matrix)}")
print(f"特征矩阵形状:{merged_matrix.shape}")
# 确保标签为整数后再统计分布
print(f"标签分布:{np.bincount(merged_label)} (索引对应标签值)")
with open(output_path, "wb") as f:
pickle.dump({"matrix": merged_matrix, "label": merged_label}, f)
print(f"\n已成功保存至 {output_path}")
if __name__ == "__main__":
INPUT_DIRECTORY = r"D:\SummerSchool\mat_cv\mat_cv-02"
OUTPUT_FILE = r"D:\SummerSchool\mat_cv\mat_cv-02\feature_dataset00.pkl" # 绝对路径
merge_pkl_files(INPUT_DIRECTORY, OUTPUT_FILE)

Binary file not shown.

After

Width:  |  Height:  |  Size: 530 KiB

@ -0,0 +1,180 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
有标注学习音频特征提取读取瓷空1.wav标注为提取五维特征+标签保存MAT/PKL适配深度学习
"""
from pathlib import Path
import numpy as np
import scipy.io.wavfile as wav
from scipy.io import savemat
from scipy.signal import hilbert
import librosa
import matplotlib.pyplot as plt
import os
import pickle # 用于保存PKL文件
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示正常
plt.rcParams['axes.unicode_minus'] = False # 负号显示正常
# ---------- 参数设定(核心:指定有标注的学习音频路径,标签自动识别) ----------
WAV_FILE = r"D:\SummerSchool\sample\瓷空1.wav" # 有标注的学习音频(文件名含“空”,自动识别标签)
WIN_SIZE = 1024 # 帧长(与测试音频代码一致)
OVERLAP = 512 # 帧移(与测试音频代码一致)
STEP = WIN_SIZE - OVERLAP # 帧步长(与测试音频代码一致)
THRESH = 0.01 # 能量阈值(降低以确保检测到敲击片段,与测试音频代码一致)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒,与测试音频代码一致)
# 标签映射按深度学习习惯定义“空”标注为0后续可根据需求修改若“实”则改为1
LABEL_MAP = {"": 0, "": 1}
# 输出文件路径(默认保存在音频同目录,文件名含“学习”标识,便于区分)
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.mat"
OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.pkl"
# ---------- 工具函数(完全复用之前的逻辑,确保特征一致性) ----------
def segment_signal(signal: np.ndarray, fs: int):
"""按能量切分敲击片段(与测试音频代码完全一致)"""
if signal.ndim > 1: # 双声道自动转单声道
signal = signal[:, 0]
signal = signal / (np.max(np.abs(signal)) + 1e-12) # 音频归一化(避免幅值影响)
# 分帧并计算每帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选能量高于阈值的帧,定位“新敲击”起始点
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return []
# 相邻有效帧间隔>5帧时视为新的敲击避免连续帧重复计数
hit_mask = np.diff(np.concatenate(([0], idx))) > 5
hit_starts = idx[hit_mask]
# 切分固定长度的片段(不足长度时取到音频末尾)
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for start_frame in hit_starts:
start_sample = start_frame * STEP
end_sample = min(start_sample + seg_len, len(signal))
segments.append(signal[start_sample:end_sample])
return segments
def extract_features(sig: np.ndarray, fs: int):
"""提取五维特征(与测试音频代码完全一致,保证深度学习数据匹配)"""
sig = sig.flatten()
if sig.size == 0: # 空片段防报错返回0向量
return np.zeros(5, dtype=np.float32)
# 1. RMS均方根反映音频能量大小
rms = np.sqrt(np.mean(sig ** 2))
# 2. 主频(频谱峰值对应的频率:反映敲击声的主要频率成分)
L = sig.size
freq = np.fft.rfftfreq(L, d=1 / fs) # 频率轴
fft_mag = np.abs(np.fft.rfft(sig)) # 频谱幅值
main_freq = freq[np.argmax(fft_mag)]
# 3. 频谱偏度(反映频谱分布的不对称性:区分“空”“实”的关键特征之一)
spec_power = fft_mag
freq_centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) # 频谱质心
freq_spread = np.sqrt(np.sum(((freq - freq_centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) # 频谱展宽
skewness = np.sum(((freq - freq_centroid) ** 3) * spec_power) / ((np.sum(spec_power) + 1e-12) * (freq_spread ** 3 + 1e-12))
# 4. MFCC第一维均值梅尔频率倒谱系数反映音频的音色特征
try:
mfcc = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13) # 提取13维MFCC
mfcc_mean = float(np.mean(mfcc[0, :])) # 取第一维均值(最能区分音色)
except Exception: # 异常情况如片段过短返回0
mfcc_mean = 0.0
# 5. 包络峰值(希尔伯特变换提取幅度包络:反映敲击声的衰减特性)
amp_envelope = np.abs(hilbert(sig))
env_peak = np.max(amp_envelope)
# 特征格式统一为float32适配深度学习框架
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32)
# ---------- 主程序(核心:自动识别标签+特征+标签保存) ----------
def main():
# 1. 检查音频文件是否存在
wav_path = Path(WAV_FILE)
if not wav_path.exists():
print(f"❌ 错误:音频文件 {WAV_FILE} 不存在!")
return
if wav_path.suffix != ".wav":
print(f"❌ 错误:{wav_path.name} 不是WAV格式")
return
# 2. 读取音频用librosa兼容更多格式保持采样率不变
audio, sr = librosa.load(wav_path, sr=None, mono=True)
print(f"✅ 成功读取学习音频:{wav_path.name}")
print(f" 采样率:{sr} Hz | 音频长度:{len(audio)/sr:.2f}")
# 3. 切分有效敲击片段
segments = segment_signal(audio, sr)
if len(segments) == 0:
print(f"⚠️ 未检测到有效敲击片段可尝试降低THRESH当前{THRESH})或检查音频是否有敲击声。")
return
print(f"✅ 检测到 {len(segments)} 个有效敲击片段")
# 4. 提取特征+自动识别标签
features_list = []
labels_list = []
# 从文件名提取标注“瓷空1.wav”含“空”对应标签0
file_stem = wav_path.stem # 文件名(不含后缀):"瓷空1"
if "" in file_stem:
label = LABEL_MAP[""]
print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(空)")
elif "" in file_stem:
label = LABEL_MAP[""]
print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(实)")
else:
print(f"⚠️ 文件名 {file_stem} 不含''''手动指定标签为0")
label = LABEL_MAP[""] # 手动兜底,可根据实际修改
# 批量提取特征并匹配标签(每个片段对应一个标签)
for i, seg in enumerate(segments, 1):
feat = extract_features(seg, sr)
features_list.append(feat)
labels_list.append(label)
print(f" 片段{i:02d}特征提取完成维度5")
# 5. 整理为矩阵格式(适配深度学习输入)
features_matrix = np.vstack(features_list) # 特征矩阵:(片段数, 5)
labels_array = np.array(labels_list, dtype=np.int8).reshape(-1, 1) # 标签矩阵:(片段数, 1)
print(f"\n✅ 特征与标签整理完成")
print(f" 特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征")
print(f" 标签矩阵形状:{labels_array.shape}(行=片段数,列=1")
# 6. 保存为MAT文件兼容MATLAB深度学习框架
savemat(OUT_MAT, {
"matrix": features_matrix, # 特征矩阵(与之前训练集格式一致)
"label": labels_array # 标签矩阵(与之前训练集格式一致)
})
print(f"✅ MAT文件已保存{OUT_MAT}")
# 7. 保存为PKL文件兼容Python深度学习框架如PyTorch/TensorFlow
with open(OUT_PKL, "wb") as f:
pickle.dump({
"matrix": features_matrix, # 特征矩阵
"label": labels_array # 标签矩阵(含标注信息)
}, f)
print(f"✅ PKL文件已保存{OUT_PKL}")
# 8. 特征可视化(可选,帮助直观查看特征分布)
plt.figure(figsize=(12, 8))
feature_names = ["RMS能量", "主频Hz", "频谱偏度", "MFCC均值", "包络峰值"]
for i in range(5):
plt.subplot(2, 3, i+1)
plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", color="#1f77b4", linewidth=1.5, markersize=4)
plt.xlabel("片段编号", fontsize=10)
plt.ylabel("特征值", fontsize=10)
plt.title(f"特征{i+1}{feature_names[i]}", fontsize=11, fontweight="bold")
plt.grid(True, alpha=0.3)
# 标签信息标注
plt.subplot(2, 3, 6)
plt.text(0.5, 0.6, f"音频文件:{wav_path.name}", ha="center", fontsize=11)
plt.text(0.5, 0.4, f"标注标签:{label}{'' if label==0 else ''}", ha="center", fontsize=11)
plt.text(0.5, 0.2, f"有效片段数:{len(features_matrix)}", ha="center", fontsize=11)
plt.axis("off")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
无标注测试音频特征提取读取单个WAV提取五维特征保存为MAT和PKL无标签
"""
from pathlib import Path
import numpy as np
import scipy.io.wavfile as wav
from scipy.io import savemat
from scipy.signal import hilbert
import librosa
import matplotlib.pyplot as plt
import os
import pickle # 用于保存PKL文件
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示
plt.rcParams['axes.unicode_minus'] = False # 负号正常显示
# ---------- 参数设定(无需改,按原逻辑) ----------
WAV_FILE = r"D:\SummerSchool\test2.wav" # 你的无标注测试音频路径
WIN_SIZE = 1024 # 帧长
OVERLAP = 512 # 帧移
STEP = WIN_SIZE - OVERLAP # 帧步长
THRESH = 0.01 # 降低阈值,确保能检测到片段(已调小)
SEG_LEN_SEC = 0.2 # 每段音频长度(秒)
# 输出文件路径可自定义默认保存在WAV文件同目录
OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.mat"
OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.pkl"
# ---------- 工具函数(完全保留原特征提取逻辑,确保和训练集一致) ----------
def segment_signal(signal: np.ndarray, fs: int):
"""按能量切分音频片段(原逻辑不变)"""
if signal.ndim > 1: # 双声道转单声道
signal = signal[:, 0]
signal = signal / (np.max(np.abs(signal)) + 1e-12) # 归一化
# 分帧+计算帧能量
frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T
energy = np.sum(frames ** 2, axis=1)
# 筛选能量高于阈值的帧,切出有效片段
idx = np.where(energy > THRESH)[0]
if idx.size == 0:
return []
hit_mask = np.diff(np.concatenate(([0], idx))) > 5 # 新敲击起始帧
hit_starts = idx[hit_mask]
seg_len = int(round(SEG_LEN_SEC * fs))
segments = []
for h in hit_starts:
start = h * STEP
end = min(start + seg_len, len(signal))
segments.append(signal[start:end])
return segments
def extract_features(sig: np.ndarray, fs: int):
"""提取五维特征(和训练集完全一致,保证特征匹配)"""
sig = sig.flatten()
if sig.size == 0:
return np.zeros(5)
# 1. RMS均方根
rms = np.sqrt(np.mean(sig ** 2))
# 2. 主频(频谱峰值对应的频率)
L = sig.size
f = np.fft.rfftfreq(L, d=1 / fs)
Y = np.abs(np.fft.rfft(sig))
main_freq = f[np.argmax(Y)]
# 3. 频谱偏度
P = Y
centroid = np.sum(f * P) / (np.sum(P) + 1e-12)
spread = np.sqrt(np.sum(((f - centroid) ** 2) * P) / (np.sum(P) + 1e-12))
skewness = np.sum(((f - centroid) ** 3) * P) / ((np.sum(P) + 1e-12) * (spread ** 3 + 1e-12))
# 4. MFCC第一维均值
try:
mfccs = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13)
mfcc_mean = float(np.mean(mfccs[0, :]))
except Exception:
mfcc_mean = 0.0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(sig)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# ---------- 主程序(核心:去掉标签,只提特征+保存) ----------
def main():
# 1. 读取音频文件
wav_path = Path(WAV_FILE)
if not (wav_path.exists() and wav_path.suffix == ".wav"):
print(f"❌ 错误:{WAV_FILE} 不存在或不是WAV文件")
return
# 用librosa读取兼容性更好避免格式问题
y, fs = librosa.load(wav_path, sr=None, mono=True)
print(f"✅ 成功读取音频:{wav_path.name},采样率:{fs} Hz")
# 2. 切分有效片段
segments = segment_signal(y, fs)
if len(segments) == 0:
print(f"⚠️ 未检测到有效音频片段尝试再降低THRESH当前{THRESH}")
return
print(f"✅ 检测到 {len(segments)} 个有效片段")
# 3. 提取五维特征
features = [extract_features(seg, fs) for seg in segments]
features_matrix = np.vstack(features).astype(np.float32) # 特征矩阵N行5列N=片段数)
print(f"✅ 提取特征完成,特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征")
# 4. 保存为MAT文件兼容MATLAB
savemat(OUT_MAT, {"matrix": features_matrix}) # 只存特征矩阵无label
print(f"✅ MAT文件已保存{OUT_MAT}")
# 5. 保存为PKL文件兼容Python后续模型推断
with open(OUT_PKL, "wb") as f:
pickle.dump({"matrix": features_matrix}, f) # 和训练集PKL结构一致只少label
print(f"✅ PKL文件已保存{OUT_PKL}")
# (可选)绘制特征可视化图
plt.figure(figsize=(10, 6))
feature_names = ["RMS", "主频(Hz)", "频谱偏度", "MFCC均值", "包络峰值"]
for i in range(5):
plt.subplot(2, 3, i+1)
plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", linewidth=1.5)
plt.xlabel("片段编号")
plt.ylabel("特征值")
plt.title(f"特征:{feature_names[i]}")
plt.grid(True)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,55 @@
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import joblib
# 加载数据集使用你之前合并的feature_dataset00.pkl
def load_dataset(pkl_path):
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data['label']
# 训练模型
def train_and_save_model(dataset_path, model_save_path, scaler_save_path):
# 加载数据
X, y = load_dataset(dataset_path)
print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征")
# 划分训练集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 标准化
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)
# 训练SVM
svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42)
svm.fit(X_train_std, y_train)
# 评估
y_pred = svm.predict(X_test_std)
print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}")
# 保存模型和标准化器
joblib.dump(svm, model_save_path)
joblib.dump(scaler, scaler_save_path)
print(f"模型已保存至:{model_save_path}")
print(f"标准化器已保存至:{scaler_save_path}")
if __name__ == "__main__":
# 替换为你的数据集路径
DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl"
# 模型保存路径与GUI代码中设置的路径一致
MODEL_PATH = "svm_model.pkl"
SCALER_PATH = "scaler.pkl"
train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH)

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
"""
交叉验证最小改动版与原结构一致仅在第6段训练后新增模型与scaler导出
"""
from pathlib import Path
import pickle
import numpy as np
import matplotlib.pyplot as plt
# 新增/补全的 import
import joblib
from scipy.stats import randint, loguniform, norm
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.pipeline import Pipeline
# ---------- 1. 数据路径 ----------
BASE_DIR = Path(r'D:\SummerSchool\mat_cv\mat_cv')
TRAIN_PKL = BASE_DIR / 'cv10_train.pkl'
TEST_FILES = [BASE_DIR / 'cv10_test.pkl'] # 也可放多个测试集 pkl 文件
# ---------- 2. 工具 ----------
def load_pkl_matrix(path: Path):
with open(path, 'rb') as f:
data = pickle.load(f)
return data['matrix'], data.get('label')
# ---------- 3. 读取训练集 ----------
X_train, y_train = load_pkl_matrix(TRAIN_PKL)
if y_train is None:
raise ValueError('训练集缺少 label 字段')
y_train = y_train.ravel()
# {0,1} → {-1,+1}
y_train_signed = np.where(y_train == 0, -1, 1)
# ---------- 4. 标准化 ----------
scaler = StandardScaler().fit(X_train)
X_train_std = scaler.transform(X_train)
n_features = X_train_std.shape[1]
# ---------- 5. RandomizedSearchCV 搜索 ----------
pipe = Pipeline([
('sel', SelectKBest(mutual_info_classif)),
('svm', SVC(kernel='rbf', class_weight='balanced', probability=True))
])
param_dist = {
'sel__k': randint(1, n_features + 1),
'svm__C': loguniform(1e-3, 1e3),
'svm__gamma': loguniform(1e-6, 1e1)
}
cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
search = RandomizedSearchCV(
pipe,
param_distributions=param_dist,
n_iter=60, # 搜索次数
scoring='roc_auc',
cv=cv_inner,
n_jobs=-1,
random_state=42,
verbose=1
)
search.fit(X_train_std, y_train_signed)
best_params = search.best_params_
print("\n▶ RandomizedSearch 最佳参数:", best_params)
print(f" 内层 5-折 AUC ≈ {search.best_score_:.4f}")
# ---------- 6. 训练最终流水线 ----------
final_model = search.best_estimator_
final_model.fit(X_train_std, y_train_signed)
# ---------- 6.5 新增:导出模型与标准化器(供 GUI 使用) ----------
# 输出到 BASE_DIR 下,也可按需改路径
model_out = BASE_DIR / 'svm_model.pkl'
scaler_out = BASE_DIR / 'scaler.pkl'
joblib.dump(final_model, model_out)
joblib.dump(scaler, scaler_out)
print(f"\n✅ 已导出模型与标尺:\n 模型: {model_out}\n 标尺: {scaler_out}\n has_predict_proba: {hasattr(final_model, 'predict_proba')}")
# ---------- 7. 外层 5-折交叉验证 ----------
cv_outer = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_auc = cross_val_score(final_model, X_train_std, y_train_signed,
cv=cv_outer, scoring='roc_auc', n_jobs=-1)
cv_acc = cross_val_score(final_model, X_train_std, y_train_signed,
cv=cv_outer, scoring='accuracy', n_jobs=-1)
print('\n========== 外层 5-折交叉验证 ==========')
print(f'AUC = {cv_auc.mean():.4f} ± {cv_auc.std():.4f}')
print(f'ACC = {cv_acc.mean():.4f} ± {cv_acc.std():.4f}')
# ---------- 8. 推断 ----------
THRESHOLD = 0.5
Z = norm.ppf(0.975)
infer_results = []
print('\n========== 推断结果 ==========')
for pkl_path in TEST_FILES:
X_test, _ = load_pkl_matrix(pkl_path)
X_test_std = scaler.transform(X_test)
pred_signed = final_model.predict(X_test_std)
proba_pos = final_model.predict_proba(X_test_std)[:, 1]
pred_label = np.where(pred_signed == -1, 0, 1)
mean_p = proba_pos.mean()
sem_p = proba_pos.std(ddof=1) / np.sqrt(len(proba_pos)) if len(proba_pos) > 1 else 0.0
ci_low, ci_high = mean_p - Z * sem_p, mean_p + Z * sem_p
file_label = int(mean_p >= THRESHOLD)
print(f'\n▶ 文件: {pkl_path.name} (样本 {len(pred_label)})')
for i, (lbl, prob) in enumerate(zip(pred_label, proba_pos), 1):
print(f' Sample {i:02d}: pred={lbl} prob(1)={prob:.4f}')
print(' ---- 文件级融合 ----')
print(f' mean_prob(1) = {mean_p:.4f} (95% CI {ci_low:.4f} ~ {ci_high:.4f})')
print(f' Final label = {file_label} (阈值 {THRESHOLD})')
infer_results.append(dict(
file=pkl_path.name,
pred=pred_label.tolist(),
prob=proba_pos.tolist(),
mean_prob=float(mean_p),
ci_low=float(ci_low),
ci_high=float(ci_high),
final_label=int(file_label)
))
# 打印测试文件的原始标签(若有)
try:
print("TEST_FILES 标签:", load_pkl_matrix(TEST_FILES[0])[1])
except Exception:
pass
# ---------- 9. 保存 & 可视化 ----------
out_pkl = BASE_DIR / 'infer_results.pkl'
with open(out_pkl, 'wb') as f:
pickle.dump(infer_results, f)
print(f'\n所有文件结果已保存到: {out_pkl}')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
labels = [r['file'] for r in infer_results]
means = [r['mean_prob'] for r in infer_results]
yerr = [(r['mean_prob'] - r['ci_low'], r['ci_high'] - r['mean_prob'])
for r in infer_results]
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(range(len(means)), means,
yerr=np.array(yerr).T, capsize=5, alpha=0.8)
ax.axhline(THRESHOLD, color='red', ls='--', label=f'阈值 {THRESHOLD}')
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=15)
ax.set_ylim(0, 1)
ax.set_ylabel('mean_prob(空心=0)')
ax.set_title('文件级空心概率 (±95% CI)')
ax.legend()
plt.tight_layout()
desktop = Path.home() / 'Desktop'
save_path = desktop / 'infer_summary.png'
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'可视化图已保存至: {save_path}')
plt.show()
Loading…
Cancel
Save