diff --git a/01src/75%25%.py b/01src/75%25%.py new file mode 100644 index 0000000..cd68ac0 --- /dev/null +++ b/01src/75%25%.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +使用 75% 训练 / 25% 测试 的方式评估 SVM(输出 ACC & AUC) +""" + +import pickle +import numpy as np +from pathlib import Path +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score, roc_auc_score + + +# ---------- 1. 数据路径 ---------- +PKL_PATH = Path(r"D:\Python\空心检测\pythonProject\feature_dataset.pkl") + +# ---------- 2. 读取特征 ---------- +def load_pkl_matrix(path: Path): + with open(path, "rb") as f: + data = pickle.load(f) + return data["matrix"], data["label"] + +X, y = load_pkl_matrix(PKL_PATH) +y = y.ravel() # shape (N,) + +# ---------- 3. 75% / 25% 拆分 ---------- +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.25, random_state=42, stratify=y, shuffle=True +) + +# ---------- 4. 标准化 + SVM ---------- +scaler = StandardScaler().fit(X_train) +X_train_std = scaler.transform(X_train) +X_test_std = scaler.transform(X_test) + +svm = SVC( + kernel="rbf", + C=10, + gamma="scale", + probability=True, + class_weight="balanced", + random_state=42, +) +svm.fit(X_train_std, y_train) + +# ---------- 5. 评估 ---------- +y_pred = svm.predict(X_test_std) +y_proba_pos = svm.predict_proba(X_test_std)[:, list(svm.classes_).index(1)] + +acc = accuracy_score(y_test, y_pred) +auc = roc_auc_score(y_test, y_proba_pos) + +print("\n========== 评估结果 ==========") +print(f"样本总数: {len(y)} | 训练: {len(y_train)} 测试: {len(y_test)}") +print(f"ACC = {acc:.4f}") +print(f"AUC = {auc:.4f}") diff --git a/01src/audio_classification.db b/01src/audio_classification.db new file mode 100644 index 0000000..1c64c02 Binary files /dev/null and b/01src/audio_classification.db differ diff --git a/01src/audio_classifier_gui_01.py b/01src/audio_classifier_gui_01.py new file mode 100644 index 0000000..63b2ade --- /dev/null +++ b/01src/audio_classifier_gui_01.py @@ -0,0 +1,945 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import sys +import os +import time +import numpy as np +import joblib +import librosa +import pyaudio +import wave +import sqlite3 +from datetime import datetime +from scipy.signal import hilbert, find_peaks +from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, + QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout, + QWidget, QProgressBar, QStackedWidget, QMessageBox, + QTableWidget, QTableWidgetItem, QHeaderView, + QLineEdit, QDialog, QDialogButtonBox, QFormLayout, + QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog) +from PyQt5.QtCore import Qt, QThread, pyqtSignal + + +# ---------------------- 数据库管理模块 ---------------------- +class DatabaseManager: + """数据库管理器,负责结果数据的存储与查询""" + + def __init__(self, db_path="audio_classification.db"): + self.db_path = db_path + self.init_database() + + def init_database(self): + """初始化数据库表结构""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # 创建结果表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS classification_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT NOT NULL, + segment_count INTEGER NOT NULL, + segment_labels TEXT NOT NULL, + segment_probs TEXT NOT NULL, + mean_probability REAL NOT NULL, + final_label INTEGER NOT NULL, + label_text TEXT NOT NULL, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + notes TEXT + ) + ''') + + conn.commit() + conn.close() + + def insert_result(self, result_data, notes=""): + """插入新的分类结果""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(''' + INSERT INTO classification_results + (filename, segment_count, segment_labels, segment_probs, + mean_probability, final_label, label_text, notes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ''', ( + result_data['filename'], + result_data['segment_count'], + str(result_data['segment_labels']), + str(result_data['segment_probs']), + result_data['mean_probability'], + result_data['final_label'], + result_data['label_text'], + notes + )) + + conn.commit() + conn.close() + return cursor.lastrowid + + def get_all_results(self): + """获取所有结果记录""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(''' + SELECT * FROM classification_results + ORDER BY create_time DESC + ''') + + results = [] + columns = [desc[0] for desc in cursor.description] + + for row in cursor.fetchall(): + results.append(dict(zip(columns, row))) + + conn.close() + return results + + def search_results(self, filename_filter="", label_filter=None, date_filter=None): + """根据条件搜索结果""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + query = ''' + SELECT * FROM classification_results + WHERE 1=1 + ''' + params = [] + + if filename_filter: + query += " AND filename LIKE ?" + params.append(f'%{filename_filter}%') + + if label_filter is not None: + query += " AND final_label = ?" + params.append(label_filter) + + if date_filter: + query += " AND DATE(create_time) = ?" + params.append(date_filter) + + query += " ORDER BY create_time DESC" + + cursor.execute(query, params) + + results = [] + columns = [desc[0] for desc in cursor.description] + + for row in cursor.fetchall(): + results.append(dict(zip(columns, row))) + + conn.close() + return results + + def update_result(self, result_id, updates): + """更新结果记录""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + set_clause = ", ".join([f"{key} = ?" for key in updates.keys()]) + query = f"UPDATE classification_results SET {set_clause} WHERE id = ?" + + params = list(updates.values()) + params.append(result_id) + + cursor.execute(query, params) + conn.commit() + conn.close() + + return cursor.rowcount > 0 + + def delete_result(self, result_id): + """删除结果记录""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,)) + conn.commit() + conn.close() + + return cursor.rowcount > 0 + + +# ---------------------- 特征提取模块(与训练集保持一致) ---------------------- +class FeatureProcessor: + """特征提取器,统一使用训练时的特征维度和计算方式""" + WIN = 1024 + OVERLAP = 512 + THRESHOLD = 0.1 + SEG_LEN_S = 0.2 + STEP = WIN - OVERLAP + + @staticmethod + def frame_energy(signal: np.ndarray) -> np.ndarray: + """计算帧能量""" + frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN, + hop_length=FeatureProcessor.STEP) + return np.sum(frames ** 2, axis=0) + + @staticmethod + def detect_impacts(energy: np.ndarray) -> np.ndarray: + """检测有效敲击片段起始点""" + idx = np.where(energy > FeatureProcessor.THRESHOLD)[0] + if idx.size == 0: + return np.array([]) + # 间隔超过5帧视为新片段 + flags = np.diff(np.concatenate(([0], idx))) > 5 + return idx[flags] + + @staticmethod + def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]: + """切分固定长度的音频片段""" + seg_len = int(FeatureProcessor.SEG_LEN_S * sr) + segments = [] + for frame_idx in starts: + start = frame_idx * FeatureProcessor.STEP + end = min(start + seg_len, len(signal)) + segments.append(signal[start:end]) + return segments + + @staticmethod + def extract_features(signal: np.ndarray, sr: int) -> np.ndarray: + """提取与训练集一致的5维特征(修正原GUI可能的特征不匹配问题)""" + signal = signal.flatten() + if len(signal) == 0: + return np.zeros(5, dtype=np.float32) + + # 1. RMS能量 + rms = np.sqrt(np.mean(signal ** 2)) + + # 2. 主频(频谱峰值) + fft = np.fft.fft(signal) + freq = np.fft.fftfreq(len(signal), 1 / sr) + positive_mask = freq >= 0 + freq = freq[positive_mask] + fft_mag = np.abs(fft[positive_mask]) + main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 + + # 3. 频谱偏度 + spec_power = fft_mag + centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) + spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / + (np.sum(spec_power) + 1e-12)) + skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( + (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0 + + # 4. MFCC第一维均值 + mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13) + mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 + + # 5. 包络峰值(希尔伯特变换) + env_peak = np.max(np.abs(hilbert(signal))) + + return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32) + + +# ---------------------- 录音线程 ---------------------- +class AudioRecorder(QThread): + status_updated = pyqtSignal(str) + progress_updated = pyqtSignal(int) + recording_finished = pyqtSignal(str) + + def __init__(self, max_duration=60): + super().__init__() + self.max_duration = max_duration + self.recording = False + self.temp_file = "temp_audio.wav" + + def run(self): + # 音频参数(与特征提取兼容) + FORMAT = pyaudio.paFloat32 + CHANNELS = 1 + RATE = 22050 + CHUNK = 1024 + + try: + audio = pyaudio.PyAudio() + stream = audio.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=CHUNK + ) + + self.recording = True + start_time = time.time() + frames = [] + + while self.recording: + elapsed = time.time() - start_time + if elapsed >= self.max_duration: + self.status_updated.emit(f"已达最大时长 {self.max_duration} 秒") + break + + self.progress_updated.emit(int((elapsed / self.max_duration) * 100)) + data = stream.read(CHUNK) + frames.append(data) + + # 停止录音并保存 + stream.stop_stream() + stream.close() + audio.terminate() + + if frames: + with wave.open(self.temp_file, 'wb') as wf: + wf.setnchannels(CHANNELS) + wf.setsampwidth(audio.get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}秒") + self.recording_finished.emit(self.temp_file) + else: + self.status_updated.emit("未检测到音频输入") + + except Exception as e: + self.status_updated.emit(f"录音错误: {str(e)}") + + def stop(self): + self.recording = False + + +# ---------------------- 音频处理与预测线程 ---------------------- +class AudioProcessor(QThread): + status_updated = pyqtSignal(str) + result_generated = pyqtSignal(dict) + + def __init__(self, audio_path, model_path, scaler_path): + super().__init__() + self.audio_path = audio_path + self.model_path = model_path + self.scaler_path = scaler_path + self.class_threshold = 0.5 # 分类阈值 + + def run(self): + try: + # 加载模型和标准化器 + self.status_updated.emit("加载模型资源...") + model = joblib.load(self.model_path) + scaler = joblib.load(self.scaler_path) + + # 读取音频文件 + self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}") + signal, sr = librosa.load(self.audio_path, sr=None, mono=True) + signal = signal / np.max(np.abs(signal)) # 归一化 + + # 提取片段 + self.status_updated.emit("检测有效音频片段...") + energy = FeatureProcessor.frame_energy(signal) + impact_starts = FeatureProcessor.detect_impacts(energy) + segments = FeatureProcessor.extract_segments(signal, sr, impact_starts) + + if not segments: + self.status_updated.emit("未检测到有效敲击片段") + return + + # 提取特征并预测 + self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...") + features = [FeatureProcessor.extract_features(seg, sr) for seg in segments] + X = np.vstack(features) + + # 标准化特征 + X_scaled = scaler.transform(X) + + # 模型预测(处理标签转换:-1/1 → 0/1) + predictions = model.predict(X_scaled) + pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率 + pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签 + + # 计算文件级结果 + mean_prob = pred_proba.mean() + final_label = 1 if mean_prob >= self.class_threshold else 0 + result = { + "filename": os.path.basename(self.audio_path), + "segment_count": len(segments), + "segment_labels": pred_labels.tolist(), + "segment_probs": [round(p, 4) for p in pred_proba], + "mean_probability": round(mean_prob, 4), + "final_label": final_label, + "label_text": "空心" if final_label == 0 else "实心" # 假设0=空心,1=实心 + } + + self.result_generated.emit(result) + + except Exception as e: + self.status_updated.emit(f"处理错误: {str(e)}") + + +# ---------------------- 数据库编辑对话框 ---------------------- +class EditResultDialog(QDialog): + def __init__(self, result_data=None, parent=None): + super().__init__(parent) + self.result_data = result_data or {} + self.init_ui() + + def init_ui(self): + self.setWindowTitle("编辑结果记录") + self.setModal(True) + self.resize(400, 300) + + layout = QFormLayout() + + # 文件名 + self.filename_edit = QLineEdit(self.result_data.get('filename', '')) + layout.addRow("文件名:", self.filename_edit) + + # 片段数量 + self.segment_count_spin = QSpinBox() + self.segment_count_spin.setRange(0, 1000) + self.segment_count_spin.setValue(self.result_data.get('segment_count', 0)) + layout.addRow("片段数量:", self.segment_count_spin) + + # 平均概率 + self.mean_prob_spin = QDoubleSpinBox() + self.mean_prob_spin.setRange(0.0, 1.0) + self.mean_prob_spin.setDecimals(4) + self.mean_prob_spin.setSingleStep(0.01) + self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0)) + layout.addRow("平均概率:", self.mean_prob_spin) + + # 最终标签 + self.final_label_combo = QComboBox() + self.final_label_combo.addItem("空心", 0) + self.final_label_combo.addItem("实心", 1) + current_label = self.result_data.get('final_label', 0) + index = self.final_label_combo.findData(current_label) + if index >= 0: + self.final_label_combo.setCurrentIndex(index) + layout.addRow("最终分类:", self.final_label_combo) + + # 备注 + self.notes_edit = QTextEdit() + self.notes_edit.setMaximumHeight(80) + self.notes_edit.setText(self.result_data.get('notes', '')) + layout.addRow("备注:", self.notes_edit) + + # 按钮 + button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + button_box.accepted.connect(self.accept) + button_box.rejected.connect(self.reject) + layout.addRow(button_box) + + self.setLayout(layout) + + def get_updated_data(self): + """获取更新后的数据""" + return { + 'filename': self.filename_edit.text(), + 'segment_count': self.segment_count_spin.value(), + 'mean_probability': self.mean_prob_spin.value(), + 'final_label': self.final_label_combo.currentData(), + 'label_text': self.final_label_combo.currentText(), + 'notes': self.notes_edit.toPlainText() + } + + +# ---------------------- 数据库汇总界面 ---------------------- +class DatabaseUI(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.parent = parent + self.db_manager = DatabaseManager() + self.init_ui() + self.load_data() + + def init_ui(self): + layout = QVBoxLayout() + layout.setContentsMargins(20, 20, 20, 20) + + # 返回按钮 + btn_back = QPushButton("← 返回主菜单") + btn_back.clicked.connect(lambda: self.parent.switch_page("main")) + layout.addWidget(btn_back) + + # 标题 + title = QLabel("数据库汇总管理") + title.setStyleSheet("font-size: 20px; margin: 15px 0;") + layout.addWidget(title) + + # 创建标签页 + self.tab_widget = QTabWidget() + + # 数据浏览标签页 + self.browse_tab = QWidget() + self.init_browse_tab() + self.tab_widget.addTab(self.browse_tab, "数据浏览") + + # 搜索标签页 + self.search_tab = QWidget() + self.init_search_tab() + self.tab_widget.addTab(self.search_tab, "搜索过滤") + + layout.addWidget(self.tab_widget) + self.setLayout(layout) + + def init_browse_tab(self): + layout = QVBoxLayout() + + # 操作按钮 + btn_layout = QHBoxLayout() + + self.btn_refresh = QPushButton("刷新数据") + self.btn_refresh.clicked.connect(self.load_data) + + self.btn_add = QPushButton("新增记录") + self.btn_add.clicked.connect(self.add_result) + + self.btn_edit = QPushButton("编辑选中") + self.btn_edit.clicked.connect(self.edit_selected) + + self.btn_delete = QPushButton("删除选中") + self.btn_delete.clicked.connect(self.delete_selected) + + btn_layout.addWidget(self.btn_refresh) + btn_layout.addWidget(self.btn_add) + btn_layout.addWidget(self.btn_edit) + btn_layout.addWidget(self.btn_delete) + btn_layout.addStretch() + + layout.addLayout(btn_layout) + + # 数据表格 + self.table_widget = QTableWidget() + self.table_widget.setColumnCount(8) + self.table_widget.setHorizontalHeaderLabels([ + "ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作" + ]) + self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) + self.table_widget.setSelectionBehavior(QTableWidget.SelectRows) + + layout.addWidget(self.table_widget) + self.browse_tab.setLayout(layout) + + def init_search_tab(self): + layout = QVBoxLayout() + + # 搜索条件表单 + form_layout = QFormLayout() + + self.search_filename = QLineEdit() + self.search_filename.setPlaceholderText("输入文件名关键词") + form_layout.addRow("文件名:", self.search_filename) + + self.search_label = QComboBox() + self.search_label.addItem("全部", None) + self.search_label.addItem("空心", 0) + self.search_label.addItem("实心", 1) + form_layout.addRow("分类结果:", self.search_label) + + self.search_date = QLineEdit() + self.search_date.setPlaceholderText("YYYY-MM-DD") + form_layout.addRow("创建日期:", self.search_date) + + layout.addLayout(form_layout) + + # 搜索按钮 + btn_search = QPushButton("搜索") + btn_search.clicked.connect(self.perform_search) + layout.addWidget(btn_search) + + # 搜索结果表格 + self.search_table = QTableWidget() + self.search_table.setColumnCount(8) + self.search_table.setHorizontalHeaderLabels([ + "ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作" + ]) + self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) + self.search_table.setSelectionBehavior(QTableWidget.SelectRows) + + layout.addWidget(self.search_table) + self.search_tab.setLayout(layout) + + def load_data(self): + """加载所有数据到浏览表格""" + results = self.db_manager.get_all_results() + self.populate_table(self.table_widget, results) + + def perform_search(self): + """执行搜索操作""" + filename_filter = self.search_filename.text().strip() + label_filter = self.search_label.currentData() + date_filter = self.search_date.text().strip() or None + + results = self.db_manager.search_results(filename_filter, label_filter, date_filter) + self.populate_table(self.search_table, results) + + def populate_table(self, table, results): + """填充表格数据""" + table.setRowCount(len(results)) + + for row, result in enumerate(results): + # ID + table.setItem(row, 0, QTableWidgetItem(str(result['id']))) + + # 文件名 + table.setItem(row, 1, QTableWidgetItem(result['filename'])) + + # 片段数 + table.setItem(row, 2, QTableWidgetItem(str(result['segment_count']))) + + # 平均概率 + table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}")) + + # 分类结果 + table.setItem(row, 4, QTableWidgetItem(result['label_text'])) + + # 创建时间 + create_time = result['create_time'] + if isinstance(create_time, str): + table.setItem(row, 5, QTableWidgetItem(create_time)) + else: + table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S"))) + + # 备注 + notes = result.get('notes', '') + table.setItem(row, 6, QTableWidgetItem(notes if notes else "无")) + + # 操作按钮 + btn_view = QPushButton("查看详情") + btn_view.clicked.connect(lambda checked, r=result: self.view_details(r)) + table.setCellWidget(row, 7, btn_view) + + def add_result(self): + """新增结果记录""" + dialog = EditResultDialog() + if dialog.exec_() == QDialog.Accepted: + new_data = dialog.get_updated_data() + # 设置一些默认值 + new_data['segment_labels'] = [] + new_data['segment_probs'] = [] + + try: + self.db_manager.insert_result(new_data, new_data.get('notes', '')) + self.load_data() + QMessageBox.information(self, "成功", "记录添加成功!") + except Exception as e: + QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}") + + def edit_selected(self): + """编辑选中的记录""" + current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget() + current_row = current_table.currentRow() + + if current_row < 0: + QMessageBox.warning(self, "警告", "请先选择一条记录!") + return + + result_id = int(current_table.item(current_row, 0).text()) + results = self.db_manager.search_results() + result_data = next((r for r in results if r['id'] == result_id), None) + + if not result_data: + QMessageBox.warning(self, "错误", "未找到选中的记录!") + return + + dialog = EditResultDialog(result_data) + if dialog.exec_() == QDialog.Accepted: + updated_data = dialog.get_updated_data() + + try: + success = self.db_manager.update_result(result_id, updated_data) + if success: + self.load_data() + if current_table == self.search_table: + self.perform_search() + QMessageBox.information(self, "成功", "记录更新成功!") + else: + QMessageBox.warning(self, "错误", "更新记录失败!") + except Exception as e: + QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}") + + def delete_selected(self): + """删除选中的记录""" + current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget() + current_row = current_table.currentRow() + + if current_row < 0: + QMessageBox.warning(self, "警告", "请先选择一条记录!") + return + + result_id = int(current_table.item(current_row, 0).text()) + filename = current_table.item(current_row, 1).text() + + reply = QMessageBox.question( + self, "确认删除", + f"确定要删除文件 '{filename}' 的记录吗?", + QMessageBox.Yes | QMessageBox.No + ) + + if reply == QMessageBox.Yes: + try: + success = self.db_manager.delete_result(result_id) + if success: + self.load_data() + if current_table == self.search_table: + self.perform_search() + QMessageBox.information(self, "成功", "记录删除成功!") + else: + QMessageBox.warning(self, "错误", "删除记录失败!") + except Exception as e: + QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}") + + def view_details(self, result_data): + """查看记录详情""" + details = ( + f"记录ID: {result_data['id']}\n" + f"文件名: {result_data['filename']}\n" + f"片段数量: {result_data['segment_count']}\n" + f"片段标签: {result_data['segment_labels']}\n" + f"片段概率: {result_data['segment_probs']}\n" + f"平均概率: {result_data['mean_probability']:.4f}\n" + f"最终分类: {result_data['label_text']}\n" + f"创建时间: {result_data['create_time']}\n" + f"备注: {result_data.get('notes', '无')}" + ) + + QMessageBox.information(self, "记录详情", details) + + +# ---------------------- 主界面组件 ---------------------- +class MainMenu(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.parent = parent + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + layout.setSpacing(20) + + # 标题 + title = QLabel("音频分类系统") + title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;") + layout.addWidget(title) + + # 功能按钮 + btn_record = QPushButton("录制音频") + btn_record.setMinimumHeight(70) + btn_record.setStyleSheet("font-size: 18px;") + btn_record.clicked.connect(lambda: self.parent.switch_page("record")) + + btn_upload = QPushButton("上传音频文件") + btn_upload.setMinimumHeight(70) + btn_upload.setStyleSheet("font-size: 18px;") + btn_upload.clicked.connect(lambda: self.parent.switch_page("upload")) + + btn_database = QPushButton("数据库管理") + btn_database.setMinimumHeight(70) + btn_database.setStyleSheet("font-size: 18px;") + btn_database.clicked.connect(lambda: self.parent.switch_page("database")) + + layout.addWidget(btn_record) + layout.addWidget(btn_upload) + layout.addWidget(btn_database) + layout.addStretch() + + self.setLayout(layout) + + +class InputPage(QWidget): + def __init__(self, parent=None, mode="record"): + super().__init__(parent) + self.parent = parent + self.mode = mode # "record" 或 "upload" + self.audio_path = "" + self.recorder = None + self.db_manager = DatabaseManager() # 添加数据库管理器 + self.init_ui() + + def init_ui(self): + # 主布局 + main_layout = QVBoxLayout() + main_layout.setContentsMargins(20, 20, 20, 20) + + # 返回按钮 + btn_back = QPushButton("← 返回主菜单") + btn_back.clicked.connect(lambda: self.parent.switch_page("main")) + main_layout.addWidget(btn_back) + + # 标题 + title = QLabel("录制音频" if self.mode == "record" else "上传音频文件") + title.setStyleSheet("font-size: 20px; margin: 15px 0;") + main_layout.addWidget(title) + + # 操作区域 + if self.mode == "record": + # 录音控制 + self.btn_start_rec = QPushButton("开始录音") + self.btn_start_rec.clicked.connect(self.start_recording) + self.btn_stop_rec = QPushButton("停止录音") + self.btn_stop_rec.setEnabled(False) + self.btn_stop_rec.clicked.connect(self.stop_recording) + + # 进度条 + self.progress_bar = QProgressBar() + self.progress_bar.setRange(0, 100) + self.progress_bar.setValue(0) + + # 录音布局 + rec_layout = QHBoxLayout() + rec_layout.addWidget(self.btn_start_rec) + rec_layout.addWidget(self.btn_stop_rec) + main_layout.addLayout(rec_layout) + main_layout.addWidget(self.progress_bar) + + else: + # 上传控制 + self.btn_browse = QPushButton("选择WAV文件") + self.btn_browse.clicked.connect(self.browse_file) + self.lbl_file = QLabel("未选择文件") + self.lbl_file.setStyleSheet("color: #666;") + main_layout.addWidget(self.btn_browse) + main_layout.addWidget(self.lbl_file) + + # 状态显示 + self.status_display = QTextEdit() + self.status_display.setReadOnly(True) + self.status_display.setMinimumHeight(150) + main_layout.addWidget(QLabel("状态信息:")) + main_layout.addWidget(self.status_display) + + # 处理按钮 + self.btn_process = QPushButton("开始分析") + self.btn_process.setEnabled(False) + self.btn_process.clicked.connect(self.process_audio) + main_layout.addWidget(self.btn_process) + + main_layout.addStretch() + self.setLayout(main_layout) + + def start_recording(self): + self.status_display.append("开始录音...") + self.recorder = AudioRecorder(max_duration=60) + self.recorder.status_updated.connect(self.update_status) + self.recorder.progress_updated.connect(self.progress_bar.setValue) + self.recorder.recording_finished.connect(self.on_recording_finished) + self.recorder.start() + self.btn_start_rec.setEnabled(False) + self.btn_stop_rec.setEnabled(True) + + def stop_recording(self): + if self.recorder and self.recorder.recording: + self.recorder.stop() + self.btn_start_rec.setEnabled(True) + self.btn_stop_rec.setEnabled(False) + + def on_recording_finished(self, file_path): + self.audio_path = file_path + self.btn_process.setEnabled(True) + self.status_display.append(f"录音文件已保存: {file_path}") + + def browse_file(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "选择音频文件", "", "WAV文件 (*.wav)" + ) + if file_path: + self.audio_path = file_path + self.lbl_file.setText(os.path.basename(file_path)) + self.btn_process.setEnabled(True) + self.update_status(f"已选择文件: {file_path}") + + def process_audio(self): + if not self.audio_path or not os.path.exists(self.audio_path): + QMessageBox.warning(self, "错误", "音频文件不存在") + return + + # 模型路径(请根据实际情况修改) + model_path = "pipeline_model.pkl" # 或 "svm_model.pkl" + scaler_path = "scaler.pkl" + + if not (os.path.exists(model_path) and os.path.exists(scaler_path)): + QMessageBox.warning(self, "错误", "模型文件或标准化器不存在") + return + + # 启动处理线程 + self.processor = AudioProcessor(self.audio_path, model_path, scaler_path) + self.processor.status_updated.connect(self.update_status) + self.processor.result_generated.connect(self.show_result) + self.processor.start() + self.btn_process.setEnabled(False) + self.update_status("开始分析音频...") + + def update_status(self, message): + self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}") + # 自动滚动到底部 + self.status_display.moveCursor(self.status_display.textCursor().End) + + def show_result(self, result): + # 显示结果对话框,增加保存到数据库的选项 + msg = QMessageBox() + msg.setWindowTitle("分析结果") + msg.setIcon(QMessageBox.Information) + + text = ( + f"文件名: {result['filename']}\n" + f"有效片段数: {result['segment_count']}\n" + f"平均概率: {result['mean_probability']}\n" + f"最终分类: {result['label_text']}" + ) + msg.setText(text) + + # 添加保存到数据库的按钮 + msg.addButton("保存到数据库", QMessageBox.AcceptRole) + msg.addButton("仅查看", QMessageBox.RejectRole) + + reply = msg.exec_() + + if reply == 0: # 保存到数据库 + notes, ok = QInputDialog.getText( + self, "添加备注", "请输入备注信息(可选):", + QLineEdit.Normal, "" + ) + if ok: + try: + self.db_manager.insert_result(result, notes) + self.update_status("结果已保存到数据库") + except Exception as e: + QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}") + + self.update_status("分析完成") + self.btn_process.setEnabled(True) + + +# ---------------------- 主窗口 ---------------------- +class AudioClassifierApp(QMainWindow): + def __init__(self): + super().__init__() + self.init_ui() + + def init_ui(self): + self.setWindowTitle("音频分类器") + self.setGeometry(300, 300, 800, 600) + + # 堆叠窗口管理页面 + self.stack = QStackedWidget() + self.main_menu = MainMenu(self) + self.record_page = InputPage(self, mode="record") + self.upload_page = InputPage(self, mode="upload") + self.database_page = DatabaseUI(self) # 新增数据库页面 + + self.stack.addWidget(self.main_menu) + self.stack.addWidget(self.record_page) + self.stack.addWidget(self.upload_page) + self.stack.addWidget(self.database_page) + + self.setCentralWidget(self.stack) + + def switch_page(self, page_name): + """切换页面""" + if page_name == "main": + self.stack.setCurrentWidget(self.main_menu) + elif page_name == "record": + self.stack.setCurrentWidget(self.record_page) + elif page_name == "upload": + self.stack.setCurrentWidget(self.upload_page) + elif page_name == "database": + self.stack.setCurrentWidget(self.database_page) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = AudioClassifierApp() + window.show() + sys.exit(app.exec_()) \ No newline at end of file diff --git a/01src/audio_classifier_gui_02.py b/01src/audio_classifier_gui_02.py new file mode 100644 index 0000000..a630bd9 --- /dev/null +++ b/01src/audio_classifier_gui_02.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import sys +import os +import time +import numpy as np +import joblib +import librosa +import pyaudio +import wave +from scipy.signal import hilbert +from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, + QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, + QHBoxLayout, QWidget, QProgressBar, QStackedWidget, + QComboBox) # >>> CHANGED: add QComboBox +from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer + + +# 特征提取类 +class FeatureExtractor: + WIN = 1024 + OVERLAP = 512 + THRESHOLD = 0.1 + SEG_LEN_S = 0.2 + STEP = WIN - OVERLAP + + @staticmethod + def frame_energy(signal: np.ndarray) -> np.ndarray: + frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) + return np.sum(frames ** 2, axis=0) + + @staticmethod + def detect_hits(energy: np.ndarray) -> np.ndarray: + idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] + if idx.size == 0: + return np.array([]) + flags = np.diff(np.concatenate(([0], idx))) > 5 + return idx[flags] + + @staticmethod + def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]: + seg_len = int(FeatureExtractor.SEG_LEN_S * sr) + segments = [] + for frame_idx in hit_starts: + s = frame_idx * FeatureExtractor.STEP + e = min(s + seg_len, len(signal)) + segments.append(signal[s:e]) + return segments + + @staticmethod + def extract_features(x: np.ndarray, sr: int) -> np.ndarray: + x = x.flatten() + if len(x) == 0: + return np.zeros(5) + + # 1. RMS能量 + rms = np.sqrt(np.mean(x ** 2)) + + # 2. 主频(频谱峰值频率) + fft = np.fft.fft(x) + freq = np.fft.fftfreq(len(x), 1 / sr) + positive_freq_mask = freq >= 0 + freq = freq[positive_freq_mask] + fft_mag = np.abs(fft[positive_freq_mask]) + main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 + + # 3. 频谱偏度 + spec_power = fft_mag + centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) + spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) + skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( + (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 + + # 4. MFCC第一维均值 + mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) + mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 + + # 5. 包络峰值(希尔伯特变换) + env_peak = np.max(np.abs(hilbert(x))) + + return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) + + +# 录音线程 +class RecordThread(QThread): + update_signal = pyqtSignal(str) + finish_signal = pyqtSignal(str) + progress_signal = pyqtSignal(int) + level_signal = pyqtSignal(int) # >>> NEW: 实时电平(0-100) + + def __init__(self, max_duration=60, device_index=None): + super().__init__() + self.max_duration = max_duration + self.is_recording = False + self.temp_file = "temp_recording.wav" + self.device_index = device_index # >>> NEW + + def run(self): + # >>> CHANGED: 更通用的音频参数 + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + RATE = 44100 + CHUNK = 1024 + + p = None + stream = None + try: + p = pyaudio.PyAudio() + + # 设备信息日志(帮助排查) + try: + device_log = [] + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if int(info.get('maxInputChannels', 0)) > 0: + device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}") + if device_log: + self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log)) + except Exception as _: + pass + + # >>> CHANGED: 指定 input_device_index(若传入) + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + input_device_index=self.device_index, # 可能为 None,则使用系统默认 + frames_per_buffer=CHUNK + ) + + self.is_recording = True + start_time = time.time() + frames = [] + + # 用于电平估计的满刻度 + max_int16 = 32767.0 + + while self.is_recording: + elapsed = time.time() - start_time + if elapsed >= self.max_duration: + self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") + break + + self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) + + data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED + frames.append(data) + + # >>> NEW: 计算电平(RMS) + # 将bytes转为np.int16,归一化到[-1,1],计算RMS并发射到UI + try: + chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 + rms = np.sqrt(np.mean(chunk_np ** 2)) + # 简单映射到0-100 + level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整 + self.level_signal.emit(level) + except Exception: + pass + + if stream is not None: + stream.stop_stream() + stream.close() + if p is not None: + p.terminate() + + if len(frames) > 0: + wf = wave.open(self.temp_file, 'wb') + wf.setnchannels(CHANNELS) + wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节 + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + wf.close() + self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}秒") + self.finish_signal.emit(self.temp_file) + else: + self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)") + + except Exception as e: + # >>> NEW: 打开失败时提示设备列表 + err_msg = f"录制错误: {str(e)}" + try: + if p is not None: + err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)" + finally: + self.update_signal.emit(err_msg) + finally: + try: + if stream is not None: + stream.stop_stream() + stream.close() + except Exception: + pass + try: + if p is not None: + p.terminate() + except Exception: + pass + + +# 处理线程 +class ProcessThread(QThread): + update_signal = pyqtSignal(str) + finish_signal = pyqtSignal(dict) + + def __init__(self, wav_path, model_path, scaler_path): + super().__init__() + self.wav_path = wav_path + self.model_path = model_path + self.scaler_path = scaler_path + + def run(self): + try: + self.update_signal.emit("加载模型和标准化器...") + model = joblib.load(self.model_path) + scaler = joblib.load(self.scaler_path) + + self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") + sig, sr = librosa.load(self.wav_path, sr=None, mono=True) + # 归一化避免静音/削顶影响 + if np.max(np.abs(sig)) > 0: + sig = sig / np.max(np.abs(sig)) + + self.update_signal.emit("提取特征...") + ene = FeatureExtractor.frame_energy(sig) + hit_starts = FeatureExtractor.detect_hits(ene) + segments = FeatureExtractor.extract_segments(sig, sr, hit_starts) + if not segments: + self.update_signal.emit("未检测到有效片段!") + return + + feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments] + X = np.vstack(feats) + self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}") + + X_std = scaler.transform(X) + y_pred = model.predict(X_std) + # 兼容无 predict_proba 的分类器 + if hasattr(model, "predict_proba"): + y_proba = model.predict_proba(X_std)[:, 1] + else: + # 退化:用决策函数映射到(0,1) + if hasattr(model, "decision_function"): + df = model.decision_function(X_std) + y_proba = 1 / (1 + np.exp(-df)) + else: + y_proba = (y_pred.astype(float) + 0.0) + + self.finish_signal.emit({ + "filename": os.path.basename(self.wav_path), + "segments": len(segments), + "predictions": y_pred.tolist(), + "probabilities": [round(float(p), 4) for p in y_proba], + "mean_prob": round(float(np.mean(y_proba)), 4), + "final_label": int(np.mean(y_proba) >= 0.5) + }) + + except Exception as e: + self.update_signal.emit(f"错误: {str(e)}") + + +# 第一层界面:主菜单 +class MainMenuWidget(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.parent = parent + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + title = QLabel("音频分类器") + title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") + layout.addWidget(title) + + # 添加按钮 + record_btn = QPushButton("采集音频") + record_btn.setMinimumHeight(60) + record_btn.setStyleSheet("font-size: 16px;") + record_btn.clicked.connect(lambda: self.parent.switch_to_input("record")) + + upload_btn = QPushButton("上传外部WAV文件") + upload_btn.setMinimumHeight(60) + upload_btn.setStyleSheet("font-size: 16px;") + upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload")) + + layout.addWidget(record_btn) + layout.addWidget(upload_btn) + layout.addStretch(1) + + self.setLayout(layout) + + +# 第二层界面:输入界面(录音或上传文件) +class InputWidget(QWidget): + def __init__(self, parent=None, mode="record"): + super().__init__(parent) + self.parent = parent + self.mode = mode + self.wav_path = "" + self.record_thread = None + self.device_index = None # >>> NEW: 当前选中的输入设备索引 + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + # 返回按钮 + back_btn = QPushButton("返回") + back_btn.clicked.connect(self.parent.switch_to_main_menu) + layout.addWidget(back_btn) + + if self.mode == "record": + self.setup_record_ui(layout) + else: + self.setup_upload_ui(layout) + + # 模型路径 + model_layout = QHBoxLayout() + self.model_path = QLineEdit("svm_model.pkl") + self.scaler_path = QLineEdit("scaler.pkl") + model_layout.addWidget(QLabel("模型路径:")) + model_layout.addWidget(self.model_path) + model_layout.addWidget(QLabel("标准化器路径:")) + model_layout.addWidget(self.scaler_path) + layout.addLayout(model_layout) + + # 处理按钮 + self.process_btn = QPushButton("开始处理") + self.process_btn.setMinimumHeight(50) + self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;") + self.process_btn.clicked.connect(self.start_process) + self.process_btn.setEnabled(False) # 初始不可用 + layout.addWidget(self.process_btn) + + # 日志区域 + self.log_area = QTextEdit() + self.log_area.setReadOnly(True) + layout.addWidget(QLabel("日志:")) + layout.addWidget(self.log_area) + + self.setLayout(layout) + + def setup_record_ui(self, layout): + title = QLabel("音频采集") + title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") + layout.addWidget(title) + + # >>> NEW: 麦克风选择 + device_layout = QHBoxLayout() + device_layout.addWidget(QLabel("输入设备:")) + self.device_combo = QComboBox() + self.refresh_devices() + self.device_combo.currentIndexChanged.connect(self.on_device_changed) + device_layout.addWidget(self.device_combo) + refresh_btn = QPushButton("刷新设备") + refresh_btn.clicked.connect(self.refresh_devices) + device_layout.addWidget(refresh_btn) + layout.addLayout(device_layout) + + # 录音控制 + record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)") + record_hint.setAlignment(Qt.AlignCenter) + + self.record_btn = QPushButton("按住录音") + self.record_btn.setMinimumHeight(80) + self.record_btn.setStyleSheet(""" + QPushButton { + background-color: #ff4d4d; + color: white; + font-size: 18px; + border-radius: 10px; + } + QPushButton:pressed { + background-color: #cc0000; + } + """) + self.record_btn.mousePressEvent = self.start_recording + self.record_btn.mouseReleaseEvent = self.stop_recording + self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu) + + # >>> NEW: 实时电平与录音进度 + self.mic_level = QProgressBar() + self.mic_level.setRange(0, 100) + self.mic_level.setFormat("麦克风电平:%p%") + + self.record_progress = QProgressBar() + self.record_progress.setRange(0, 100) + self.record_progress.setValue(0) + + self.record_duration_label = QLabel("录音时长: 0.0秒") + self.record_duration_label.setAlignment(Qt.AlignCenter) + + layout.addWidget(record_hint) + layout.addWidget(self.record_btn) + layout.addWidget(self.mic_level) # >>> NEW + layout.addWidget(self.record_progress) + layout.addWidget(self.record_duration_label) + + # 录音计时器 + self.record_timer = QTimer(self) + self.record_timer.timeout.connect(self.update_record_duration) + self.record_start_time = 0 + + def refresh_devices(self): + """枚举有输入通道的设备,并填充到下拉框""" + self.device_combo.clear() + try: + p = pyaudio.PyAudio() + default_host_api = p.get_host_api_info_by_index(0) + default_input_index = default_host_api.get("defaultInputDevice", -1) + found = [] + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if int(info.get('maxInputChannels', 0)) > 0: + name = info.get('name', f"Device {i}") + sr = int(info.get('defaultSampleRate', 0)) + label = f"[{i}] {name} (sr={sr})" + self.device_combo.addItem(label, i) + found.append(i) + p.terminate() + + # 选中默认输入设备(若存在) + if default_input_index in found: + idx = found.index(default_input_index) + self.device_combo.setCurrentIndex(idx) + self.device_index = default_input_index + elif found: + self.device_combo.setCurrentIndex(0) + self.device_index = self.device_combo.currentData() + else: + self.device_index = None + except Exception: + self.device_index = None + + def on_device_changed(self, _): + self.device_index = self.device_combo.currentData() + + def setup_upload_ui(self, layout): + title = QLabel("上传WAV文件") + title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") + layout.addWidget(title) + + # 文件选择 + file_layout = QHBoxLayout() + self.file_path = QLineEdit() + self.file_path.setReadOnly(True) + self.browse_btn = QPushButton("浏览WAV文件") + self.browse_btn.clicked.connect(self.browse_file) + file_layout.addWidget(self.file_path) + file_layout.addWidget(self.browse_btn) + layout.addLayout(file_layout) + + def start_recording(self, event): + if event.button() == Qt.LeftButton: + if not self.record_thread or not self.record_thread.isRunning(): + # >>> CHANGED: 传入 device_index + self.record_thread = RecordThread(max_duration=60, device_index=self.device_index) + self.record_thread.update_signal.connect(self.update_log) + self.record_thread.finish_signal.connect(self.on_recording_finish) + self.record_thread.progress_signal.connect(self.record_progress.setValue) + self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW + self.record_thread.start() + + self.record_start_time = time.time() + self.record_timer.start(100) + self.update_log("开始录音...(松开按钮结束)") + + def stop_recording(self, event): + if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning(): + self.record_thread.is_recording = False + self.record_timer.stop() + self.record_duration_label.setText("录音时长: 0.0秒") + self.record_progress.setValue(0) + # 让电平逐步回落 + self.mic_level.setValue(0) + + def update_record_duration(self): + elapsed = time.time() - self.record_start_time + self.record_duration_label.setText(f"录音时长: {elapsed:.1f}秒") + + def on_recording_finish(self, temp_file): + self.wav_path = temp_file + self.process_btn.setEnabled(True) + + def browse_file(self): + file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)") + if file: + self.file_path.setText(file) + self.wav_path = file + self.process_btn.setEnabled(True) + + def start_process(self): + if self.mode == "upload": + self.wav_path = self.file_path.text() + + model_path = self.model_path.text() + scaler_path = self.scaler_path.text() + + if not self.wav_path or not os.path.exists(self.wav_path): + self.log_area.append("请先录音或选择有效的WAV文件!") + return + if not os.path.exists(model_path) or not os.path.exists(scaler_path): + self.log_area.append("模型文件不存在!请先运行train_model.py训练模型") + return + + self.log_area.clear() + self.process_btn.setEnabled(False) + + self.thread = ProcessThread(self.wav_path, model_path, scaler_path) + self.thread.update_signal.connect(self.update_log) + self.thread.finish_signal.connect(self.on_process_finish) + self.thread.start() + + def update_log(self, msg): + self.log_area.append(msg) + + def on_process_finish(self, result): + self.parent.switch_to_result(result) + + +# 第三层界面:结果显示界面 +class ResultWidget(QWidget): + def __init__(self, parent=None, result=None): + super().__init__(parent) + self.parent = parent + self.result = result + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + # 返回按钮 + back_btn = QPushButton("返回") + back_btn.clicked.connect(self.parent.switch_to_input_from_result) + layout.addWidget(back_btn) + + title = QLabel("处理结果") + title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") + layout.addWidget(title) + + # 结果显示 + self.result_area = QTextEdit() + self.result_area.setReadOnly(True) + + if self.result: + res = f"文件名: {self.result['filename']}\n" + res += f"片段数: {self.result['segments']}\n" + res += "预测结果:\n" + for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])): + res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n" + res += f"\n平均概率: {self.result['mean_prob']}\n" + res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}" + self.result_area.setText(res) + + layout.addWidget(self.result_area) + self.setLayout(layout) + + +# 主窗口 +class AudioClassifierGUI(QMainWindow): + def __init__(self): + super().__init__() + self.current_input_mode = "record" # 记录当前输入模式 + self.process_result = None # 存储处理结果 + self.init_ui() + + def init_ui(self): + self.setWindowTitle("音频分类器") + self.setGeometry(100, 100, 800, 600) + + # 创建堆叠窗口 + self.stacked_widget = QStackedWidget() + self.setCentralWidget(self.stacked_widget) + + # 创建三个界面 + self.main_menu_widget = MainMenuWidget(self) + self.record_input_widget = InputWidget(self, "record") + self.upload_input_widget = InputWidget(self, "upload") + self.result_widget = ResultWidget(self) + + # 添加到堆叠窗口 + self.stacked_widget.addWidget(self.main_menu_widget) + self.stacked_widget.addWidget(self.record_input_widget) + self.stacked_widget.addWidget(self.upload_input_widget) + self.stacked_widget.addWidget(self.result_widget) + + # 显示主菜单 + self.stacked_widget.setCurrentWidget(self.main_menu_widget) + + def switch_to_input(self, mode): + self.current_input_mode = mode + if mode == "record": + self.stacked_widget.setCurrentWidget(self.record_input_widget) + else: + self.stacked_widget.setCurrentWidget(self.upload_input_widget) + + def switch_to_main_menu(self): + self.stacked_widget.setCurrentWidget(self.main_menu_widget) + + def switch_to_result(self, result): + self.process_result = result + self.result_widget = ResultWidget(self, result) + # 移除旧的结果界面并添加新的 + self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) + self.stacked_widget.addWidget(self.result_widget) + self.stacked_widget.setCurrentWidget(self.result_widget) + + def switch_to_input_from_result(self): + if self.current_input_mode == "record": + self.stacked_widget.setCurrentWidget(self.record_input_widget) + else: + self.stacked_widget.setCurrentWidget(self.upload_input_widget) + + +if __name__ == "__main__": + try: + import pyaudio # 确保已安装 + except ImportError: + print("请先安装pyaudio: pip install pyaudio") + sys.exit(1) + + app = QApplication(sys.argv) + window = AudioClassifierGUI() + window.show() + sys.exit(app.exec_()) diff --git a/01src/audio_classifier_gui_03.py b/01src/audio_classifier_gui_03.py new file mode 100644 index 0000000..4182315 --- /dev/null +++ b/01src/audio_classifier_gui_03.py @@ -0,0 +1,917 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import sys +import os +import time +import csv +import sqlite3 +from datetime import datetime +import numpy as np +import joblib +import librosa +import pyaudio +import wave +from scipy.signal import hilbert + +from PyQt5.QtWidgets import ( + QApplication, QMainWindow, QPushButton, QLabel, + QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, + QHBoxLayout, QWidget, QProgressBar, QStackedWidget, + QComboBox, QTableWidget, QTableWidgetItem, QMessageBox, + QDialog, QFormLayout, QDialogButtonBox +) +from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer + + +# ====================== +# 本地 SQLite 数据库(含 CRUD) +# ====================== +class DatabaseManager: + def __init__(self, db_path="results.db"): + self.db_path = db_path + self._ensure_schema() + + def _connect(self): + return sqlite3.connect(self.db_path) + + def _ensure_schema(self): + con = self._connect() + cur = con.cursor() + cur.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT, + segments INTEGER, + mean_prob REAL, + final_label INTEGER, + created_at TEXT + ); + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS run_segments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id INTEGER, + seg_index INTEGER, + label INTEGER, + proba REAL, + FOREIGN KEY(run_id) REFERENCES runs(id) + ); + """) + con.commit() + con.close() + + # ---------- 原有写入 ---------- + def insert_result(self, result_dict): + con = self._connect() + cur = con.cursor() + cur.execute( + "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", + ( + result_dict.get("filename"), + int(result_dict.get("segments", 0)), + float(result_dict.get("mean_prob", 0.0)), + int(result_dict.get("final_label", 0)), + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ) + ) + run_id = cur.lastrowid + preds = result_dict.get("predictions", []) + probas = result_dict.get("probabilities", []) + for i, (lab, pr) in enumerate(zip(preds, probas), start=1): + cur.execute( + "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", + (run_id, i, int(lab), float(pr)) + ) + con.commit() + con.close() + return run_id + + # ---------- R(查询列表/单条) ---------- + def fetch_recent_runs(self, limit=50): + con = self._connect() + cur = con.cursor() + cur.execute(""" + SELECT id, created_at, filename, segments, mean_prob, final_label + FROM runs ORDER BY id DESC LIMIT ? + """, (limit,)) + rows = cur.fetchall() + con.close() + return rows + + def get_run(self, run_id: int): + con = self._connect() + cur = con.cursor() + cur.execute(""" + SELECT id, created_at, filename, segments, mean_prob, final_label + FROM runs WHERE id = ? + """, (run_id,)) + row = cur.fetchone() + con.close() + return row + + def search_runs(self, keyword: str, limit=100): + kw = f"%{keyword}%" + con = self._connect() + cur = con.cursor() + cur.execute(""" + SELECT id, created_at, filename, segments, mean_prob, final_label + FROM runs + WHERE filename LIKE ? + ORDER BY id DESC LIMIT ? + """, (kw, limit)) + rows = cur.fetchall() + con.close() + return rows + + # ---------- C/U/D(run) ---------- + def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None): + created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + con = self._connect() + cur = con.cursor() + cur.execute( + "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", + (filename, int(segments), float(mean_prob), int(final_label), created_at) + ) + run_id = cur.lastrowid + con.commit() + con.close() + return run_id + + def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str): + con = self._connect() + cur = con.cursor() + cur.execute(""" + UPDATE runs + SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=? + WHERE id=? + """, (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id))) + con.commit() + con.close() + + def delete_run(self, run_id: int): + con = self._connect() + cur = con.cursor() + # 先删子表 + cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),)) + cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),)) + con.commit() + con.close() + + # ---------- C/R/U/D(segments) ---------- + def list_segments(self, run_id: int): + con = self._connect() + cur = con.cursor() + cur.execute(""" + SELECT id, seg_index, label, proba FROM run_segments + WHERE run_id=? ORDER BY seg_index ASC + """, (int(run_id),)) + rows = cur.fetchall() + con.close() + return rows + + def add_segment(self, run_id: int, seg_index: int, label: int, proba: float): + con = self._connect() + cur = con.cursor() + cur.execute( + "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", + (int(run_id), int(seg_index), int(label), float(proba)) + ) + seg_id = cur.lastrowid + con.commit() + con.close() + return seg_id + + def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float): + con = self._connect() + cur = con.cursor() + cur.execute(""" + UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=? + """, (int(seg_index), int(label), float(proba), int(seg_id))) + con.commit() + con.close() + + def delete_segment(self, seg_id: int): + con = self._connect() + cur = con.cursor() + cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),)) + con.commit() + con.close() + + # ---------- 导出 ---------- + def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000): + rows = self.fetch_recent_runs(limit) + headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(headers) + for r in rows: + writer.writerow(r) + return csv_path + + +DB = DatabaseManager("results.db") + + +# ====================== +# 特征提取(同前) +# ====================== +class FeatureExtractor: + WIN = 1024 + OVERLAP = 512 + THRESHOLD = 0.1 + SEG_LEN_S = 0.2 + STEP = WIN - OVERLAP + + @staticmethod + def frame_energy(signal: np.ndarray) -> np.ndarray: + frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) + return np.sum(frames ** 2, axis=0) + + @staticmethod + def detect_hits(energy: np.ndarray) -> np.ndarray: + idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] + if idx.size == 0: + return np.array([]) + flags = np.diff(np.concatenate(([0], idx))) > 5 + return idx[flags] + + @staticmethod + def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray): + seg_len = int(FeatureExtractor.SEG_LEN_S * sr) + segments = [] + for frame_idx in hit_starts: + s = frame_idx * FeatureExtractor.STEP + e = min(s + seg_len, len(signal)) + segments.append(signal[s:e]) + return segments + + @staticmethod + def extract_features(x: np.ndarray, sr: int) -> np.ndarray: + x = x.flatten() + if len(x) == 0: + return np.zeros(5) + rms = np.sqrt(np.mean(x ** 2)) + fft = np.fft.fft(x) + freq = np.fft.fftfreq(len(x), 1 / sr) + positive = freq >= 0 + freq, fft_mag = freq[positive], np.abs(fft[positive]) + main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 + spec_power = fft_mag + centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) + spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) + skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( + (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 + mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) + mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 + env_peak = np.max(np.abs(hilbert(x))) + return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) + + +# ====================== +# 录音线程(同前) +# ====================== +class RecordThread(QThread): + update_signal = pyqtSignal(str) + finish_signal = pyqtSignal(str) + progress_signal = pyqtSignal(int) + level_signal = pyqtSignal(int) + + def __init__(self, max_duration=60, device_index=None): + super().__init__() + self.max_duration = max_duration + self.is_recording = False + self.temp_file = "temp_recording.wav" + self.device_index = device_index + + def run(self): + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + RATE = 44100 + CHUNK = 1024 + p, stream = None, None + try: + p = pyaudio.PyAudio() + stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, + input_device_index=self.device_index, frames_per_buffer=CHUNK) + self.is_recording = True + start_time = time.time() + frames = [] + max_int16 = 32767.0 + while self.is_recording: + elapsed = time.time() - start_time + if elapsed >= self.max_duration: + self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") + break + self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) + data = stream.read(CHUNK, exception_on_overflow=False) + frames.append(data) + chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 + rms = np.sqrt(np.mean(chunk_np ** 2)) + self.level_signal.emit(int(np.clip(rms * 500, 0, 100))) + stream.stop_stream(); stream.close(); p.terminate() + if frames: + wf = wave.open(self.temp_file, 'wb') + wf.setnchannels(CHANNELS) + wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + wf.close() + self.finish_signal.emit(self.temp_file) + else: + self.update_signal.emit("未录到有效音频。") + except Exception as e: + self.update_signal.emit(f"录制错误: {e}") + + +# ====================== +# 推理处理线程(写入 DB) +# ====================== +class ProcessThread(QThread): + update_signal = pyqtSignal(str) + finish_signal = pyqtSignal(dict) + + def __init__(self, wav_path, model_path, scaler_path): + super().__init__() + self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path + + def run(self): + try: + self.update_signal.emit("加载模型和标准化器...") + model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path) + self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") + sig, sr = librosa.load(self.wav_path, sr=None, mono=True) + sig = sig / max(np.max(np.abs(sig)), 1e-9) + ene = FeatureExtractor.frame_energy(sig) + hits = FeatureExtractor.detect_hits(ene) + segs = FeatureExtractor.extract_segments(sig, sr, hits) + if not segs: + self.update_signal.emit("未检测到有效片段!") + return + feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs] + X = np.vstack(feats) + X_std = scaler.transform(X) + y_pred = model.predict(X_std) + if hasattr(model, "predict_proba"): + y_proba = model.predict_proba(X_std)[:, 1] + else: + y_proba = (y_pred.astype(float) + 0.0) + result = { + "filename": os.path.basename(self.wav_path), + "segments": len(segs), + "predictions": y_pred.tolist(), + "probabilities": [round(float(p), 4) for p in y_proba], + "mean_prob": round(float(np.mean(y_proba)), 4), + "final_label": int(np.mean(y_proba) >= 0.5) + } + DB.insert_result(result) + self.finish_signal.emit(result) + except Exception as e: + self.update_signal.emit(f"错误: {e}") + + +# ====================== +# 对话框:编辑 run +# ====================== +class RunEditDialog(QDialog): + def __init__(self, parent=None, run_row=None): + """ + run_row: (id, created_at, filename, segments, mean_prob, final_label) 或 None(新增) + """ + super().__init__(parent) + self.setWindowTitle("编辑记录" if run_row else "新增记录") + self.resize(380, 240) + self.run_row = run_row + + self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True) + self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + self.ed_filename = QLineEdit() + self.ed_segments = QLineEdit("0") + self.ed_mean_prob = QLineEdit("0.0") + self.ed_final_label = QLineEdit("0") # 0=实心 1=空心 + + form = QFormLayout() + form.addRow("ID:", self.ed_id) + form.addRow("时间(created_at):", self.ed_created) + form.addRow("文件名(filename):", self.ed_filename) + form.addRow("片段数(segments):", self.ed_segments) + form.addRow("平均概率(mean_prob):", self.ed_mean_prob) + form.addRow("最终标签(0=实 1=空):", self.ed_final_label) + + if run_row: + self.ed_id.setText(str(run_row[0])) + self.ed_created.setText(str(run_row[1])) + self.ed_filename.setText(str(run_row[2])) + self.ed_segments.setText(str(run_row[3])) + self.ed_mean_prob.setText(str(run_row[4])) + self.ed_final_label.setText(str(run_row[5])) + + btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + btns.accepted.connect(self.accept) + btns.rejected.connect(self.reject) + + lay = QVBoxLayout() + lay.addLayout(form) + lay.addWidget(btns) + self.setLayout(lay) + + def values(self): + return dict( + id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None), + created_at=self.ed_created.text().strip(), + filename=self.ed_filename.text().strip(), + segments=int(float(self.ed_segments.text() or 0)), + mean_prob=float(self.ed_mean_prob.text() or 0.0), + final_label=int(float(self.ed_final_label.text() or 0)) + ) + + +# ====================== +# 对话框:管理 segments +# ====================== +class SegmentManagerDialog(QDialog): + def __init__(self, parent=None, run_id=None): + super().__init__(parent) + self.setWindowTitle(f"片段管理 - run_id={run_id}") + self.resize(560, 420) + self.run_id = run_id + + self.table = QTableWidget(0, 4) + self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"]) + + self.btn_add = QPushButton("新增片段") + self.btn_edit = QPushButton("编辑所选") + self.btn_del = QPushButton("删除所选") + self.btn_close = QPushButton("关闭") + + hl = QHBoxLayout() + for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close): + hl.addWidget(b) + + lay = QVBoxLayout() + lay.addWidget(self.table) + lay.addLayout(hl) + self.setLayout(lay) + + self.btn_add.clicked.connect(self.add_segment) + self.btn_edit.clicked.connect(self.edit_selected) + self.btn_del.clicked.connect(self.delete_selected) + self.btn_close.clicked.connect(self.accept) + + self.load_segments() + + def load_segments(self): + rows = DB.list_segments(self.run_id) + self.table.setRowCount(len(rows)) + for i, (sid, idx, label, proba) in enumerate(rows): + self.table.setItem(i, 0, QTableWidgetItem(str(sid))) + self.table.setItem(i, 1, QTableWidgetItem(str(idx))) + self.table.setItem(i, 2, QTableWidgetItem(str(label))) + self.table.setItem(i, 3, QTableWidgetItem(str(proba))) + self.table.resizeColumnsToContents() + + def _get_selected_seg_id(self): + r = self.table.currentRow() + if r < 0: + return None + return int(self.table.item(r, 0).text()) + + def add_segment(self): + d = QDialog(self) + d.setWindowTitle("新增片段") + form = QFormLayout() + ed_index = QLineEdit("1") + ed_label = QLineEdit("0") + ed_proba = QLineEdit("0.0") + form.addRow("seg_index:", ed_index) + form.addRow("label(0/1):", ed_label) + form.addRow("proba(0~1):", ed_proba) + btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) + v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) + if d.exec_() == QDialog.Accepted: + DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)), + float(ed_proba.text() or 0.0)) + self.load_segments() + + def edit_selected(self): + seg_id = self._get_selected_seg_id() + if not seg_id: + QMessageBox.information(self, "提示", "请先选择一行片段。") + return + # 读当前值 + row = self.table.currentRow() + cur_idx = self.table.item(row, 1).text() + cur_label = self.table.item(row, 2).text() + cur_proba = self.table.item(row, 3).text() + + d = QDialog(self) + d.setWindowTitle("编辑片段") + form = QFormLayout() + ed_index = QLineEdit(cur_idx) + ed_label = QLineEdit(cur_label) + ed_proba = QLineEdit(cur_proba) + form.addRow("seg_index:", ed_index) + form.addRow("label(0/1):", ed_label) + form.addRow("proba(0~1):", ed_proba) + btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) + v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) + if d.exec_() == QDialog.Accepted: + DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)), + int(float(ed_label.text() or cur_label)), + float(ed_proba.text() or cur_proba)) + self.load_segments() + + def delete_selected(self): + seg_id = self._get_selected_seg_id() + if not seg_id: + QMessageBox.information(self, "提示", "请先选择一行片段。") + return + if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!", + QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes: + DB.delete_segment(seg_id) + self.load_segments() + + +# ====================== +# 主菜单 +# ====================== +class MainMenuWidget(QWidget): + def __init__(self, parent): + super().__init__(parent) + layout = QVBoxLayout() + title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter) + title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") + layout.addWidget(title) + btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record")) + btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload")) + btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history) + for b in [btn1, btn2, btn3]: + b.setMinimumHeight(60); layout.addWidget(b) + layout.addStretch(1); self.setLayout(layout) + + +# ====================== +# 输入界面(录音 / 上传) +# ====================== +class InputWidget(QWidget): + def __init__(self, parent, mode="record"): + super().__init__(parent) + self.parent, self.mode = parent, mode + self.wav_path, self.record_thread = "", None + self.device_index = None + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) + layout.addWidget(back) + if self.mode == "record": self.setup_record(layout) + else: self.setup_upload(layout) + self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") + path_layout = QHBoxLayout() + path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) + path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) + layout.addLayout(path_layout) + self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) + self.process_btn.clicked.connect(self.start_process) + layout.addWidget(self.process_btn) + self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) + self.setLayout(layout) + + def setup_record(self, layout): + title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) + layout.addWidget(title) + dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) + self.dev_combo = QComboBox(); self.refresh_devices() + self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) + dev_layout.addWidget(self.dev_combo) + btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) + dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) + self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) + self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec + self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") + self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) + for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) + self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) + + def setup_upload(self, layout): + h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) + b = QPushButton("浏览"); b.clicked.connect(self.browse) + h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) + + def refresh_devices(self): + self.dev_combo.clear(); p = pyaudio.PyAudio() + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if int(info.get('maxInputChannels',0))>0: + self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) + if self.dev_combo.count()>0: + self.dev_combo.setCurrentIndex(0) + self.device_index = self.dev_combo.currentData() + p.terminate() + + def start_rec(self,e): + if e.button()==Qt.LeftButton: + self.record_thread=RecordThread(device_index=self.device_index) + self.record_thread.update_signal.connect(self.log.append) + self.record_thread.finish_signal.connect(self.finish_rec) + self.record_thread.progress_signal.connect(self.progress.setValue) + self.record_thread.level_signal.connect(self.level.setValue) + self.record_thread.start(); self.start=time.time(); self.timer.start(100) + + def stop_rec(self,e): + if e.button()==Qt.LeftButton and self.record_thread: + self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) + + def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") + + def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) + + def browse(self): + f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") + if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) + + def start_process(self): + if self.mode=="upload": self.wav_path=self.file.text() + if not os.path.exists(self.wav_path): self.log.append("无效文件"); return + self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) + self.thread.update_signal.connect(self.log.append) + self.thread.finish_signal.connect(self.parent.switch_to_result) + self.thread.start() + + +# ====================== +# 历史记录(含 CRUD 按钮) +# ====================== +class HistoryWidget(QWidget): + def __init__(self, parent): + super().__init__(parent) + self.parent=parent + self.init_ui() + + def init_ui(self): + l=QVBoxLayout() + back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back) + + top=QHBoxLayout() + self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...") + btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search) + btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load) + btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv) + btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run) + btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected) + btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected) + btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments) + for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]: + top.addWidget(w) + l.addLayout(top) + + self.table=QTableWidget(0,6) + self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"]) + l.addWidget(self.table) + self.setLayout(l) + self.load() + + def _current_run_id(self): + r=self.table.currentRow() + if r<0: return None + return int(self.table.item(r,0).text()) + + # ---- 数据加载 / 搜索 ---- + def load(self): + rows=DB.fetch_recent_runs(200) + self._fill(rows) + + def search(self): + kw=self.search_edit.text().strip() + rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200) + self._fill(rows) + + def _fill(self, rows): + self.table.setRowCount(len(rows)) + for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows): + self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid))) + self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at))) + self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn))) + self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs))) + self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob))) + self.table.setItem(r_idx, 5, QTableWidgetItem("实心" if int(final_label) else "空心")) + self.table.resizeColumnsToContents() + + # ---- 导出 ---- + def export_csv(self): + p=DB.export_runs_to_csv() + QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}") + + # ---- CRUD: run ---- + def add_run(self): + dlg=RunEditDialog(self, run_row=None) + if dlg.exec_()==QDialog.Accepted: + v=dlg.values() + DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) + self.load() + + def edit_selected(self): + rid=self._current_run_id() + if not rid: + QMessageBox.information(self,"提示","请先选择一行记录。") + return + row=DB.get_run(rid) + dlg=RunEditDialog(self, run_row=row) + if dlg.exec_()==QDialog.Accepted: + v=dlg.values() + DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) + self.load() + + def delete_selected(self): + rid=self._current_run_id() + if not rid: + QMessageBox.information(self,"提示","请先选择一行记录。") + return + if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!", + QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes: + DB.delete_run(rid) + self.load() + + # ---- 片段管理 ---- + def open_segments(self): + rid=self._current_run_id() + if not rid: + QMessageBox.information(self,"提示","请先选择一行记录。") + return + dlg=SegmentManagerDialog(self, run_id=rid) + dlg.exec_() + # 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主 + self.load() + + +# ====================== +# 结果页(简) +# ====================== +class ResultWidget(QWidget): + def __init__(self, parent, result=None): + super().__init__(parent); self.parent=parent + layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b) + t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t) + area=QTextEdit(); area.setReadOnly(True) + if result: + txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'实心' if result['final_label'] else '空心'}" + area.setText(txt) + layout.addWidget(area); self.setLayout(layout) + + +# ====================== +# 主窗口 +# ====================== +class AudioClassifierGUI(QMainWindow): + def __init__(self): + super().__init__() + self.current_input_mode = "record" + self.process_result = None + self.init_ui() + + def init_ui(self): + self.setWindowTitle("音频分类器") + self.setGeometry(100, 100, 1000, 700) + + self.stacked_widget = QStackedWidget() + self.setCentralWidget(self.stacked_widget) + + self.main_menu_widget = MainMenuWidget(self) + self.record_input_widget = InputWidget(self, "record") + self.upload_input_widget = InputWidget(self, "upload") + self.result_widget = ResultWidget(self) + self.history_widget = HistoryWidget(self) + + self.stacked_widget.addWidget(self.main_menu_widget) # 0 + self.stacked_widget.addWidget(self.record_input_widget) # 1 + self.stacked_widget.addWidget(self.upload_input_widget) # 2 + self.stacked_widget.addWidget(self.result_widget) # 3 + self.stacked_widget.addWidget(self.history_widget) # 4 + + self.stacked_widget.setCurrentWidget(self.main_menu_widget) + + def switch_to_input(self, mode): + self.current_input_mode = mode + if mode == "record": + self.stacked_widget.setCurrentWidget(self.record_input_widget) + else: + self.stacked_widget.setCurrentWidget(self.upload_input_widget) + + def switch_to_main_menu(self): + self.stacked_widget.setCurrentWidget(self.main_menu_widget) + + def switch_to_result(self, result): + self.process_result = result + self.result_widget = ResultWidget(self, result) + self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) + self.stacked_widget.addWidget(self.result_widget) + self.stacked_widget.setCurrentWidget(self.result_widget) + + def switch_to_history(self): + self.history_widget.load() + self.stacked_widget.setCurrentWidget(self.history_widget) + + +# ====================== +# 上传/录音 界面(放最后,避免打断阅读) +# ====================== +class InputWidget(QWidget): + def __init__(self, parent, mode="record"): + super().__init__(parent) + self.parent, self.mode = parent, mode + self.wav_path, self.record_thread = "", None + self.device_index = None + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) + layout.addWidget(back) + if self.mode == "record": self.setup_record(layout) + else: self.setup_upload(layout) + self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") + path_layout = QHBoxLayout() + path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) + path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) + layout.addLayout(path_layout) + self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) + self.process_btn.clicked.connect(self.start_process) + layout.addWidget(self.process_btn) + self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) + self.setLayout(layout) + + def setup_record(self, layout): + title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) + layout.addWidget(title) + dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) + self.dev_combo = QComboBox(); self.refresh_devices() + self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) + dev_layout.addWidget(self.dev_combo) + btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) + dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) + self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) + self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec + self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") + self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) + for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) + self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) + + def setup_upload(self, layout): + h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) + b = QPushButton("浏览"); b.clicked.connect(self.browse) + h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) + + def refresh_devices(self): + self.dev_combo.clear(); p = pyaudio.PyAudio() + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if int(info.get('maxInputChannels',0))>0: + self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) + if self.dev_combo.count()>0: + self.dev_combo.setCurrentIndex(0) + self.device_index = self.dev_combo.currentData() + p.terminate() + + def start_rec(self,e): + if e.button()==Qt.LeftButton: + self.record_thread=RecordThread(device_index=self.device_index) + self.record_thread.update_signal.connect(self.log.append) + self.record_thread.finish_signal.connect(self.finish_rec) + self.record_thread.progress_signal.connect(self.progress.setValue) + self.record_thread.level_signal.connect(self.level.setValue) + self.record_thread.start(); self.start=time.time(); self.timer.start(100) + + def stop_rec(self,e): + if e.button()==Qt.LeftButton and self.record_thread: + self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) + + def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") + + def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) + + def browse(self): + f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") + if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) + + def start_process(self): + if self.mode=="upload": self.wav_path=self.file.text() + if not os.path.exists(self.wav_path): self.log.append("无效文件"); return + self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) + self.thread.update_signal.connect(self.log.append) + self.thread.finish_signal.connect(self.parent.switch_to_result) + self.thread.start() + + +if __name__ == "__main__": + try: + import pyaudio + except ImportError: + print("请先安装pyaudio: pip install pyaudio") + sys.exit(1) + + app = QApplication(sys.argv) + window = AudioClassifierGUI() + window.show() + sys.exit(app.exec_()) diff --git a/01src/features.pkl b/01src/features.pkl new file mode 100644 index 0000000..ae393fc Binary files /dev/null and b/01src/features.pkl differ diff --git a/01src/mat翻译(学习音频).py b/01src/mat翻译(学习音频).py new file mode 100644 index 0000000..642eeac --- /dev/null +++ b/01src/mat翻译(学习音频).py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +有标注学习音频特征提取:读取“瓷空1.wav”(标注为“空”),提取五维特征+标签,保存MAT/PKL(适配深度学习) +""" + +from pathlib import Path +import numpy as np +import scipy.io.wavfile as wav +from scipy.io import savemat +from scipy.signal import hilbert +import librosa +import matplotlib.pyplot as plt +import os +import pickle # 用于保存PKL文件 + +plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示正常 +plt.rcParams['axes.unicode_minus'] = False # 负号显示正常 + +# ---------- 参数设定(核心:指定有标注的学习音频路径,标签自动识别) ---------- +WAV_FILE = r"D:\SummerSchool\sample\瓷空1.wav" # 有标注的学习音频(文件名含“空”,自动识别标签) +WIN_SIZE = 1024 # 帧长(与测试音频代码一致) +OVERLAP = 512 # 帧移(与测试音频代码一致) +STEP = WIN_SIZE - OVERLAP # 帧步长(与测试音频代码一致) +THRESH = 0.01 # 能量阈值(降低以确保检测到敲击片段,与测试音频代码一致) +SEG_LEN_SEC = 0.2 # 每段音频长度(秒,与测试音频代码一致) +# 标签映射(按深度学习习惯定义:“空”标注为0,后续可根据需求修改;若“实”则改为1) +LABEL_MAP = {"空": 0, "实": 1} +# 输出文件路径(默认保存在音频同目录,文件名含“学习”标识,便于区分) +OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.mat" +OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_train_features.pkl" + +# ---------- 工具函数(完全复用之前的逻辑,确保特征一致性) ---------- +def segment_signal(signal: np.ndarray, fs: int): + """按能量切分敲击片段(与测试音频代码完全一致)""" + if signal.ndim > 1: # 双声道自动转单声道 + signal = signal[:, 0] + signal = signal / (np.max(np.abs(signal)) + 1e-12) # 音频归一化(避免幅值影响) + + # 分帧并计算每帧能量 + frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T + energy = np.sum(frames ** 2, axis=1) + + # 筛选能量高于阈值的帧,定位“新敲击”起始点 + idx = np.where(energy > THRESH)[0] + if idx.size == 0: + return [] + # 相邻有效帧间隔>5帧时,视为新的敲击(避免连续帧重复计数) + hit_mask = np.diff(np.concatenate(([0], idx))) > 5 + hit_starts = idx[hit_mask] + + # 切分固定长度的片段(不足长度时取到音频末尾) + seg_len = int(round(SEG_LEN_SEC * fs)) + segments = [] + for start_frame in hit_starts: + start_sample = start_frame * STEP + end_sample = min(start_sample + seg_len, len(signal)) + segments.append(signal[start_sample:end_sample]) + return segments + + +def extract_features(sig: np.ndarray, fs: int): + """提取五维特征(与测试音频代码完全一致,保证深度学习数据匹配)""" + sig = sig.flatten() + if sig.size == 0: # 空片段防报错,返回0向量 + return np.zeros(5, dtype=np.float32) + + # 1. RMS(均方根:反映音频能量大小) + rms = np.sqrt(np.mean(sig ** 2)) + # 2. 主频(频谱峰值对应的频率:反映敲击声的主要频率成分) + L = sig.size + freq = np.fft.rfftfreq(L, d=1 / fs) # 频率轴 + fft_mag = np.abs(np.fft.rfft(sig)) # 频谱幅值 + main_freq = freq[np.argmax(fft_mag)] + # 3. 频谱偏度(反映频谱分布的不对称性:区分“空”“实”的关键特征之一) + spec_power = fft_mag + freq_centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) # 频谱质心 + freq_spread = np.sqrt(np.sum(((freq - freq_centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) # 频谱展宽 + skewness = np.sum(((freq - freq_centroid) ** 3) * spec_power) / ((np.sum(spec_power) + 1e-12) * (freq_spread ** 3 + 1e-12)) + # 4. MFCC第一维均值(梅尔频率倒谱系数:反映音频的音色特征) + try: + mfcc = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13) # 提取13维MFCC + mfcc_mean = float(np.mean(mfcc[0, :])) # 取第一维均值(最能区分音色) + except Exception: # 异常情况(如片段过短)返回0 + mfcc_mean = 0.0 + # 5. 包络峰值(希尔伯特变换提取幅度包络:反映敲击声的衰减特性) + amp_envelope = np.abs(hilbert(sig)) + env_peak = np.max(amp_envelope) + + # 特征格式统一为float32(适配深度学习框架) + return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32) + +# ---------- 主程序(核心:自动识别标签+特征+标签保存) ---------- +def main(): + # 1. 检查音频文件是否存在 + wav_path = Path(WAV_FILE) + if not wav_path.exists(): + print(f"❌ 错误:音频文件 {WAV_FILE} 不存在!") + return + if wav_path.suffix != ".wav": + print(f"❌ 错误:{wav_path.name} 不是WAV格式!") + return + + # 2. 读取音频(用librosa兼容更多格式,保持采样率不变) + audio, sr = librosa.load(wav_path, sr=None, mono=True) + print(f"✅ 成功读取学习音频:{wav_path.name}") + print(f" 采样率:{sr} Hz | 音频长度:{len(audio)/sr:.2f} 秒") + + # 3. 切分有效敲击片段 + segments = segment_signal(audio, sr) + if len(segments) == 0: + print(f"⚠️ 未检测到有效敲击片段!可尝试降低THRESH(当前{THRESH})或检查音频是否有敲击声。") + return + print(f"✅ 检测到 {len(segments)} 个有效敲击片段") + + # 4. 提取特征+自动识别标签 + features_list = [] + labels_list = [] + # 从文件名提取标注(“瓷空1.wav”含“空”,对应标签0) + file_stem = wav_path.stem # 文件名(不含后缀):"瓷空1" + if "空" in file_stem: + label = LABEL_MAP["空"] + print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(空)") + elif "实" in file_stem: + label = LABEL_MAP["实"] + print(f"✅ 自动识别标注:{file_stem} → 标签 {label}(实)") + else: + print(f"⚠️ 文件名 {file_stem} 不含'空'或'实',手动指定标签为0(空)!") + label = LABEL_MAP["空"] # 手动兜底,可根据实际修改 + + # 批量提取特征并匹配标签(每个片段对应一个标签) + for i, seg in enumerate(segments, 1): + feat = extract_features(seg, sr) + features_list.append(feat) + labels_list.append(label) + print(f" 片段{i:02d}:特征提取完成(维度:5)") + + # 5. 整理为矩阵格式(适配深度学习输入) + features_matrix = np.vstack(features_list) # 特征矩阵:(片段数, 5) + labels_array = np.array(labels_list, dtype=np.int8).reshape(-1, 1) # 标签矩阵:(片段数, 1) + print(f"\n✅ 特征与标签整理完成") + print(f" 特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征)") + print(f" 标签矩阵形状:{labels_array.shape}(行=片段数,列=1)") + + # 6. 保存为MAT文件(兼容MATLAB深度学习框架) + savemat(OUT_MAT, { + "matrix": features_matrix, # 特征矩阵(与之前训练集格式一致) + "label": labels_array # 标签矩阵(与之前训练集格式一致) + }) + print(f"✅ MAT文件已保存:{OUT_MAT}") + + # 7. 保存为PKL文件(兼容Python深度学习框架,如PyTorch/TensorFlow) + with open(OUT_PKL, "wb") as f: + pickle.dump({ + "matrix": features_matrix, # 特征矩阵 + "label": labels_array # 标签矩阵(含标注信息) + }, f) + print(f"✅ PKL文件已保存:{OUT_PKL}") + + # 8. 特征可视化(可选,帮助直观查看特征分布) + plt.figure(figsize=(12, 8)) + feature_names = ["RMS(能量)", "主频(Hz)", "频谱偏度", "MFCC均值", "包络峰值"] + for i in range(5): + plt.subplot(2, 3, i+1) + plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", color="#1f77b4", linewidth=1.5, markersize=4) + plt.xlabel("片段编号", fontsize=10) + plt.ylabel("特征值", fontsize=10) + plt.title(f"特征{i+1}:{feature_names[i]}", fontsize=11, fontweight="bold") + plt.grid(True, alpha=0.3) + # 标签信息标注 + plt.subplot(2, 3, 6) + plt.text(0.5, 0.6, f"音频文件:{wav_path.name}", ha="center", fontsize=11) + plt.text(0.5, 0.4, f"标注标签:{label}({'空' if label==0 else '实'})", ha="center", fontsize=11) + plt.text(0.5, 0.2, f"有效片段数:{len(features_matrix)}", ha="center", fontsize=11) + plt.axis("off") + plt.tight_layout() + plt.show() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/01src/mat翻译(测试音频).py b/01src/mat翻译(测试音频).py new file mode 100644 index 0000000..4f6eba6 --- /dev/null +++ b/01src/mat翻译(测试音频).py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +无标注测试音频特征提取:读取单个WAV,提取五维特征,保存为MAT和PKL(无标签) +""" + +from pathlib import Path +import numpy as np +import scipy.io.wavfile as wav +from scipy.io import savemat +from scipy.signal import hilbert +import librosa +import matplotlib.pyplot as plt +import os +import pickle # 用于保存PKL文件 + +plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示 +plt.rcParams['axes.unicode_minus'] = False # 负号正常显示 + +# ---------- 参数设定(无需改,按原逻辑) ---------- +WAV_FILE = r"D:\SummerSchool\test2.wav" # 你的无标注测试音频路径 +WIN_SIZE = 1024 # 帧长 +OVERLAP = 512 # 帧移 +STEP = WIN_SIZE - OVERLAP # 帧步长 +THRESH = 0.01 # 降低阈值,确保能检测到片段(已调小) +SEG_LEN_SEC = 0.2 # 每段音频长度(秒) +# 输出文件路径(可自定义,默认保存在WAV文件同目录) +OUT_MAT = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.mat" +OUT_PKL = Path(WAV_FILE).parent / f"{Path(WAV_FILE).stem}_features.pkl" + +# ---------- 工具函数(完全保留原特征提取逻辑,确保和训练集一致) ---------- +def segment_signal(signal: np.ndarray, fs: int): + """按能量切分音频片段(原逻辑不变)""" + if signal.ndim > 1: # 双声道转单声道 + signal = signal[:, 0] + signal = signal / (np.max(np.abs(signal)) + 1e-12) # 归一化 + + # 分帧+计算帧能量 + frames = librosa.util.frame(signal, frame_length=WIN_SIZE, hop_length=STEP).T + energy = np.sum(frames ** 2, axis=1) + + # 筛选能量高于阈值的帧,切出有效片段 + idx = np.where(energy > THRESH)[0] + if idx.size == 0: + return [] + hit_mask = np.diff(np.concatenate(([0], idx))) > 5 # 新敲击起始帧 + hit_starts = idx[hit_mask] + + seg_len = int(round(SEG_LEN_SEC * fs)) + segments = [] + for h in hit_starts: + start = h * STEP + end = min(start + seg_len, len(signal)) + segments.append(signal[start:end]) + return segments + + +def extract_features(sig: np.ndarray, fs: int): + """提取五维特征(和训练集完全一致,保证特征匹配)""" + sig = sig.flatten() + if sig.size == 0: + return np.zeros(5) + + # 1. RMS(均方根) + rms = np.sqrt(np.mean(sig ** 2)) + # 2. 主频(频谱峰值对应的频率) + L = sig.size + f = np.fft.rfftfreq(L, d=1 / fs) + Y = np.abs(np.fft.rfft(sig)) + main_freq = f[np.argmax(Y)] + # 3. 频谱偏度 + P = Y + centroid = np.sum(f * P) / (np.sum(P) + 1e-12) + spread = np.sqrt(np.sum(((f - centroid) ** 2) * P) / (np.sum(P) + 1e-12)) + skewness = np.sum(((f - centroid) ** 3) * P) / ((np.sum(P) + 1e-12) * (spread ** 3 + 1e-12)) + # 4. MFCC第一维均值 + try: + mfccs = librosa.feature.mfcc(y=sig, sr=fs, n_mfcc=13) + mfcc_mean = float(np.mean(mfccs[0, :])) + except Exception: + mfcc_mean = 0.0 + # 5. 包络峰值(希尔伯特变换) + env_peak = np.max(np.abs(hilbert(sig))) + + return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) + +# ---------- 主程序(核心:去掉标签,只提特征+保存) ---------- +def main(): + # 1. 读取音频文件 + wav_path = Path(WAV_FILE) + if not (wav_path.exists() and wav_path.suffix == ".wav"): + print(f"❌ 错误:{WAV_FILE} 不存在或不是WAV文件!") + return + + # 用librosa读取(兼容性更好,避免格式问题) + y, fs = librosa.load(wav_path, sr=None, mono=True) + print(f"✅ 成功读取音频:{wav_path.name},采样率:{fs} Hz") + + # 2. 切分有效片段 + segments = segment_signal(y, fs) + if len(segments) == 0: + print(f"⚠️ 未检测到有效音频片段,尝试再降低THRESH(当前{THRESH})!") + return + print(f"✅ 检测到 {len(segments)} 个有效片段") + + # 3. 提取五维特征 + features = [extract_features(seg, fs) for seg in segments] + features_matrix = np.vstack(features).astype(np.float32) # 特征矩阵(N行5列,N=片段数) + print(f"✅ 提取特征完成,特征矩阵形状:{features_matrix.shape}(行=片段数,列=5维特征)") + + # 4. 保存为MAT文件(兼容MATLAB) + savemat(OUT_MAT, {"matrix": features_matrix}) # 只存特征矩阵,无label + print(f"✅ MAT文件已保存:{OUT_MAT}") + + # 5. 保存为PKL文件(兼容Python后续模型推断) + with open(OUT_PKL, "wb") as f: + pickle.dump({"matrix": features_matrix}, f) # 和训练集PKL结构一致(只少label) + print(f"✅ PKL文件已保存:{OUT_PKL}") + + # (可选)绘制特征可视化图 + plt.figure(figsize=(10, 6)) + feature_names = ["RMS", "主频(Hz)", "频谱偏度", "MFCC均值", "包络峰值"] + for i in range(5): + plt.subplot(2, 3, i+1) + plt.plot(range(1, len(features_matrix)+1), features_matrix[:, i], "-o", linewidth=1.5) + plt.xlabel("片段编号") + plt.ylabel("特征值") + plt.title(f"特征:{feature_names[i]}") + plt.grid(True) + plt.tight_layout() + plt.show() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/01src/readme.txt b/01src/readme.txt new file mode 100644 index 0000000..5cd3329 --- /dev/null +++ b/01src/readme.txt @@ -0,0 +1,30 @@ +系统简介 +本系统是一个基于SVM机器学习的墙体声纹检测系统,支持通过录制或上传 WAV 格式音频文件,实现音频片段的自动切分、特征提取与分类,并将结果存储于本地数据库中,方便后续查看、管理与导出。主要应用场景包括对特定音频(如敲击声等)的分类识别,支持 "实心" 和 "空心" 两类标签的识别与管理。 +核心功能 +音频输入:支持两种输入方式 +实时录音:通过麦克风采集音频,支持设备选择与录音时长监控 +文件上传:上传本地 WAV 格式音频文件 +音频处理与分类 +自动切分:基于音频能量特征切分有效片段 +特征提取:提取音频片段的 RMS、主频、频谱偏度等多维特征 +模型预测:使用预训练的 SVM 模型进行分类,输出每个片段的标签(0/1)和概率值 +数据管理 +本地存储:采用 SQLite 数据库存储分类结果(含文件名、片段数、平均概率等)及片段详情 +历史记录:支持查询、搜索、编辑、删除历史记录 +片段管理:可查看、新增、编辑、删除特定记录的音频片段 +可视化交互:通过直观的 GUI 界面操作,包括主菜单、输入界面、结果展示、历史管理等模块 + +配置环境 +依赖库 +本系统依赖以下 Python 库,建议使用 Python 3.9 + 版本: +GUI 框架:PyQt5(用于构建图形界面) +音频处理:librosa(音频加载、特征提取)、pyaudio(录音功能) +数值计算:numpy(数组处理、特征矩阵运算) +机器学习:scikit-learn(SVM 模型、数据标准化) +模型存储:joblib(保存 / 加载训练好的模型) +数据库:sqlite3(本地数据库操作,Python 标准库) +可视化:matplotlib(特征可视化,可选) +数据序列化:pickle(特征数据存储与读取) + +SVM模型的开发依赖于Python下的Scikit-learn、Librosa、NumPy等一系列科学计算与音频处理库。这些库本身存在复杂的底层依赖,在不同操作系统上通过传统的pip安装方式极易出现依赖冲突、编译失败等问题。 +为彻底解决环境配置的复杂性,建议选用Anaconda,预编译的二进制包和强大的环境管理功能,能够自动处理库与库之间的依赖关系 \ No newline at end of file diff --git a/01src/sample/无保温实1.wav b/01src/sample/无保温实1.wav new file mode 100644 index 0000000..b583885 Binary files /dev/null and b/01src/sample/无保温实1.wav differ diff --git a/01src/sample/无保温空1.wav b/01src/sample/无保温空1.wav new file mode 100644 index 0000000..11cd098 Binary files /dev/null and b/01src/sample/无保温空1.wav differ diff --git a/01src/sample/漆实1.wav b/01src/sample/漆实1.wav new file mode 100644 index 0000000..7fb0439 Binary files /dev/null and b/01src/sample/漆实1.wav differ diff --git a/01src/sample/漆实2.wav b/01src/sample/漆实2.wav new file mode 100644 index 0000000..fa53291 Binary files /dev/null and b/01src/sample/漆实2.wav differ diff --git a/01src/sample/漆空1.wav b/01src/sample/漆空1.wav new file mode 100644 index 0000000..fbadd66 Binary files /dev/null and b/01src/sample/漆空1.wav differ diff --git a/01src/sample/漆空渐强1.wav b/01src/sample/漆空渐强1.wav new file mode 100644 index 0000000..853625f Binary files /dev/null and b/01src/sample/漆空渐强1.wav differ diff --git a/01src/sample/漆空渐强2.wav b/01src/sample/漆空渐强2.wav new file mode 100644 index 0000000..46c8c8d Binary files /dev/null and b/01src/sample/漆空渐强2.wav differ diff --git a/01src/sample/瓷实1.wav b/01src/sample/瓷实1.wav new file mode 100644 index 0000000..1625647 Binary files /dev/null and b/01src/sample/瓷实1.wav differ diff --git a/01src/sample/瓷实2.wav b/01src/sample/瓷实2.wav new file mode 100644 index 0000000..63ab74c Binary files /dev/null and b/01src/sample/瓷实2.wav differ diff --git a/01src/sample/瓷实渐强.wav b/01src/sample/瓷实渐强.wav new file mode 100644 index 0000000..97d2528 Binary files /dev/null and b/01src/sample/瓷实渐强.wav differ diff --git a/01src/sample/瓷空1.wav b/01src/sample/瓷空1.wav new file mode 100644 index 0000000..80f8c4d Binary files /dev/null and b/01src/sample/瓷空1.wav differ diff --git a/01src/sample/瓷空2.wav b/01src/sample/瓷空2.wav new file mode 100644 index 0000000..4af472d Binary files /dev/null and b/01src/sample/瓷空2.wav differ diff --git a/01src/sample/瓷空渐强.wav b/01src/sample/瓷空渐强.wav new file mode 100644 index 0000000..44b1536 Binary files /dev/null and b/01src/sample/瓷空渐强.wav differ diff --git a/01src/scaler.pkl b/01src/scaler.pkl new file mode 100644 index 0000000..d520b18 Binary files /dev/null and b/01src/scaler.pkl differ diff --git a/01src/svm_model.pkl b/01src/svm_model.pkl new file mode 100644 index 0000000..b7ef935 Binary files /dev/null and b/01src/svm_model.pkl differ diff --git a/01src/train_cross_validated_model.py b/01src/train_cross_validated_model.py new file mode 100644 index 0000000..c9ddd83 --- /dev/null +++ b/01src/train_cross_validated_model.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +基于交叉验证代码的训练脚本 +""" + +import pickle +import numpy as np +from pathlib import Path +from sklearn.svm import SVC +from sklearn.feature_selection import SelectKBest, mutual_info_classif +from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from scipy.stats import randint, loguniform +import joblib + + +def load_dataset(pkl_path): + """加载数据集""" + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data['matrix'], data['label'] + + +def train_cross_validated_model(dataset_path, pipeline_save_path, scaler_save_path): + # 加载数据 + X_train, y_train = load_dataset(dataset_path) + print(f"加载数据集:{X_train.shape[0]}个样本,{X_train.shape[1]}维特征") + + # 将标签转换为-1/1格式(与交叉验证代码一致) + y_train_signed = np.where(y_train == 0, -1, 1) + + # 创建流水线(包含特征选择和SVM) + pipe = Pipeline([ + ('sel', SelectKBest(mutual_info_classif)), + ('svm', SVC(kernel='rbf', class_weight='balanced', probability=True)) + ]) + + # 参数分布(与交叉验证代码一致) + n_features = X_train.shape[1] + param_dist = { + 'sel__k': randint(1, n_features + 1), + 'svm__C': loguniform(1e-3, 1e3), + 'svm__gamma': loguniform(1e-6, 1e1) + } + + # 随机搜索 + cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) + search = RandomizedSearchCV( + pipe, + param_distributions=param_dist, + n_iter=30, # 减少迭代次数以加快训练 + scoring='roc_auc', + cv=cv_inner, + n_jobs=-1, + random_state=42, + verbose=1 + ) + + print("开始随机搜索优化...") + search.fit(X_train, y_train_signed) + + best_params = search.best_params_ + print(f"\n最佳参数: {best_params}") + print(f"最佳交叉验证AUC: {search.best_score_:.4f}") + + # 训练最终模型 + final_model = search.best_estimator_ + final_model.fit(X_train, y_train_signed) + + # 单独保存标准化器(用于GUI中的特征标准化) + scaler = StandardScaler().fit(X_train) + + # 保存模型 + joblib.dump(final_model, pipeline_save_path) + joblib.dump(scaler, scaler_save_path) + + print(f"流水线模型已保存至: {pipeline_save_path}") + print(f"标准化器已保存至: {scaler_save_path}") + + return final_model, scaler + + +if __name__ == "__main__": + # 使用你的训练集路径 + DATASET_PATH = r"D:\Python\空心检测\pythonProject\feature_dataset.pkl" + PIPELINE_PATH = "pipeline_model.pkl" # 与GUI中设置的路径一致 + SCALER_PATH = "scaler.pkl" + + train_cross_validated_model(DATASET_PATH, PIPELINE_PATH, SCALER_PATH) \ No newline at end of file diff --git a/01src/train_model.py b/01src/train_model.py new file mode 100644 index 0000000..33b6bde --- /dev/null +++ b/01src/train_model.py @@ -0,0 +1,55 @@ +import pickle +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score +import joblib + + +# 加载数据集(使用你之前合并的feature_dataset00.pkl) +def load_dataset(pkl_path): + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data['matrix'], data['label'] + + +# 训练模型 +def train_and_save_model(dataset_path, model_save_path, scaler_save_path): + # 加载数据 + X, y = load_dataset(dataset_path) + print(f"加载数据集:{X.shape[0]}个样本,{X.shape[1]}维特征") + + # 划分训练集 + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y + ) + + # 标准化 + scaler = StandardScaler() + X_train_std = scaler.fit_transform(X_train) + X_test_std = scaler.transform(X_test) + + # 训练SVM + svm = SVC(kernel='rbf', class_weight='balanced', probability=True, random_state=42) + svm.fit(X_train_std, y_train) + + # 评估 + y_pred = svm.predict(X_test_std) + print(f"模型准确率:{accuracy_score(y_test, y_pred):.4f}") + + # 保存模型和标准化器 + joblib.dump(svm, model_save_path) + joblib.dump(scaler, scaler_save_path) + print(f"模型已保存至:{model_save_path}") + print(f"标准化器已保存至:{scaler_save_path}") + + +if __name__ == "__main__": + # 替换为你的数据集路径 + DATASET_PATH = r"D:\SummerSchool\mat_cv\mat_cv\feature_dataset.pkl" + # 模型保存路径(与GUI代码中设置的路径一致) + MODEL_PATH = "svm_model.pkl" + SCALER_PATH = "scaler.pkl" + + train_and_save_model(DATASET_PATH, MODEL_PATH, SCALER_PATH) diff --git a/01src/交叉验证.py b/01src/交叉验证.py new file mode 100644 index 0000000..47a0240 --- /dev/null +++ b/01src/交叉验证.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +""" +交叉验证(最小改动版)——与原结构一致,仅在第6段训练后新增模型与scaler导出 +""" + +from pathlib import Path +import pickle +import numpy as np +import matplotlib.pyplot as plt + +# 新增/补全的 import +import joblib +from scipy.stats import randint, loguniform, norm +from sklearn.svm import SVC +from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV, cross_val_score +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, mutual_info_classif +from sklearn.pipeline import Pipeline + +# ---------- 1. 数据路径 ---------- +BASE_DIR = Path(r'D:\SummerSchool\mat_cv\mat_cv') +TRAIN_PKL = BASE_DIR / 'cv10_train.pkl' +TEST_FILES = [BASE_DIR / 'cv10_test.pkl'] # 也可放多个测试集 pkl 文件 + +# ---------- 2. 工具 ---------- +def load_pkl_matrix(path: Path): + with open(path, 'rb') as f: + data = pickle.load(f) + return data['matrix'], data.get('label') + +# ---------- 3. 读取训练集 ---------- +X_train, y_train = load_pkl_matrix(TRAIN_PKL) +if y_train is None: + raise ValueError('训练集缺少 label 字段') +y_train = y_train.ravel() +# {0,1} → {-1,+1} +y_train_signed = np.where(y_train == 0, -1, 1) + +# ---------- 4. 标准化 ---------- +scaler = StandardScaler().fit(X_train) +X_train_std = scaler.transform(X_train) +n_features = X_train_std.shape[1] + +# ---------- 5. RandomizedSearchCV 搜索 ---------- +pipe = Pipeline([ + ('sel', SelectKBest(mutual_info_classif)), + ('svm', SVC(kernel='rbf', class_weight='balanced', probability=True)) +]) + +param_dist = { + 'sel__k': randint(1, n_features + 1), + 'svm__C': loguniform(1e-3, 1e3), + 'svm__gamma': loguniform(1e-6, 1e1) +} + +cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) +search = RandomizedSearchCV( + pipe, + param_distributions=param_dist, + n_iter=60, # 搜索次数 + scoring='roc_auc', + cv=cv_inner, + n_jobs=-1, + random_state=42, + verbose=1 +) +search.fit(X_train_std, y_train_signed) +best_params = search.best_params_ +print("\n▶ RandomizedSearch 最佳参数:", best_params) +print(f" 内层 5-折 AUC ≈ {search.best_score_:.4f}") + +# ---------- 6. 训练最终流水线 ---------- +final_model = search.best_estimator_ +final_model.fit(X_train_std, y_train_signed) + +# ---------- 6.5 新增:导出模型与标准化器(供 GUI 使用) ---------- +# 输出到 BASE_DIR 下,也可按需改路径 +model_out = BASE_DIR / 'svm_model.pkl' +scaler_out = BASE_DIR / 'scaler.pkl' +joblib.dump(final_model, model_out) +joblib.dump(scaler, scaler_out) +print(f"\n✅ 已导出模型与标尺:\n 模型: {model_out}\n 标尺: {scaler_out}\n has_predict_proba: {hasattr(final_model, 'predict_proba')}") + +# ---------- 7. 外层 5-折交叉验证 ---------- +cv_outer = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) +cv_auc = cross_val_score(final_model, X_train_std, y_train_signed, + cv=cv_outer, scoring='roc_auc', n_jobs=-1) +cv_acc = cross_val_score(final_model, X_train_std, y_train_signed, + cv=cv_outer, scoring='accuracy', n_jobs=-1) + +print('\n========== 外层 5-折交叉验证 ==========') +print(f'AUC = {cv_auc.mean():.4f} ± {cv_auc.std():.4f}') +print(f'ACC = {cv_acc.mean():.4f} ± {cv_acc.std():.4f}') + +# ---------- 8. 推断 ---------- +THRESHOLD = 0.5 +Z = norm.ppf(0.975) +infer_results = [] +print('\n========== 推断结果 ==========') + +for pkl_path in TEST_FILES: + X_test, _ = load_pkl_matrix(pkl_path) + X_test_std = scaler.transform(X_test) + pred_signed = final_model.predict(X_test_std) + proba_pos = final_model.predict_proba(X_test_std)[:, 1] + pred_label = np.where(pred_signed == -1, 0, 1) + + mean_p = proba_pos.mean() + sem_p = proba_pos.std(ddof=1) / np.sqrt(len(proba_pos)) if len(proba_pos) > 1 else 0.0 + ci_low, ci_high = mean_p - Z * sem_p, mean_p + Z * sem_p + file_label = int(mean_p >= THRESHOLD) + + print(f'\n▶ 文件: {pkl_path.name} (样本 {len(pred_label)})') + for i, (lbl, prob) in enumerate(zip(pred_label, proba_pos), 1): + print(f' Sample {i:02d}: pred={lbl} prob(1)={prob:.4f}') + print(' ---- 文件级融合 ----') + print(f' mean_prob(1) = {mean_p:.4f} (95% CI {ci_low:.4f} ~ {ci_high:.4f})') + print(f' Final label = {file_label} (阈值 {THRESHOLD})') + + infer_results.append(dict( + file=pkl_path.name, + pred=pred_label.tolist(), + prob=proba_pos.tolist(), + mean_prob=float(mean_p), + ci_low=float(ci_low), + ci_high=float(ci_high), + final_label=int(file_label) + )) + +# 打印测试文件的原始标签(若有) +try: + print("TEST_FILES 标签:", load_pkl_matrix(TEST_FILES[0])[1]) +except Exception: + pass + +# ---------- 9. 保存 & 可视化 ---------- +out_pkl = BASE_DIR / 'infer_results.pkl' +with open(out_pkl, 'wb') as f: + pickle.dump(infer_results, f) +print(f'\n所有文件结果已保存到: {out_pkl}') + +plt.rcParams['font.sans-serif'] = ['SimHei'] +plt.rcParams['axes.unicode_minus'] = False + +labels = [r['file'] for r in infer_results] +means = [r['mean_prob'] for r in infer_results] +yerr = [(r['mean_prob'] - r['ci_low'], r['ci_high'] - r['mean_prob']) + for r in infer_results] + +fig, ax = plt.subplots(figsize=(6, 4)) +ax.bar(range(len(means)), means, + yerr=np.array(yerr).T, capsize=5, alpha=0.8) +ax.axhline(THRESHOLD, color='red', ls='--', label=f'阈值 {THRESHOLD}') +ax.set_xticks(range(len(labels))) +ax.set_xticklabels(labels, rotation=15) +ax.set_ylim(0, 1) +ax.set_ylabel('mean_prob(空心=0)') +ax.set_title('文件级空心概率 (±95% CI)') +ax.legend() +plt.tight_layout() + +desktop = Path.home() / 'Desktop' +save_path = desktop / 'infer_summary.png' +fig.savefig(save_path, dpi=300, bbox_inches='tight') +print(f'可视化图已保存至: {save_path}') +plt.show() diff --git a/01src/生成feature_dataset.py b/01src/生成feature_dataset.py new file mode 100644 index 0000000..1577cc2 --- /dev/null +++ b/01src/生成feature_dataset.py @@ -0,0 +1,73 @@ +import os +import pickle +import numpy as np +from pathlib import Path + + +def merge_pkl_files(input_dir, output_path="feature_dataset.pkl"): + all_features = [] + all_labels = [] + + pkl_files = list(Path(input_dir).glob("*.pkl")) + if not pkl_files: + raise FileNotFoundError(f"在目录 {input_dir} 中未找到PKL文件") + + # 排除已合并的文件和无效文件 + pkl_files = [f for f in pkl_files if f.name not in ["feature_dataset.pkl", "infer_results.pkl"]] + print(f"发现 {len(pkl_files)} 个有效PKL文件,开始合并...") + + for file in pkl_files: + try: + with open(file, "rb") as f: + data = pickle.load(f) + + if "matrix" not in data or "label" not in data: + print(f"跳过 {file.name}:缺少'matrix'或'label'字段") + continue + + features = data["matrix"] + labels = data["label"] + + # 强制将标签转为一维整数数组(核心修复) + labels = labels.ravel().astype(np.int64) # 转为int64类型 + + # 验证特征和标签数量匹配 + if len(features) != len(labels): + print(f"跳过 {file.name}:特征({len(features)})与标签({len(labels)})数量不匹配") + continue + + # 验证特征维度一致性 + if all_features and features.shape[1] != all_features[0].shape[1]: + print(f"跳过 {file.name}:特征维度与已有数据不一致(现有{all_features[0].shape[1]}维,当前{features.shape[1]}维)") + continue + + all_features.append(features) + all_labels.append(labels) + print(f"已加载 {file.name}:{len(features)} 条样本(特征{features.shape[1]}维)") + + except Exception as e: + print(f"处理 {file.name} 时出错:{str(e)},已跳过") + + if not all_features: + raise ValueError("没有有效数据可合并,请检查输入文件") + + # 合并特征和标签 + merged_matrix = np.vstack(all_features) + merged_label = np.concatenate(all_labels, axis=0) # 一维数组拼接 + + print("\n合并结果:") + print(f"总样本数:{len(merged_matrix)}") + print(f"特征矩阵形状:{merged_matrix.shape}") + # 确保标签为整数后再统计分布 + print(f"标签分布:{np.bincount(merged_label)} (索引对应标签值)") + + with open(output_path, "wb") as f: + pickle.dump({"matrix": merged_matrix, "label": merged_label}, f) + + print(f"\n已成功保存至 {output_path}") + + +if __name__ == "__main__": + INPUT_DIRECTORY = r"D:\SummerSchool\mat_cv\mat_cv-02" + OUTPUT_FILE = r"D:\SummerSchool\mat_cv\mat_cv-02\feature_dataset00.pkl" # 绝对路径 + merge_pkl_files(INPUT_DIRECTORY, OUTPUT_FILE) \ No newline at end of file