diff --git a/audio_classifier_gui_01.py b/audio_classifier_gui_01.py deleted file mode 100644 index 63b2ade..0000000 --- a/audio_classifier_gui_01.py +++ /dev/null @@ -1,945 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import sys -import os -import time -import numpy as np -import joblib -import librosa -import pyaudio -import wave -import sqlite3 -from datetime import datetime -from scipy.signal import hilbert, find_peaks -from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, - QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout, - QWidget, QProgressBar, QStackedWidget, QMessageBox, - QTableWidget, QTableWidgetItem, QHeaderView, - QLineEdit, QDialog, QDialogButtonBox, QFormLayout, - QComboBox, QSpinBox, QDoubleSpinBox, QTabWidget, QInputDialog) -from PyQt5.QtCore import Qt, QThread, pyqtSignal - - -# ---------------------- 数据库管理模块 ---------------------- -class DatabaseManager: - """数据库管理器,负责结果数据的存储与查询""" - - def __init__(self, db_path="audio_classification.db"): - self.db_path = db_path - self.init_database() - - def init_database(self): - """初始化数据库表结构""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - # 创建结果表 - cursor.execute(''' - CREATE TABLE IF NOT EXISTS classification_results ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - filename TEXT NOT NULL, - segment_count INTEGER NOT NULL, - segment_labels TEXT NOT NULL, - segment_probs TEXT NOT NULL, - mean_probability REAL NOT NULL, - final_label INTEGER NOT NULL, - label_text TEXT NOT NULL, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - notes TEXT - ) - ''') - - conn.commit() - conn.close() - - def insert_result(self, result_data, notes=""): - """插入新的分类结果""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - cursor.execute(''' - INSERT INTO classification_results - (filename, segment_count, segment_labels, segment_probs, - mean_probability, final_label, label_text, notes) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - result_data['filename'], - result_data['segment_count'], - str(result_data['segment_labels']), - str(result_data['segment_probs']), - result_data['mean_probability'], - result_data['final_label'], - result_data['label_text'], - notes - )) - - conn.commit() - conn.close() - return cursor.lastrowid - - def get_all_results(self): - """获取所有结果记录""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - cursor.execute(''' - SELECT * FROM classification_results - ORDER BY create_time DESC - ''') - - results = [] - columns = [desc[0] for desc in cursor.description] - - for row in cursor.fetchall(): - results.append(dict(zip(columns, row))) - - conn.close() - return results - - def search_results(self, filename_filter="", label_filter=None, date_filter=None): - """根据条件搜索结果""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - query = ''' - SELECT * FROM classification_results - WHERE 1=1 - ''' - params = [] - - if filename_filter: - query += " AND filename LIKE ?" - params.append(f'%{filename_filter}%') - - if label_filter is not None: - query += " AND final_label = ?" - params.append(label_filter) - - if date_filter: - query += " AND DATE(create_time) = ?" - params.append(date_filter) - - query += " ORDER BY create_time DESC" - - cursor.execute(query, params) - - results = [] - columns = [desc[0] for desc in cursor.description] - - for row in cursor.fetchall(): - results.append(dict(zip(columns, row))) - - conn.close() - return results - - def update_result(self, result_id, updates): - """更新结果记录""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - set_clause = ", ".join([f"{key} = ?" for key in updates.keys()]) - query = f"UPDATE classification_results SET {set_clause} WHERE id = ?" - - params = list(updates.values()) - params.append(result_id) - - cursor.execute(query, params) - conn.commit() - conn.close() - - return cursor.rowcount > 0 - - def delete_result(self, result_id): - """删除结果记录""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - cursor.execute("DELETE FROM classification_results WHERE id = ?", (result_id,)) - conn.commit() - conn.close() - - return cursor.rowcount > 0 - - -# ---------------------- 特征提取模块(与训练集保持一致) ---------------------- -class FeatureProcessor: - """特征提取器,统一使用训练时的特征维度和计算方式""" - WIN = 1024 - OVERLAP = 512 - THRESHOLD = 0.1 - SEG_LEN_S = 0.2 - STEP = WIN - OVERLAP - - @staticmethod - def frame_energy(signal: np.ndarray) -> np.ndarray: - """计算帧能量""" - frames = librosa.util.frame(signal, frame_length=FeatureProcessor.WIN, - hop_length=FeatureProcessor.STEP) - return np.sum(frames ** 2, axis=0) - - @staticmethod - def detect_impacts(energy: np.ndarray) -> np.ndarray: - """检测有效敲击片段起始点""" - idx = np.where(energy > FeatureProcessor.THRESHOLD)[0] - if idx.size == 0: - return np.array([]) - # 间隔超过5帧视为新片段 - flags = np.diff(np.concatenate(([0], idx))) > 5 - return idx[flags] - - @staticmethod - def extract_segments(signal: np.ndarray, sr: int, starts: np.ndarray) -> list[np.ndarray]: - """切分固定长度的音频片段""" - seg_len = int(FeatureProcessor.SEG_LEN_S * sr) - segments = [] - for frame_idx in starts: - start = frame_idx * FeatureProcessor.STEP - end = min(start + seg_len, len(signal)) - segments.append(signal[start:end]) - return segments - - @staticmethod - def extract_features(signal: np.ndarray, sr: int) -> np.ndarray: - """提取与训练集一致的5维特征(修正原GUI可能的特征不匹配问题)""" - signal = signal.flatten() - if len(signal) == 0: - return np.zeros(5, dtype=np.float32) - - # 1. RMS能量 - rms = np.sqrt(np.mean(signal ** 2)) - - # 2. 主频(频谱峰值) - fft = np.fft.fft(signal) - freq = np.fft.fftfreq(len(signal), 1 / sr) - positive_mask = freq >= 0 - freq = freq[positive_mask] - fft_mag = np.abs(fft[positive_mask]) - main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 - - # 3. 频谱偏度 - spec_power = fft_mag - centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) - spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / - (np.sum(spec_power) + 1e-12)) - skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( - (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if spread > 0 else 0 - - # 4. MFCC第一维均值 - mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13) - mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 - - # 5. 包络峰值(希尔伯特变换) - env_peak = np.max(np.abs(hilbert(signal))) - - return np.array([rms, main_freq, skewness, mfcc_mean, env_peak], dtype=np.float32) - - -# ---------------------- 录音线程 ---------------------- -class AudioRecorder(QThread): - status_updated = pyqtSignal(str) - progress_updated = pyqtSignal(int) - recording_finished = pyqtSignal(str) - - def __init__(self, max_duration=60): - super().__init__() - self.max_duration = max_duration - self.recording = False - self.temp_file = "temp_audio.wav" - - def run(self): - # 音频参数(与特征提取兼容) - FORMAT = pyaudio.paFloat32 - CHANNELS = 1 - RATE = 22050 - CHUNK = 1024 - - try: - audio = pyaudio.PyAudio() - stream = audio.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - frames_per_buffer=CHUNK - ) - - self.recording = True - start_time = time.time() - frames = [] - - while self.recording: - elapsed = time.time() - start_time - if elapsed >= self.max_duration: - self.status_updated.emit(f"已达最大时长 {self.max_duration} 秒") - break - - self.progress_updated.emit(int((elapsed / self.max_duration) * 100)) - data = stream.read(CHUNK) - frames.append(data) - - # 停止录音并保存 - stream.stop_stream() - stream.close() - audio.terminate() - - if frames: - with wave.open(self.temp_file, 'wb') as wf: - wf.setnchannels(CHANNELS) - wf.setsampwidth(audio.get_sample_size(FORMAT)) - wf.setframerate(RATE) - wf.writeframes(b''.join(frames)) - self.status_updated.emit(f"录制完成,时长: {elapsed:.1f}秒") - self.recording_finished.emit(self.temp_file) - else: - self.status_updated.emit("未检测到音频输入") - - except Exception as e: - self.status_updated.emit(f"录音错误: {str(e)}") - - def stop(self): - self.recording = False - - -# ---------------------- 音频处理与预测线程 ---------------------- -class AudioProcessor(QThread): - status_updated = pyqtSignal(str) - result_generated = pyqtSignal(dict) - - def __init__(self, audio_path, model_path, scaler_path): - super().__init__() - self.audio_path = audio_path - self.model_path = model_path - self.scaler_path = scaler_path - self.class_threshold = 0.5 # 分类阈值 - - def run(self): - try: - # 加载模型和标准化器 - self.status_updated.emit("加载模型资源...") - model = joblib.load(self.model_path) - scaler = joblib.load(self.scaler_path) - - # 读取音频文件 - self.status_updated.emit(f"解析音频: {os.path.basename(self.audio_path)}") - signal, sr = librosa.load(self.audio_path, sr=None, mono=True) - signal = signal / np.max(np.abs(signal)) # 归一化 - - # 提取片段 - self.status_updated.emit("检测有效音频片段...") - energy = FeatureProcessor.frame_energy(signal) - impact_starts = FeatureProcessor.detect_impacts(energy) - segments = FeatureProcessor.extract_segments(signal, sr, impact_starts) - - if not segments: - self.status_updated.emit("未检测到有效敲击片段") - return - - # 提取特征并预测 - self.status_updated.emit(f"提取 {len(segments)} 个片段的特征...") - features = [FeatureProcessor.extract_features(seg, sr) for seg in segments] - X = np.vstack(features) - - # 标准化特征 - X_scaled = scaler.transform(X) - - # 模型预测(处理标签转换:-1/1 → 0/1) - predictions = model.predict(X_scaled) - pred_proba = model.predict_proba(X_scaled)[:, 1] # 正类概率 - pred_labels = np.where(predictions == -1, 0, 1) # 统一为0/1标签 - - # 计算文件级结果 - mean_prob = pred_proba.mean() - final_label = 1 if mean_prob >= self.class_threshold else 0 - result = { - "filename": os.path.basename(self.audio_path), - "segment_count": len(segments), - "segment_labels": pred_labels.tolist(), - "segment_probs": [round(p, 4) for p in pred_proba], - "mean_probability": round(mean_prob, 4), - "final_label": final_label, - "label_text": "空心" if final_label == 0 else "实心" # 假设0=空心,1=实心 - } - - self.result_generated.emit(result) - - except Exception as e: - self.status_updated.emit(f"处理错误: {str(e)}") - - -# ---------------------- 数据库编辑对话框 ---------------------- -class EditResultDialog(QDialog): - def __init__(self, result_data=None, parent=None): - super().__init__(parent) - self.result_data = result_data or {} - self.init_ui() - - def init_ui(self): - self.setWindowTitle("编辑结果记录") - self.setModal(True) - self.resize(400, 300) - - layout = QFormLayout() - - # 文件名 - self.filename_edit = QLineEdit(self.result_data.get('filename', '')) - layout.addRow("文件名:", self.filename_edit) - - # 片段数量 - self.segment_count_spin = QSpinBox() - self.segment_count_spin.setRange(0, 1000) - self.segment_count_spin.setValue(self.result_data.get('segment_count', 0)) - layout.addRow("片段数量:", self.segment_count_spin) - - # 平均概率 - self.mean_prob_spin = QDoubleSpinBox() - self.mean_prob_spin.setRange(0.0, 1.0) - self.mean_prob_spin.setDecimals(4) - self.mean_prob_spin.setSingleStep(0.01) - self.mean_prob_spin.setValue(self.result_data.get('mean_probability', 0.0)) - layout.addRow("平均概率:", self.mean_prob_spin) - - # 最终标签 - self.final_label_combo = QComboBox() - self.final_label_combo.addItem("空心", 0) - self.final_label_combo.addItem("实心", 1) - current_label = self.result_data.get('final_label', 0) - index = self.final_label_combo.findData(current_label) - if index >= 0: - self.final_label_combo.setCurrentIndex(index) - layout.addRow("最终分类:", self.final_label_combo) - - # 备注 - self.notes_edit = QTextEdit() - self.notes_edit.setMaximumHeight(80) - self.notes_edit.setText(self.result_data.get('notes', '')) - layout.addRow("备注:", self.notes_edit) - - # 按钮 - button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - button_box.accepted.connect(self.accept) - button_box.rejected.connect(self.reject) - layout.addRow(button_box) - - self.setLayout(layout) - - def get_updated_data(self): - """获取更新后的数据""" - return { - 'filename': self.filename_edit.text(), - 'segment_count': self.segment_count_spin.value(), - 'mean_probability': self.mean_prob_spin.value(), - 'final_label': self.final_label_combo.currentData(), - 'label_text': self.final_label_combo.currentText(), - 'notes': self.notes_edit.toPlainText() - } - - -# ---------------------- 数据库汇总界面 ---------------------- -class DatabaseUI(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - self.parent = parent - self.db_manager = DatabaseManager() - self.init_ui() - self.load_data() - - def init_ui(self): - layout = QVBoxLayout() - layout.setContentsMargins(20, 20, 20, 20) - - # 返回按钮 - btn_back = QPushButton("← 返回主菜单") - btn_back.clicked.connect(lambda: self.parent.switch_page("main")) - layout.addWidget(btn_back) - - # 标题 - title = QLabel("数据库汇总管理") - title.setStyleSheet("font-size: 20px; margin: 15px 0;") - layout.addWidget(title) - - # 创建标签页 - self.tab_widget = QTabWidget() - - # 数据浏览标签页 - self.browse_tab = QWidget() - self.init_browse_tab() - self.tab_widget.addTab(self.browse_tab, "数据浏览") - - # 搜索标签页 - self.search_tab = QWidget() - self.init_search_tab() - self.tab_widget.addTab(self.search_tab, "搜索过滤") - - layout.addWidget(self.tab_widget) - self.setLayout(layout) - - def init_browse_tab(self): - layout = QVBoxLayout() - - # 操作按钮 - btn_layout = QHBoxLayout() - - self.btn_refresh = QPushButton("刷新数据") - self.btn_refresh.clicked.connect(self.load_data) - - self.btn_add = QPushButton("新增记录") - self.btn_add.clicked.connect(self.add_result) - - self.btn_edit = QPushButton("编辑选中") - self.btn_edit.clicked.connect(self.edit_selected) - - self.btn_delete = QPushButton("删除选中") - self.btn_delete.clicked.connect(self.delete_selected) - - btn_layout.addWidget(self.btn_refresh) - btn_layout.addWidget(self.btn_add) - btn_layout.addWidget(self.btn_edit) - btn_layout.addWidget(self.btn_delete) - btn_layout.addStretch() - - layout.addLayout(btn_layout) - - # 数据表格 - self.table_widget = QTableWidget() - self.table_widget.setColumnCount(8) - self.table_widget.setHorizontalHeaderLabels([ - "ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作" - ]) - self.table_widget.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) - self.table_widget.setSelectionBehavior(QTableWidget.SelectRows) - - layout.addWidget(self.table_widget) - self.browse_tab.setLayout(layout) - - def init_search_tab(self): - layout = QVBoxLayout() - - # 搜索条件表单 - form_layout = QFormLayout() - - self.search_filename = QLineEdit() - self.search_filename.setPlaceholderText("输入文件名关键词") - form_layout.addRow("文件名:", self.search_filename) - - self.search_label = QComboBox() - self.search_label.addItem("全部", None) - self.search_label.addItem("空心", 0) - self.search_label.addItem("实心", 1) - form_layout.addRow("分类结果:", self.search_label) - - self.search_date = QLineEdit() - self.search_date.setPlaceholderText("YYYY-MM-DD") - form_layout.addRow("创建日期:", self.search_date) - - layout.addLayout(form_layout) - - # 搜索按钮 - btn_search = QPushButton("搜索") - btn_search.clicked.connect(self.perform_search) - layout.addWidget(btn_search) - - # 搜索结果表格 - self.search_table = QTableWidget() - self.search_table.setColumnCount(8) - self.search_table.setHorizontalHeaderLabels([ - "ID", "文件名", "片段数", "平均概率", "分类结果", "创建时间", "备注", "操作" - ]) - self.search_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) - self.search_table.setSelectionBehavior(QTableWidget.SelectRows) - - layout.addWidget(self.search_table) - self.search_tab.setLayout(layout) - - def load_data(self): - """加载所有数据到浏览表格""" - results = self.db_manager.get_all_results() - self.populate_table(self.table_widget, results) - - def perform_search(self): - """执行搜索操作""" - filename_filter = self.search_filename.text().strip() - label_filter = self.search_label.currentData() - date_filter = self.search_date.text().strip() or None - - results = self.db_manager.search_results(filename_filter, label_filter, date_filter) - self.populate_table(self.search_table, results) - - def populate_table(self, table, results): - """填充表格数据""" - table.setRowCount(len(results)) - - for row, result in enumerate(results): - # ID - table.setItem(row, 0, QTableWidgetItem(str(result['id']))) - - # 文件名 - table.setItem(row, 1, QTableWidgetItem(result['filename'])) - - # 片段数 - table.setItem(row, 2, QTableWidgetItem(str(result['segment_count']))) - - # 平均概率 - table.setItem(row, 3, QTableWidgetItem(f"{result['mean_probability']:.4f}")) - - # 分类结果 - table.setItem(row, 4, QTableWidgetItem(result['label_text'])) - - # 创建时间 - create_time = result['create_time'] - if isinstance(create_time, str): - table.setItem(row, 5, QTableWidgetItem(create_time)) - else: - table.setItem(row, 5, QTableWidgetItem(create_time.strftime("%Y-%m-%d %H:%M:%S"))) - - # 备注 - notes = result.get('notes', '') - table.setItem(row, 6, QTableWidgetItem(notes if notes else "无")) - - # 操作按钮 - btn_view = QPushButton("查看详情") - btn_view.clicked.connect(lambda checked, r=result: self.view_details(r)) - table.setCellWidget(row, 7, btn_view) - - def add_result(self): - """新增结果记录""" - dialog = EditResultDialog() - if dialog.exec_() == QDialog.Accepted: - new_data = dialog.get_updated_data() - # 设置一些默认值 - new_data['segment_labels'] = [] - new_data['segment_probs'] = [] - - try: - self.db_manager.insert_result(new_data, new_data.get('notes', '')) - self.load_data() - QMessageBox.information(self, "成功", "记录添加成功!") - except Exception as e: - QMessageBox.warning(self, "错误", f"添加记录失败: {str(e)}") - - def edit_selected(self): - """编辑选中的记录""" - current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget() - current_row = current_table.currentRow() - - if current_row < 0: - QMessageBox.warning(self, "警告", "请先选择一条记录!") - return - - result_id = int(current_table.item(current_row, 0).text()) - results = self.db_manager.search_results() - result_data = next((r for r in results if r['id'] == result_id), None) - - if not result_data: - QMessageBox.warning(self, "错误", "未找到选中的记录!") - return - - dialog = EditResultDialog(result_data) - if dialog.exec_() == QDialog.Accepted: - updated_data = dialog.get_updated_data() - - try: - success = self.db_manager.update_result(result_id, updated_data) - if success: - self.load_data() - if current_table == self.search_table: - self.perform_search() - QMessageBox.information(self, "成功", "记录更新成功!") - else: - QMessageBox.warning(self, "错误", "更新记录失败!") - except Exception as e: - QMessageBox.warning(self, "错误", f"更新记录失败: {str(e)}") - - def delete_selected(self): - """删除选中的记录""" - current_table = self.tab_widget.currentWidget().layout().itemAt(1).widget() - current_row = current_table.currentRow() - - if current_row < 0: - QMessageBox.warning(self, "警告", "请先选择一条记录!") - return - - result_id = int(current_table.item(current_row, 0).text()) - filename = current_table.item(current_row, 1).text() - - reply = QMessageBox.question( - self, "确认删除", - f"确定要删除文件 '{filename}' 的记录吗?", - QMessageBox.Yes | QMessageBox.No - ) - - if reply == QMessageBox.Yes: - try: - success = self.db_manager.delete_result(result_id) - if success: - self.load_data() - if current_table == self.search_table: - self.perform_search() - QMessageBox.information(self, "成功", "记录删除成功!") - else: - QMessageBox.warning(self, "错误", "删除记录失败!") - except Exception as e: - QMessageBox.warning(self, "错误", f"删除记录失败: {str(e)}") - - def view_details(self, result_data): - """查看记录详情""" - details = ( - f"记录ID: {result_data['id']}\n" - f"文件名: {result_data['filename']}\n" - f"片段数量: {result_data['segment_count']}\n" - f"片段标签: {result_data['segment_labels']}\n" - f"片段概率: {result_data['segment_probs']}\n" - f"平均概率: {result_data['mean_probability']:.4f}\n" - f"最终分类: {result_data['label_text']}\n" - f"创建时间: {result_data['create_time']}\n" - f"备注: {result_data.get('notes', '无')}" - ) - - QMessageBox.information(self, "记录详情", details) - - -# ---------------------- 主界面组件 ---------------------- -class MainMenu(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - self.parent = parent - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - layout.setSpacing(20) - - # 标题 - title = QLabel("音频分类系统") - title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 28px; font-weight: bold; margin: 30px 0;") - layout.addWidget(title) - - # 功能按钮 - btn_record = QPushButton("录制音频") - btn_record.setMinimumHeight(70) - btn_record.setStyleSheet("font-size: 18px;") - btn_record.clicked.connect(lambda: self.parent.switch_page("record")) - - btn_upload = QPushButton("上传音频文件") - btn_upload.setMinimumHeight(70) - btn_upload.setStyleSheet("font-size: 18px;") - btn_upload.clicked.connect(lambda: self.parent.switch_page("upload")) - - btn_database = QPushButton("数据库管理") - btn_database.setMinimumHeight(70) - btn_database.setStyleSheet("font-size: 18px;") - btn_database.clicked.connect(lambda: self.parent.switch_page("database")) - - layout.addWidget(btn_record) - layout.addWidget(btn_upload) - layout.addWidget(btn_database) - layout.addStretch() - - self.setLayout(layout) - - -class InputPage(QWidget): - def __init__(self, parent=None, mode="record"): - super().__init__(parent) - self.parent = parent - self.mode = mode # "record" 或 "upload" - self.audio_path = "" - self.recorder = None - self.db_manager = DatabaseManager() # 添加数据库管理器 - self.init_ui() - - def init_ui(self): - # 主布局 - main_layout = QVBoxLayout() - main_layout.setContentsMargins(20, 20, 20, 20) - - # 返回按钮 - btn_back = QPushButton("← 返回主菜单") - btn_back.clicked.connect(lambda: self.parent.switch_page("main")) - main_layout.addWidget(btn_back) - - # 标题 - title = QLabel("录制音频" if self.mode == "record" else "上传音频文件") - title.setStyleSheet("font-size: 20px; margin: 15px 0;") - main_layout.addWidget(title) - - # 操作区域 - if self.mode == "record": - # 录音控制 - self.btn_start_rec = QPushButton("开始录音") - self.btn_start_rec.clicked.connect(self.start_recording) - self.btn_stop_rec = QPushButton("停止录音") - self.btn_stop_rec.setEnabled(False) - self.btn_stop_rec.clicked.connect(self.stop_recording) - - # 进度条 - self.progress_bar = QProgressBar() - self.progress_bar.setRange(0, 100) - self.progress_bar.setValue(0) - - # 录音布局 - rec_layout = QHBoxLayout() - rec_layout.addWidget(self.btn_start_rec) - rec_layout.addWidget(self.btn_stop_rec) - main_layout.addLayout(rec_layout) - main_layout.addWidget(self.progress_bar) - - else: - # 上传控制 - self.btn_browse = QPushButton("选择WAV文件") - self.btn_browse.clicked.connect(self.browse_file) - self.lbl_file = QLabel("未选择文件") - self.lbl_file.setStyleSheet("color: #666;") - main_layout.addWidget(self.btn_browse) - main_layout.addWidget(self.lbl_file) - - # 状态显示 - self.status_display = QTextEdit() - self.status_display.setReadOnly(True) - self.status_display.setMinimumHeight(150) - main_layout.addWidget(QLabel("状态信息:")) - main_layout.addWidget(self.status_display) - - # 处理按钮 - self.btn_process = QPushButton("开始分析") - self.btn_process.setEnabled(False) - self.btn_process.clicked.connect(self.process_audio) - main_layout.addWidget(self.btn_process) - - main_layout.addStretch() - self.setLayout(main_layout) - - def start_recording(self): - self.status_display.append("开始录音...") - self.recorder = AudioRecorder(max_duration=60) - self.recorder.status_updated.connect(self.update_status) - self.recorder.progress_updated.connect(self.progress_bar.setValue) - self.recorder.recording_finished.connect(self.on_recording_finished) - self.recorder.start() - self.btn_start_rec.setEnabled(False) - self.btn_stop_rec.setEnabled(True) - - def stop_recording(self): - if self.recorder and self.recorder.recording: - self.recorder.stop() - self.btn_start_rec.setEnabled(True) - self.btn_stop_rec.setEnabled(False) - - def on_recording_finished(self, file_path): - self.audio_path = file_path - self.btn_process.setEnabled(True) - self.status_display.append(f"录音文件已保存: {file_path}") - - def browse_file(self): - file_path, _ = QFileDialog.getOpenFileName( - self, "选择音频文件", "", "WAV文件 (*.wav)" - ) - if file_path: - self.audio_path = file_path - self.lbl_file.setText(os.path.basename(file_path)) - self.btn_process.setEnabled(True) - self.update_status(f"已选择文件: {file_path}") - - def process_audio(self): - if not self.audio_path or not os.path.exists(self.audio_path): - QMessageBox.warning(self, "错误", "音频文件不存在") - return - - # 模型路径(请根据实际情况修改) - model_path = "pipeline_model.pkl" # 或 "svm_model.pkl" - scaler_path = "scaler.pkl" - - if not (os.path.exists(model_path) and os.path.exists(scaler_path)): - QMessageBox.warning(self, "错误", "模型文件或标准化器不存在") - return - - # 启动处理线程 - self.processor = AudioProcessor(self.audio_path, model_path, scaler_path) - self.processor.status_updated.connect(self.update_status) - self.processor.result_generated.connect(self.show_result) - self.processor.start() - self.btn_process.setEnabled(False) - self.update_status("开始分析音频...") - - def update_status(self, message): - self.status_display.append(f"[{time.strftime('%H:%M:%S')}] {message}") - # 自动滚动到底部 - self.status_display.moveCursor(self.status_display.textCursor().End) - - def show_result(self, result): - # 显示结果对话框,增加保存到数据库的选项 - msg = QMessageBox() - msg.setWindowTitle("分析结果") - msg.setIcon(QMessageBox.Information) - - text = ( - f"文件名: {result['filename']}\n" - f"有效片段数: {result['segment_count']}\n" - f"平均概率: {result['mean_probability']}\n" - f"最终分类: {result['label_text']}" - ) - msg.setText(text) - - # 添加保存到数据库的按钮 - msg.addButton("保存到数据库", QMessageBox.AcceptRole) - msg.addButton("仅查看", QMessageBox.RejectRole) - - reply = msg.exec_() - - if reply == 0: # 保存到数据库 - notes, ok = QInputDialog.getText( - self, "添加备注", "请输入备注信息(可选):", - QLineEdit.Normal, "" - ) - if ok: - try: - self.db_manager.insert_result(result, notes) - self.update_status("结果已保存到数据库") - except Exception as e: - QMessageBox.warning(self, "保存失败", f"保存到数据库失败: {str(e)}") - - self.update_status("分析完成") - self.btn_process.setEnabled(True) - - -# ---------------------- 主窗口 ---------------------- -class AudioClassifierApp(QMainWindow): - def __init__(self): - super().__init__() - self.init_ui() - - def init_ui(self): - self.setWindowTitle("音频分类器") - self.setGeometry(300, 300, 800, 600) - - # 堆叠窗口管理页面 - self.stack = QStackedWidget() - self.main_menu = MainMenu(self) - self.record_page = InputPage(self, mode="record") - self.upload_page = InputPage(self, mode="upload") - self.database_page = DatabaseUI(self) # 新增数据库页面 - - self.stack.addWidget(self.main_menu) - self.stack.addWidget(self.record_page) - self.stack.addWidget(self.upload_page) - self.stack.addWidget(self.database_page) - - self.setCentralWidget(self.stack) - - def switch_page(self, page_name): - """切换页面""" - if page_name == "main": - self.stack.setCurrentWidget(self.main_menu) - elif page_name == "record": - self.stack.setCurrentWidget(self.record_page) - elif page_name == "upload": - self.stack.setCurrentWidget(self.upload_page) - elif page_name == "database": - self.stack.setCurrentWidget(self.database_page) - - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = AudioClassifierApp() - window.show() - sys.exit(app.exec_()) \ No newline at end of file diff --git a/audio_classifier_gui_02.py b/audio_classifier_gui_02.py deleted file mode 100644 index a630bd9..0000000 --- a/audio_classifier_gui_02.py +++ /dev/null @@ -1,630 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import sys -import os -import time -import numpy as np -import joblib -import librosa -import pyaudio -import wave -from scipy.signal import hilbert -from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, - QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, - QHBoxLayout, QWidget, QProgressBar, QStackedWidget, - QComboBox) # >>> CHANGED: add QComboBox -from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer - - -# 特征提取类 -class FeatureExtractor: - WIN = 1024 - OVERLAP = 512 - THRESHOLD = 0.1 - SEG_LEN_S = 0.2 - STEP = WIN - OVERLAP - - @staticmethod - def frame_energy(signal: np.ndarray) -> np.ndarray: - frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) - return np.sum(frames ** 2, axis=0) - - @staticmethod - def detect_hits(energy: np.ndarray) -> np.ndarray: - idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] - if idx.size == 0: - return np.array([]) - flags = np.diff(np.concatenate(([0], idx))) > 5 - return idx[flags] - - @staticmethod - def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]: - seg_len = int(FeatureExtractor.SEG_LEN_S * sr) - segments = [] - for frame_idx in hit_starts: - s = frame_idx * FeatureExtractor.STEP - e = min(s + seg_len, len(signal)) - segments.append(signal[s:e]) - return segments - - @staticmethod - def extract_features(x: np.ndarray, sr: int) -> np.ndarray: - x = x.flatten() - if len(x) == 0: - return np.zeros(5) - - # 1. RMS能量 - rms = np.sqrt(np.mean(x ** 2)) - - # 2. 主频(频谱峰值频率) - fft = np.fft.fft(x) - freq = np.fft.fftfreq(len(x), 1 / sr) - positive_freq_mask = freq >= 0 - freq = freq[positive_freq_mask] - fft_mag = np.abs(fft[positive_freq_mask]) - main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 - - # 3. 频谱偏度 - spec_power = fft_mag - centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) - spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) - skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( - (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 - - # 4. MFCC第一维均值 - mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) - mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 - - # 5. 包络峰值(希尔伯特变换) - env_peak = np.max(np.abs(hilbert(x))) - - return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) - - -# 录音线程 -class RecordThread(QThread): - update_signal = pyqtSignal(str) - finish_signal = pyqtSignal(str) - progress_signal = pyqtSignal(int) - level_signal = pyqtSignal(int) # >>> NEW: 实时电平(0-100) - - def __init__(self, max_duration=60, device_index=None): - super().__init__() - self.max_duration = max_duration - self.is_recording = False - self.temp_file = "temp_recording.wav" - self.device_index = device_index # >>> NEW - - def run(self): - # >>> CHANGED: 更通用的音频参数 - FORMAT = pyaudio.paInt16 - CHANNELS = 1 - RATE = 44100 - CHUNK = 1024 - - p = None - stream = None - try: - p = pyaudio.PyAudio() - - # 设备信息日志(帮助排查) - try: - device_log = [] - for i in range(p.get_device_count()): - info = p.get_device_info_by_index(i) - if int(info.get('maxInputChannels', 0)) > 0: - device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}") - if device_log: - self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log)) - except Exception as _: - pass - - # >>> CHANGED: 指定 input_device_index(若传入) - stream = p.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - input_device_index=self.device_index, # 可能为 None,则使用系统默认 - frames_per_buffer=CHUNK - ) - - self.is_recording = True - start_time = time.time() - frames = [] - - # 用于电平估计的满刻度 - max_int16 = 32767.0 - - while self.is_recording: - elapsed = time.time() - start_time - if elapsed >= self.max_duration: - self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") - break - - self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) - - data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED - frames.append(data) - - # >>> NEW: 计算电平(RMS) - # 将bytes转为np.int16,归一化到[-1,1],计算RMS并发射到UI - try: - chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 - rms = np.sqrt(np.mean(chunk_np ** 2)) - # 简单映射到0-100 - level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整 - self.level_signal.emit(level) - except Exception: - pass - - if stream is not None: - stream.stop_stream() - stream.close() - if p is not None: - p.terminate() - - if len(frames) > 0: - wf = wave.open(self.temp_file, 'wb') - wf.setnchannels(CHANNELS) - wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节 - wf.setframerate(RATE) - wf.writeframes(b''.join(frames)) - wf.close() - self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}秒") - self.finish_signal.emit(self.temp_file) - else: - self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)") - - except Exception as e: - # >>> NEW: 打开失败时提示设备列表 - err_msg = f"录制错误: {str(e)}" - try: - if p is not None: - err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)" - finally: - self.update_signal.emit(err_msg) - finally: - try: - if stream is not None: - stream.stop_stream() - stream.close() - except Exception: - pass - try: - if p is not None: - p.terminate() - except Exception: - pass - - -# 处理线程 -class ProcessThread(QThread): - update_signal = pyqtSignal(str) - finish_signal = pyqtSignal(dict) - - def __init__(self, wav_path, model_path, scaler_path): - super().__init__() - self.wav_path = wav_path - self.model_path = model_path - self.scaler_path = scaler_path - - def run(self): - try: - self.update_signal.emit("加载模型和标准化器...") - model = joblib.load(self.model_path) - scaler = joblib.load(self.scaler_path) - - self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") - sig, sr = librosa.load(self.wav_path, sr=None, mono=True) - # 归一化避免静音/削顶影响 - if np.max(np.abs(sig)) > 0: - sig = sig / np.max(np.abs(sig)) - - self.update_signal.emit("提取特征...") - ene = FeatureExtractor.frame_energy(sig) - hit_starts = FeatureExtractor.detect_hits(ene) - segments = FeatureExtractor.extract_segments(sig, sr, hit_starts) - if not segments: - self.update_signal.emit("未检测到有效片段!") - return - - feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments] - X = np.vstack(feats) - self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}") - - X_std = scaler.transform(X) - y_pred = model.predict(X_std) - # 兼容无 predict_proba 的分类器 - if hasattr(model, "predict_proba"): - y_proba = model.predict_proba(X_std)[:, 1] - else: - # 退化:用决策函数映射到(0,1) - if hasattr(model, "decision_function"): - df = model.decision_function(X_std) - y_proba = 1 / (1 + np.exp(-df)) - else: - y_proba = (y_pred.astype(float) + 0.0) - - self.finish_signal.emit({ - "filename": os.path.basename(self.wav_path), - "segments": len(segments), - "predictions": y_pred.tolist(), - "probabilities": [round(float(p), 4) for p in y_proba], - "mean_prob": round(float(np.mean(y_proba)), 4), - "final_label": int(np.mean(y_proba) >= 0.5) - }) - - except Exception as e: - self.update_signal.emit(f"错误: {str(e)}") - - -# 第一层界面:主菜单 -class MainMenuWidget(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - self.parent = parent - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - title = QLabel("音频分类器") - title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") - layout.addWidget(title) - - # 添加按钮 - record_btn = QPushButton("采集音频") - record_btn.setMinimumHeight(60) - record_btn.setStyleSheet("font-size: 16px;") - record_btn.clicked.connect(lambda: self.parent.switch_to_input("record")) - - upload_btn = QPushButton("上传外部WAV文件") - upload_btn.setMinimumHeight(60) - upload_btn.setStyleSheet("font-size: 16px;") - upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload")) - - layout.addWidget(record_btn) - layout.addWidget(upload_btn) - layout.addStretch(1) - - self.setLayout(layout) - - -# 第二层界面:输入界面(录音或上传文件) -class InputWidget(QWidget): - def __init__(self, parent=None, mode="record"): - super().__init__(parent) - self.parent = parent - self.mode = mode - self.wav_path = "" - self.record_thread = None - self.device_index = None # >>> NEW: 当前选中的输入设备索引 - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - # 返回按钮 - back_btn = QPushButton("返回") - back_btn.clicked.connect(self.parent.switch_to_main_menu) - layout.addWidget(back_btn) - - if self.mode == "record": - self.setup_record_ui(layout) - else: - self.setup_upload_ui(layout) - - # 模型路径 - model_layout = QHBoxLayout() - self.model_path = QLineEdit("svm_model.pkl") - self.scaler_path = QLineEdit("scaler.pkl") - model_layout.addWidget(QLabel("模型路径:")) - model_layout.addWidget(self.model_path) - model_layout.addWidget(QLabel("标准化器路径:")) - model_layout.addWidget(self.scaler_path) - layout.addLayout(model_layout) - - # 处理按钮 - self.process_btn = QPushButton("开始处理") - self.process_btn.setMinimumHeight(50) - self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;") - self.process_btn.clicked.connect(self.start_process) - self.process_btn.setEnabled(False) # 初始不可用 - layout.addWidget(self.process_btn) - - # 日志区域 - self.log_area = QTextEdit() - self.log_area.setReadOnly(True) - layout.addWidget(QLabel("日志:")) - layout.addWidget(self.log_area) - - self.setLayout(layout) - - def setup_record_ui(self, layout): - title = QLabel("音频采集") - title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") - layout.addWidget(title) - - # >>> NEW: 麦克风选择 - device_layout = QHBoxLayout() - device_layout.addWidget(QLabel("输入设备:")) - self.device_combo = QComboBox() - self.refresh_devices() - self.device_combo.currentIndexChanged.connect(self.on_device_changed) - device_layout.addWidget(self.device_combo) - refresh_btn = QPushButton("刷新设备") - refresh_btn.clicked.connect(self.refresh_devices) - device_layout.addWidget(refresh_btn) - layout.addLayout(device_layout) - - # 录音控制 - record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)") - record_hint.setAlignment(Qt.AlignCenter) - - self.record_btn = QPushButton("按住录音") - self.record_btn.setMinimumHeight(80) - self.record_btn.setStyleSheet(""" - QPushButton { - background-color: #ff4d4d; - color: white; - font-size: 18px; - border-radius: 10px; - } - QPushButton:pressed { - background-color: #cc0000; - } - """) - self.record_btn.mousePressEvent = self.start_recording - self.record_btn.mouseReleaseEvent = self.stop_recording - self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu) - - # >>> NEW: 实时电平与录音进度 - self.mic_level = QProgressBar() - self.mic_level.setRange(0, 100) - self.mic_level.setFormat("麦克风电平:%p%") - - self.record_progress = QProgressBar() - self.record_progress.setRange(0, 100) - self.record_progress.setValue(0) - - self.record_duration_label = QLabel("录音时长: 0.0秒") - self.record_duration_label.setAlignment(Qt.AlignCenter) - - layout.addWidget(record_hint) - layout.addWidget(self.record_btn) - layout.addWidget(self.mic_level) # >>> NEW - layout.addWidget(self.record_progress) - layout.addWidget(self.record_duration_label) - - # 录音计时器 - self.record_timer = QTimer(self) - self.record_timer.timeout.connect(self.update_record_duration) - self.record_start_time = 0 - - def refresh_devices(self): - """枚举有输入通道的设备,并填充到下拉框""" - self.device_combo.clear() - try: - p = pyaudio.PyAudio() - default_host_api = p.get_host_api_info_by_index(0) - default_input_index = default_host_api.get("defaultInputDevice", -1) - found = [] - for i in range(p.get_device_count()): - info = p.get_device_info_by_index(i) - if int(info.get('maxInputChannels', 0)) > 0: - name = info.get('name', f"Device {i}") - sr = int(info.get('defaultSampleRate', 0)) - label = f"[{i}] {name} (sr={sr})" - self.device_combo.addItem(label, i) - found.append(i) - p.terminate() - - # 选中默认输入设备(若存在) - if default_input_index in found: - idx = found.index(default_input_index) - self.device_combo.setCurrentIndex(idx) - self.device_index = default_input_index - elif found: - self.device_combo.setCurrentIndex(0) - self.device_index = self.device_combo.currentData() - else: - self.device_index = None - except Exception: - self.device_index = None - - def on_device_changed(self, _): - self.device_index = self.device_combo.currentData() - - def setup_upload_ui(self, layout): - title = QLabel("上传WAV文件") - title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") - layout.addWidget(title) - - # 文件选择 - file_layout = QHBoxLayout() - self.file_path = QLineEdit() - self.file_path.setReadOnly(True) - self.browse_btn = QPushButton("浏览WAV文件") - self.browse_btn.clicked.connect(self.browse_file) - file_layout.addWidget(self.file_path) - file_layout.addWidget(self.browse_btn) - layout.addLayout(file_layout) - - def start_recording(self, event): - if event.button() == Qt.LeftButton: - if not self.record_thread or not self.record_thread.isRunning(): - # >>> CHANGED: 传入 device_index - self.record_thread = RecordThread(max_duration=60, device_index=self.device_index) - self.record_thread.update_signal.connect(self.update_log) - self.record_thread.finish_signal.connect(self.on_recording_finish) - self.record_thread.progress_signal.connect(self.record_progress.setValue) - self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW - self.record_thread.start() - - self.record_start_time = time.time() - self.record_timer.start(100) - self.update_log("开始录音...(松开按钮结束)") - - def stop_recording(self, event): - if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning(): - self.record_thread.is_recording = False - self.record_timer.stop() - self.record_duration_label.setText("录音时长: 0.0秒") - self.record_progress.setValue(0) - # 让电平逐步回落 - self.mic_level.setValue(0) - - def update_record_duration(self): - elapsed = time.time() - self.record_start_time - self.record_duration_label.setText(f"录音时长: {elapsed:.1f}秒") - - def on_recording_finish(self, temp_file): - self.wav_path = temp_file - self.process_btn.setEnabled(True) - - def browse_file(self): - file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)") - if file: - self.file_path.setText(file) - self.wav_path = file - self.process_btn.setEnabled(True) - - def start_process(self): - if self.mode == "upload": - self.wav_path = self.file_path.text() - - model_path = self.model_path.text() - scaler_path = self.scaler_path.text() - - if not self.wav_path or not os.path.exists(self.wav_path): - self.log_area.append("请先录音或选择有效的WAV文件!") - return - if not os.path.exists(model_path) or not os.path.exists(scaler_path): - self.log_area.append("模型文件不存在!请先运行train_model.py训练模型") - return - - self.log_area.clear() - self.process_btn.setEnabled(False) - - self.thread = ProcessThread(self.wav_path, model_path, scaler_path) - self.thread.update_signal.connect(self.update_log) - self.thread.finish_signal.connect(self.on_process_finish) - self.thread.start() - - def update_log(self, msg): - self.log_area.append(msg) - - def on_process_finish(self, result): - self.parent.switch_to_result(result) - - -# 第三层界面:结果显示界面 -class ResultWidget(QWidget): - def __init__(self, parent=None, result=None): - super().__init__(parent) - self.parent = parent - self.result = result - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - # 返回按钮 - back_btn = QPushButton("返回") - back_btn.clicked.connect(self.parent.switch_to_input_from_result) - layout.addWidget(back_btn) - - title = QLabel("处理结果") - title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") - layout.addWidget(title) - - # 结果显示 - self.result_area = QTextEdit() - self.result_area.setReadOnly(True) - - if self.result: - res = f"文件名: {self.result['filename']}\n" - res += f"片段数: {self.result['segments']}\n" - res += "预测结果:\n" - for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])): - res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n" - res += f"\n平均概率: {self.result['mean_prob']}\n" - res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}" - self.result_area.setText(res) - - layout.addWidget(self.result_area) - self.setLayout(layout) - - -# 主窗口 -class AudioClassifierGUI(QMainWindow): - def __init__(self): - super().__init__() - self.current_input_mode = "record" # 记录当前输入模式 - self.process_result = None # 存储处理结果 - self.init_ui() - - def init_ui(self): - self.setWindowTitle("音频分类器") - self.setGeometry(100, 100, 800, 600) - - # 创建堆叠窗口 - self.stacked_widget = QStackedWidget() - self.setCentralWidget(self.stacked_widget) - - # 创建三个界面 - self.main_menu_widget = MainMenuWidget(self) - self.record_input_widget = InputWidget(self, "record") - self.upload_input_widget = InputWidget(self, "upload") - self.result_widget = ResultWidget(self) - - # 添加到堆叠窗口 - self.stacked_widget.addWidget(self.main_menu_widget) - self.stacked_widget.addWidget(self.record_input_widget) - self.stacked_widget.addWidget(self.upload_input_widget) - self.stacked_widget.addWidget(self.result_widget) - - # 显示主菜单 - self.stacked_widget.setCurrentWidget(self.main_menu_widget) - - def switch_to_input(self, mode): - self.current_input_mode = mode - if mode == "record": - self.stacked_widget.setCurrentWidget(self.record_input_widget) - else: - self.stacked_widget.setCurrentWidget(self.upload_input_widget) - - def switch_to_main_menu(self): - self.stacked_widget.setCurrentWidget(self.main_menu_widget) - - def switch_to_result(self, result): - self.process_result = result - self.result_widget = ResultWidget(self, result) - # 移除旧的结果界面并添加新的 - self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) - self.stacked_widget.addWidget(self.result_widget) - self.stacked_widget.setCurrentWidget(self.result_widget) - - def switch_to_input_from_result(self): - if self.current_input_mode == "record": - self.stacked_widget.setCurrentWidget(self.record_input_widget) - else: - self.stacked_widget.setCurrentWidget(self.upload_input_widget) - - -if __name__ == "__main__": - try: - import pyaudio # 确保已安装 - except ImportError: - print("请先安装pyaudio: pip install pyaudio") - sys.exit(1) - - app = QApplication(sys.argv) - window = AudioClassifierGUI() - window.show() - sys.exit(app.exec_()) diff --git a/audio_classifier_gui_03.py b/audio_classifier_gui_03.py deleted file mode 100644 index d33068a..0000000 --- a/audio_classifier_gui_03.py +++ /dev/null @@ -1,917 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import sys -import os -import time -import csv -import sqlite3 -from datetime import datetime -import numpy as np -import joblib -import librosa -import pyaudio -import wave -from scipy.signal import hilbert - -from PyQt5.QtWidgets import ( - QApplication, QMainWindow, QPushButton, QLabel, - QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, - QHBoxLayout, QWidget, QProgressBar, QStackedWidget, - QComboBox, QTableWidget, QTableWidgetItem, QMessageBox, - QDialog, QFormLayout, QDialogButtonBox -) -from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer - - -# ====================== -# 本地 SQLite 数据库(含 CRUD) -# ====================== -class DatabaseManager: - def __init__(self, db_path="results.db"): - self.db_path = db_path - self._ensure_schema() - - def _connect(self): - return sqlite3.connect(self.db_path) - - def _ensure_schema(self): - con = self._connect() - cur = con.cursor() - cur.execute(""" - CREATE TABLE IF NOT EXISTS runs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - filename TEXT, - segments INTEGER, - mean_prob REAL, - final_label INTEGER, - created_at TEXT - ); - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS run_segments ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id INTEGER, - seg_index INTEGER, - label INTEGER, - proba REAL, - FOREIGN KEY(run_id) REFERENCES runs(id) - ); - """) - con.commit() - con.close() - - # ---------- 原有写入 ---------- - def insert_result(self, result_dict): - con = self._connect() - cur = con.cursor() - cur.execute( - "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", - ( - result_dict.get("filename"), - int(result_dict.get("segments", 0)), - float(result_dict.get("mean_prob", 0.0)), - int(result_dict.get("final_label", 0)), - datetime.now().strftime("%Y-%m-%d %H:%M:%S") - ) - ) - run_id = cur.lastrowid - preds = result_dict.get("predictions", []) - probas = result_dict.get("probabilities", []) - for i, (lab, pr) in enumerate(zip(preds, probas), start=1): - cur.execute( - "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", - (run_id, i, int(lab), float(pr)) - ) - con.commit() - con.close() - return run_id - - # ---------- R(查询列表/单条) ---------- - def fetch_recent_runs(self, limit=50): - con = self._connect() - cur = con.cursor() - cur.execute(""" - SELECT id, created_at, filename, segments, mean_prob, final_label - FROM runs ORDER BY id DESC LIMIT ? - """, (limit,)) - rows = cur.fetchall() - con.close() - return rows - - def get_run(self, run_id: int): - con = self._connect() - cur = con.cursor() - cur.execute(""" - SELECT id, created_at, filename, segments, mean_prob, final_label - FROM runs WHERE id = ? - """, (run_id,)) - row = cur.fetchone() - con.close() - return row - - def search_runs(self, keyword: str, limit=100): - kw = f"%{keyword}%" - con = self._connect() - cur = con.cursor() - cur.execute(""" - SELECT id, created_at, filename, segments, mean_prob, final_label - FROM runs - WHERE filename LIKE ? - ORDER BY id DESC LIMIT ? - """, (kw, limit)) - rows = cur.fetchall() - con.close() - return rows - - # ---------- C/U/D(run) ---------- - def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None): - created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S") - con = self._connect() - cur = con.cursor() - cur.execute( - "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", - (filename, int(segments), float(mean_prob), int(final_label), created_at) - ) - run_id = cur.lastrowid - con.commit() - con.close() - return run_id - - def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str): - con = self._connect() - cur = con.cursor() - cur.execute(""" - UPDATE runs - SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=? - WHERE id=? - """, (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id))) - con.commit() - con.close() - - def delete_run(self, run_id: int): - con = self._connect() - cur = con.cursor() - # 先删子表 - cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),)) - cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),)) - con.commit() - con.close() - - # ---------- C/R/U/D(segments) ---------- - def list_segments(self, run_id: int): - con = self._connect() - cur = con.cursor() - cur.execute(""" - SELECT id, seg_index, label, proba FROM run_segments - WHERE run_id=? ORDER BY seg_index ASC - """, (int(run_id),)) - rows = cur.fetchall() - con.close() - return rows - - def add_segment(self, run_id: int, seg_index: int, label: int, proba: float): - con = self._connect() - cur = con.cursor() - cur.execute( - "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", - (int(run_id), int(seg_index), int(label), float(proba)) - ) - seg_id = cur.lastrowid - con.commit() - con.close() - return seg_id - - def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float): - con = self._connect() - cur = con.cursor() - cur.execute(""" - UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=? - """, (int(seg_index), int(label), float(proba), int(seg_id))) - con.commit() - con.close() - - def delete_segment(self, seg_id: int): - con = self._connect() - cur = con.cursor() - cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),)) - con.commit() - con.close() - - # ---------- 导出 ---------- - def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000): - rows = self.fetch_recent_runs(limit) - headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.writer(f) - writer.writerow(headers) - for r in rows: - writer.writerow(r) - return csv_path - - -DB = DatabaseManager("results.db") - - -# ====================== -# 特征提取(同前) -# ====================== -class FeatureExtractor: - WIN = 1024 - OVERLAP = 512 - THRESHOLD = 0.1 - SEG_LEN_S = 0.2 - STEP = WIN - OVERLAP - - @staticmethod - def frame_energy(signal: np.ndarray) -> np.ndarray: - frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) - return np.sum(frames ** 2, axis=0) - - @staticmethod - def detect_hits(energy: np.ndarray) -> np.ndarray: - idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] - if idx.size == 0: - return np.array([]) - flags = np.diff(np.concatenate(([0], idx))) > 5 - return idx[flags] - - @staticmethod - def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray): - seg_len = int(FeatureExtractor.SEG_LEN_S * sr) - segments = [] - for frame_idx in hit_starts: - s = frame_idx * FeatureExtractor.STEP - e = min(s + seg_len, len(signal)) - segments.append(signal[s:e]) - return segments - - @staticmethod - def extract_features(x: np.ndarray, sr: int) -> np.ndarray: - x = x.flatten() - if len(x) == 0: - return np.zeros(5) - rms = np.sqrt(np.mean(x ** 2)) - fft = np.fft.fft(x) - freq = np.fft.fftfreq(len(x), 1 / sr) - positive = freq >= 0 - freq, fft_mag = freq[positive], np.abs(fft[positive]) - main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 - spec_power = fft_mag - centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) - spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) - skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( - (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 - mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) - mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 - env_peak = np.max(np.abs(hilbert(x))) - return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) - - -# ====================== -# 录音线程(同前) -# ====================== -class RecordThread(QThread): - update_signal = pyqtSignal(str) - finish_signal = pyqtSignal(str) - progress_signal = pyqtSignal(int) - level_signal = pyqtSignal(int) - - def __init__(self, max_duration=60, device_index=None): - super().__init__() - self.max_duration = max_duration - self.is_recording = False - self.temp_file = "temp_recording.wav" - self.device_index = device_index - - def run(self): - FORMAT = pyaudio.paInt16 - CHANNELS = 1 - RATE = 44100 - CHUNK = 1024 - p, stream = None, None - try: - p = pyaudio.PyAudio() - stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, - input_device_index=self.device_index, frames_per_buffer=CHUNK) - self.is_recording = True - start_time = time.time() - frames = [] - max_int16 = 32767.0 - while self.is_recording: - elapsed = time.time() - start_time - if elapsed >= self.max_duration: - self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") - break - self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) - data = stream.read(CHUNK, exception_on_overflow=False) - frames.append(data) - chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 - rms = np.sqrt(np.mean(chunk_np ** 2)) - self.level_signal.emit(int(np.clip(rms * 500, 0, 100))) - stream.stop_stream(); stream.close(); p.terminate() - if frames: - wf = wave.open(self.temp_file, 'wb') - wf.setnchannels(CHANNELS) - wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) - wf.setframerate(RATE) - wf.writeframes(b''.join(frames)) - wf.close() - self.finish_signal.emit(self.temp_file) - else: - self.update_signal.emit("未录到有效音频。") - except Exception as e: - self.update_signal.emit(f"录制错误: {e}") - - -# ====================== -# 推理处理线程(写入 DB) -# ====================== -class ProcessThread(QThread): - update_signal = pyqtSignal(str) - finish_signal = pyqtSignal(dict) - - def __init__(self, wav_path, model_path, scaler_path): - super().__init__() - self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path - - def run(self): - try: - self.update_signal.emit("加载模型和标准化器...") - model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path) - self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") - sig, sr = librosa.load(self.wav_path, sr=None, mono=True) - sig = sig / max(np.max(np.abs(sig)), 1e-9) - ene = FeatureExtractor.frame_energy(sig) - hits = FeatureExtractor.detect_hits(ene) - segs = FeatureExtractor.extract_segments(sig, sr, hits) - if not segs: - self.update_signal.emit("未检测到有效片段!") - return - feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs] - X = np.vstack(feats) - X_std = scaler.transform(X) - y_pred = model.predict(X_std) - if hasattr(model, "predict_proba"): - y_proba = model.predict_proba(X_std)[:, 1] - else: - y_proba = (y_pred.astype(float) + 0.0) - result = { - "filename": os.path.basename(self.wav_path), - "segments": len(segs), - "predictions": y_pred.tolist(), - "probabilities": [round(float(p), 4) for p in y_proba], - "mean_prob": round(float(np.mean(y_proba)), 4), - "final_label": int(np.mean(y_proba) >= 0.5) - } - DB.insert_result(result) - self.finish_signal.emit(result) - except Exception as e: - self.update_signal.emit(f"错误: {e}") - - -# ====================== -# 对话框:编辑 run -# ====================== -class RunEditDialog(QDialog): - def __init__(self, parent=None, run_row=None): - """ - run_row: (id, created_at, filename, segments, mean_prob, final_label) 或 None(新增) - """ - super().__init__(parent) - self.setWindowTitle("编辑记录" if run_row else "新增记录") - self.resize(380, 240) - self.run_row = run_row - - self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True) - self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - self.ed_filename = QLineEdit() - self.ed_segments = QLineEdit("0") - self.ed_mean_prob = QLineEdit("0.0") - self.ed_final_label = QLineEdit("0") # 0=实心 1=空心 - - form = QFormLayout() - form.addRow("ID:", self.ed_id) - form.addRow("时间(created_at):", self.ed_created) - form.addRow("文件名(filename):", self.ed_filename) - form.addRow("片段数(segments):", self.ed_segments) - form.addRow("平均概率(mean_prob):", self.ed_mean_prob) - form.addRow("最终标签(0=实 1=空):", self.ed_final_label) - - if run_row: - self.ed_id.setText(str(run_row[0])) - self.ed_created.setText(str(run_row[1])) - self.ed_filename.setText(str(run_row[2])) - self.ed_segments.setText(str(run_row[3])) - self.ed_mean_prob.setText(str(run_row[4])) - self.ed_final_label.setText(str(run_row[5])) - - btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - btns.accepted.connect(self.accept) - btns.rejected.connect(self.reject) - - lay = QVBoxLayout() - lay.addLayout(form) - lay.addWidget(btns) - self.setLayout(lay) - - def values(self): - return dict( - id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None), - created_at=self.ed_created.text().strip(), - filename=self.ed_filename.text().strip(), - segments=int(float(self.ed_segments.text() or 0)), - mean_prob=float(self.ed_mean_prob.text() or 0.0), - final_label=int(float(self.ed_final_label.text() or 0)) - ) - - -# ====================== -# 对话框:管理 segments -# ====================== -class SegmentManagerDialog(QDialog): - def __init__(self, parent=None, run_id=None): - super().__init__(parent) - self.setWindowTitle(f"片段管理 - run_id={run_id}") - self.resize(560, 420) - self.run_id = run_id - - self.table = QTableWidget(0, 4) - self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"]) - - self.btn_add = QPushButton("新增片段") - self.btn_edit = QPushButton("编辑所选") - self.btn_del = QPushButton("删除所选") - self.btn_close = QPushButton("关闭") - - hl = QHBoxLayout() - for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close): - hl.addWidget(b) - - lay = QVBoxLayout() - lay.addWidget(self.table) - lay.addLayout(hl) - self.setLayout(lay) - - self.btn_add.clicked.connect(self.add_segment) - self.btn_edit.clicked.connect(self.edit_selected) - self.btn_del.clicked.connect(self.delete_selected) - self.btn_close.clicked.connect(self.accept) - - self.load_segments() - - def load_segments(self): - rows = DB.list_segments(self.run_id) - self.table.setRowCount(len(rows)) - for i, (sid, idx, label, proba) in enumerate(rows): - self.table.setItem(i, 0, QTableWidgetItem(str(sid))) - self.table.setItem(i, 1, QTableWidgetItem(str(idx))) - self.table.setItem(i, 2, QTableWidgetItem(str(label))) - self.table.setItem(i, 3, QTableWidgetItem(str(proba))) - self.table.resizeColumnsToContents() - - def _get_selected_seg_id(self): - r = self.table.currentRow() - if r < 0: - return None - return int(self.table.item(r, 0).text()) - - def add_segment(self): - d = QDialog(self) - d.setWindowTitle("新增片段") - form = QFormLayout() - ed_index = QLineEdit("1") - ed_label = QLineEdit("0") - ed_proba = QLineEdit("0.0") - form.addRow("seg_index:", ed_index) - form.addRow("label(0/1):", ed_label) - form.addRow("proba(0~1):", ed_proba) - btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) - v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) - if d.exec_() == QDialog.Accepted: - DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)), - float(ed_proba.text() or 0.0)) - self.load_segments() - - def edit_selected(self): - seg_id = self._get_selected_seg_id() - if not seg_id: - QMessageBox.information(self, "提示", "请先选择一行片段。") - return - # 读当前值 - row = self.table.currentRow() - cur_idx = self.table.item(row, 1).text() - cur_label = self.table.item(row, 2).text() - cur_proba = self.table.item(row, 3).text() - - d = QDialog(self) - d.setWindowTitle("编辑片段") - form = QFormLayout() - ed_index = QLineEdit(cur_idx) - ed_label = QLineEdit(cur_label) - ed_proba = QLineEdit(cur_proba) - form.addRow("seg_index:", ed_index) - form.addRow("label(0/1):", ed_label) - form.addRow("proba(0~1):", ed_proba) - btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) - btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) - v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) - if d.exec_() == QDialog.Accepted: - DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)), - int(float(ed_label.text() or cur_label)), - float(ed_proba.text() or cur_proba)) - self.load_segments() - - def delete_selected(self): - seg_id = self._get_selected_seg_id() - if not seg_id: - QMessageBox.information(self, "提示", "请先选择一行片段。") - return - if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!", - QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes: - DB.delete_segment(seg_id) - self.load_segments() - - -# ====================== -# 主菜单 -# ====================== -class MainMenuWidget(QWidget): - def __init__(self, parent): - super().__init__(parent) - layout = QVBoxLayout() - title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter) - title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") - layout.addWidget(title) - btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record")) - btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload")) - btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history) - for b in [btn1, btn2, btn3]: - b.setMinimumHeight(60); layout.addWidget(b) - layout.addStretch(1); self.setLayout(layout) - - -# ====================== -# 输入界面(录音 / 上传) -# ====================== -class InputWidget(QWidget): - def __init__(self, parent, mode="record"): - super().__init__(parent) - self.parent, self.mode = parent, mode - self.wav_path, self.record_thread = "", None - self.device_index = None - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) - layout.addWidget(back) - if self.mode == "record": self.setup_record(layout) - else: self.setup_upload(layout) - self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") - path_layout = QHBoxLayout() - path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) - path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) - layout.addLayout(path_layout) - self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) - self.process_btn.clicked.connect(self.start_process) - layout.addWidget(self.process_btn) - self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) - self.setLayout(layout) - - def setup_record(self, layout): - title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) - layout.addWidget(title) - dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) - self.dev_combo = QComboBox(); self.refresh_devices() - self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) - dev_layout.addWidget(self.dev_combo) - btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) - dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) - self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) - self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec - self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") - self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) - for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) - self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) - - def setup_upload(self, layout): - h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) - b = QPushButton("浏览"); b.clicked.connect(self.browse) - h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) - - def refresh_devices(self): - self.dev_combo.clear(); p = pyaudio.PyAudio() - for i in range(p.get_device_count()): - info = p.get_device_info_by_index(i) - if int(info.get('maxInputChannels',0))>0: - self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) - if self.dev_combo.count()>0: - self.dev_combo.setCurrentIndex(0) - self.device_index = self.dev_combo.currentData() - p.terminate() - - def start_rec(self,e): - if e.button()==Qt.LeftButton: - self.record_thread=RecordThread(device_index=self.device_index) - self.record_thread.update_signal.connect(self.log.append) - self.record_thread.finish_signal.connect(self.finish_rec) - self.record_thread.progress_signal.connect(self.progress.setValue) - self.record_thread.level_signal.connect(self.level.setValue) - self.record_thread.start(); self.start=time.time(); self.timer.start(100) - - def stop_rec(self,e): - if e.button()==Qt.LeftButton and self.record_thread: - self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) - - def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") - - def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) - - def browse(self): - f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") - if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) - - def start_process(self): - if self.mode=="upload": self.wav_path=self.file.text() - if not os.path.exists(self.wav_path): self.log.append("无效文件"); return - self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) - self.thread.update_signal.connect(self.log.append) - self.thread.finish_signal.connect(self.parent.switch_to_result) - self.thread.start() - - -# ====================== -# 历史记录(含 CRUD 按钮) -# ====================== -class HistoryWidget(QWidget): - def __init__(self, parent): - super().__init__(parent) - self.parent=parent - self.init_ui() - - def init_ui(self): - l=QVBoxLayout() - back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back) - - top=QHBoxLayout() - self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...") - btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search) - btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load) - btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv) - btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run) - btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected) - btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected) - btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments) - for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]: - top.addWidget(w) - l.addLayout(top) - - self.table=QTableWidget(0,6) - self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"]) - l.addWidget(self.table) - self.setLayout(l) - self.load() - - def _current_run_id(self): - r=self.table.currentRow() - if r<0: return None - return int(self.table.item(r,0).text()) - - # ---- 数据加载 / 搜索 ---- - def load(self): - rows=DB.fetch_recent_runs(200) - self._fill(rows) - - def search(self): - kw=self.search_edit.text().strip() - rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200) - self._fill(rows) - - def _fill(self, rows): - self.table.setRowCount(len(rows)) - for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows): - self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid))) - self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at))) - self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn))) - self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs))) - self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob))) - self.table.setItem(r_idx, 5, QTableWidgetItem("空心" if int(final_label) else "实心")) - self.table.resizeColumnsToContents() - - # ---- 导出 ---- - def export_csv(self): - p=DB.export_runs_to_csv() - QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}") - - # ---- CRUD: run ---- - def add_run(self): - dlg=RunEditDialog(self, run_row=None) - if dlg.exec_()==QDialog.Accepted: - v=dlg.values() - DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) - self.load() - - def edit_selected(self): - rid=self._current_run_id() - if not rid: - QMessageBox.information(self,"提示","请先选择一行记录。") - return - row=DB.get_run(rid) - dlg=RunEditDialog(self, run_row=row) - if dlg.exec_()==QDialog.Accepted: - v=dlg.values() - DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) - self.load() - - def delete_selected(self): - rid=self._current_run_id() - if not rid: - QMessageBox.information(self,"提示","请先选择一行记录。") - return - if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!", - QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes: - DB.delete_run(rid) - self.load() - - # ---- 片段管理 ---- - def open_segments(self): - rid=self._current_run_id() - if not rid: - QMessageBox.information(self,"提示","请先选择一行记录。") - return - dlg=SegmentManagerDialog(self, run_id=rid) - dlg.exec_() - # 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主 - self.load() - - -# ====================== -# 结果页(简) -# ====================== -class ResultWidget(QWidget): - def __init__(self, parent, result=None): - super().__init__(parent); self.parent=parent - layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b) - t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t) - area=QTextEdit(); area.setReadOnly(True) - if result: - txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'空心' if result['final_label'] else '实心'}" - area.setText(txt) - layout.addWidget(area); self.setLayout(layout) - - -# ====================== -# 主窗口 -# ====================== -class AudioClassifierGUI(QMainWindow): - def __init__(self): - super().__init__() - self.current_input_mode = "record" - self.process_result = None - self.init_ui() - - def init_ui(self): - self.setWindowTitle("音频分类器") - self.setGeometry(100, 100, 1000, 700) - - self.stacked_widget = QStackedWidget() - self.setCentralWidget(self.stacked_widget) - - self.main_menu_widget = MainMenuWidget(self) - self.record_input_widget = InputWidget(self, "record") - self.upload_input_widget = InputWidget(self, "upload") - self.result_widget = ResultWidget(self) - self.history_widget = HistoryWidget(self) - - self.stacked_widget.addWidget(self.main_menu_widget) # 0 - self.stacked_widget.addWidget(self.record_input_widget) # 1 - self.stacked_widget.addWidget(self.upload_input_widget) # 2 - self.stacked_widget.addWidget(self.result_widget) # 3 - self.stacked_widget.addWidget(self.history_widget) # 4 - - self.stacked_widget.setCurrentWidget(self.main_menu_widget) - - def switch_to_input(self, mode): - self.current_input_mode = mode - if mode == "record": - self.stacked_widget.setCurrentWidget(self.record_input_widget) - else: - self.stacked_widget.setCurrentWidget(self.upload_input_widget) - - def switch_to_main_menu(self): - self.stacked_widget.setCurrentWidget(self.main_menu_widget) - - def switch_to_result(self, result): - self.process_result = result - self.result_widget = ResultWidget(self, result) - self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) - self.stacked_widget.addWidget(self.result_widget) - self.stacked_widget.setCurrentWidget(self.result_widget) - - def switch_to_history(self): - self.history_widget.load() - self.stacked_widget.setCurrentWidget(self.history_widget) - - -# ====================== -# 上传/录音 界面(放最后,避免打断阅读) -# ====================== -class InputWidget(QWidget): - def __init__(self, parent, mode="record"): - super().__init__(parent) - self.parent, self.mode = parent, mode - self.wav_path, self.record_thread = "", None - self.device_index = None - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) - layout.addWidget(back) - if self.mode == "record": self.setup_record(layout) - else: self.setup_upload(layout) - self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") - path_layout = QHBoxLayout() - path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) - path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) - layout.addLayout(path_layout) - self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) - self.process_btn.clicked.connect(self.start_process) - layout.addWidget(self.process_btn) - self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) - self.setLayout(layout) - - def setup_record(self, layout): - title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) - layout.addWidget(title) - dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) - self.dev_combo = QComboBox(); self.refresh_devices() - self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) - dev_layout.addWidget(self.dev_combo) - btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) - dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) - self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) - self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec - self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") - self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) - for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) - self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) - - def setup_upload(self, layout): - h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) - b = QPushButton("浏览"); b.clicked.connect(self.browse) - h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) - - def refresh_devices(self): - self.dev_combo.clear(); p = pyaudio.PyAudio() - for i in range(p.get_device_count()): - info = p.get_device_info_by_index(i) - if int(info.get('maxInputChannels',0))>0: - self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) - if self.dev_combo.count()>0: - self.dev_combo.setCurrentIndex(0) - self.device_index = self.dev_combo.currentData() - p.terminate() - - def start_rec(self,e): - if e.button()==Qt.LeftButton: - self.record_thread=RecordThread(device_index=self.device_index) - self.record_thread.update_signal.connect(self.log.append) - self.record_thread.finish_signal.connect(self.finish_rec) - self.record_thread.progress_signal.connect(self.progress.setValue) - self.record_thread.level_signal.connect(self.level.setValue) - self.record_thread.start(); self.start=time.time(); self.timer.start(100) - - def stop_rec(self,e): - if e.button()==Qt.LeftButton and self.record_thread: - self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) - - def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") - - def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) - - def browse(self): - f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") - if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) - - def start_process(self): - if self.mode=="upload": self.wav_path=self.file.text() - if not os.path.exists(self.wav_path): self.log.append("无效文件"); return - self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) - self.thread.update_signal.connect(self.log.append) - self.thread.finish_signal.connect(self.parent.switch_to_result) - self.thread.start() - - -if __name__ == "__main__": - try: - import pyaudio - except ImportError: - print("请先安装pyaudio: pip install pyaudio") - sys.exit(1) - - app = QApplication(sys.argv) - window = AudioClassifierGUI() - window.show() - sys.exit(app.exec_()) diff --git a/infer_results.pkl b/infer_results.pkl deleted file mode 100644 index 2746cb3..0000000 Binary files a/infer_results.pkl and /dev/null differ diff --git a/sample/无保温实1.wav b/sample/无保温实1.wav deleted file mode 100644 index b583885..0000000 Binary files a/sample/无保温实1.wav and /dev/null differ diff --git a/sample/无保温空1.wav b/sample/无保温空1.wav deleted file mode 100644 index 11cd098..0000000 Binary files a/sample/无保温空1.wav and /dev/null differ diff --git a/sample/漆实1.wav b/sample/漆实1.wav deleted file mode 100644 index 7fb0439..0000000 Binary files a/sample/漆实1.wav and /dev/null differ diff --git a/sample/漆实2.wav b/sample/漆实2.wav deleted file mode 100644 index fa53291..0000000 Binary files a/sample/漆实2.wav and /dev/null differ diff --git a/sample/漆空1.wav b/sample/漆空1.wav deleted file mode 100644 index fbadd66..0000000 Binary files a/sample/漆空1.wav and /dev/null differ diff --git a/sample/漆空渐强1.wav b/sample/漆空渐强1.wav deleted file mode 100644 index 853625f..0000000 Binary files a/sample/漆空渐强1.wav and /dev/null differ diff --git a/sample/漆空渐强2.wav b/sample/漆空渐强2.wav deleted file mode 100644 index 46c8c8d..0000000 Binary files a/sample/漆空渐强2.wav and /dev/null differ diff --git a/sample/瓷实1.wav b/sample/瓷实1.wav deleted file mode 100644 index 1625647..0000000 Binary files a/sample/瓷实1.wav and /dev/null differ diff --git a/sample/瓷实2.wav b/sample/瓷实2.wav deleted file mode 100644 index 63ab74c..0000000 Binary files a/sample/瓷实2.wav and /dev/null differ diff --git a/sample/瓷实渐强.wav b/sample/瓷实渐强.wav deleted file mode 100644 index 97d2528..0000000 Binary files a/sample/瓷实渐强.wav and /dev/null differ diff --git a/sample/瓷空1.wav b/sample/瓷空1.wav deleted file mode 100644 index 80f8c4d..0000000 Binary files a/sample/瓷空1.wav and /dev/null differ diff --git a/sample/瓷空2.wav b/sample/瓷空2.wav deleted file mode 100644 index 4af472d..0000000 Binary files a/sample/瓷空2.wav and /dev/null differ diff --git a/sample/瓷空渐强.wav b/sample/瓷空渐强.wav deleted file mode 100644 index 44b1536..0000000 Binary files a/sample/瓷空渐强.wav and /dev/null differ diff --git a/scaler.pkl b/scaler.pkl deleted file mode 100644 index d520b18..0000000 Binary files a/scaler.pkl and /dev/null differ diff --git a/svm_model.pkl b/svm_model.pkl deleted file mode 100644 index b7ef935..0000000 Binary files a/svm_model.pkl and /dev/null differ