#!/usr/bin/env python # -*- coding: utf-8 -*- import sys import os import time import numpy as np import joblib import librosa import pyaudio import wave from scipy.signal import hilbert from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout, QWidget, QProgressBar, QStackedWidget, QComboBox) # >>> CHANGED: add QComboBox from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer # 特征提取类 class FeatureExtractor: WIN = 1024 OVERLAP = 512 THRESHOLD = 0.1 SEG_LEN_S = 0.2 STEP = WIN - OVERLAP @staticmethod def frame_energy(signal: np.ndarray) -> np.ndarray: frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) return np.sum(frames ** 2, axis=0) @staticmethod def detect_hits(energy: np.ndarray) -> np.ndarray: idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] if idx.size == 0: return np.array([]) flags = np.diff(np.concatenate(([0], idx))) > 5 return idx[flags] @staticmethod def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]: seg_len = int(FeatureExtractor.SEG_LEN_S * sr) segments = [] for frame_idx in hit_starts: s = frame_idx * FeatureExtractor.STEP e = min(s + seg_len, len(signal)) segments.append(signal[s:e]) return segments @staticmethod def extract_features(x: np.ndarray, sr: int) -> np.ndarray: x = x.flatten() if len(x) == 0: return np.zeros(5) # 1. RMS能量 rms = np.sqrt(np.mean(x ** 2)) # 2. 主频(频谱峰值频率) fft = np.fft.fft(x) freq = np.fft.fftfreq(len(x), 1 / sr) positive_freq_mask = freq >= 0 freq = freq[positive_freq_mask] fft_mag = np.abs(fft[positive_freq_mask]) main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 # 3. 频谱偏度 spec_power = fft_mag centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 # 4. MFCC第一维均值 mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 # 5. 包络峰值(希尔伯特变换) env_peak = np.max(np.abs(hilbert(x))) return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) # 录音线程 class RecordThread(QThread): update_signal = pyqtSignal(str) finish_signal = pyqtSignal(str) progress_signal = pyqtSignal(int) level_signal = pyqtSignal(int) # >>> NEW: 实时电平(0-100) def __init__(self, max_duration=60, device_index=None): super().__init__() self.max_duration = max_duration self.is_recording = False self.temp_file = "temp_recording.wav" self.device_index = device_index # >>> NEW def run(self): # >>> CHANGED: 更通用的音频参数 FORMAT = pyaudio.paInt16 CHANNELS = 1 RATE = 44100 CHUNK = 1024 p = None stream = None try: p = pyaudio.PyAudio() # 设备信息日志(帮助排查) try: device_log = [] for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) if int(info.get('maxInputChannels', 0)) > 0: device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}") if device_log: self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log)) except Exception as _: pass # >>> CHANGED: 指定 input_device_index(若传入) stream = p.open( format=FORMAT, channels=CHANNELS, rate=RATE, input=True, input_device_index=self.device_index, # 可能为 None,则使用系统默认 frames_per_buffer=CHUNK ) self.is_recording = True start_time = time.time() frames = [] # 用于电平估计的满刻度 max_int16 = 32767.0 while self.is_recording: elapsed = time.time() - start_time if elapsed >= self.max_duration: self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") break self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED frames.append(data) # >>> NEW: 计算电平(RMS) # 将bytes转为np.int16,归一化到[-1,1],计算RMS并发射到UI try: chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 rms = np.sqrt(np.mean(chunk_np ** 2)) # 简单映射到0-100 level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整 self.level_signal.emit(level) except Exception: pass if stream is not None: stream.stop_stream() stream.close() if p is not None: p.terminate() if len(frames) > 0: wf = wave.open(self.temp_file, 'wb') wf.setnchannels(CHANNELS) wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节 wf.setframerate(RATE) wf.writeframes(b''.join(frames)) wf.close() self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}秒") self.finish_signal.emit(self.temp_file) else: self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)") except Exception as e: # >>> NEW: 打开失败时提示设备列表 err_msg = f"录制错误: {str(e)}" try: if p is not None: err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)" finally: self.update_signal.emit(err_msg) finally: try: if stream is not None: stream.stop_stream() stream.close() except Exception: pass try: if p is not None: p.terminate() except Exception: pass # 处理线程 class ProcessThread(QThread): update_signal = pyqtSignal(str) finish_signal = pyqtSignal(dict) def __init__(self, wav_path, model_path, scaler_path): super().__init__() self.wav_path = wav_path self.model_path = model_path self.scaler_path = scaler_path def run(self): try: self.update_signal.emit("加载模型和标准化器...") model = joblib.load(self.model_path) scaler = joblib.load(self.scaler_path) self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") sig, sr = librosa.load(self.wav_path, sr=None, mono=True) # 归一化避免静音/削顶影响 if np.max(np.abs(sig)) > 0: sig = sig / np.max(np.abs(sig)) self.update_signal.emit("提取特征...") ene = FeatureExtractor.frame_energy(sig) hit_starts = FeatureExtractor.detect_hits(ene) segments = FeatureExtractor.extract_segments(sig, sr, hit_starts) if not segments: self.update_signal.emit("未检测到有效片段!") return feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments] X = np.vstack(feats) self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}") X_std = scaler.transform(X) y_pred = model.predict(X_std) # 兼容无 predict_proba 的分类器 if hasattr(model, "predict_proba"): y_proba = model.predict_proba(X_std)[:, 1] else: # 退化:用决策函数映射到(0,1) if hasattr(model, "decision_function"): df = model.decision_function(X_std) y_proba = 1 / (1 + np.exp(-df)) else: y_proba = (y_pred.astype(float) + 0.0) self.finish_signal.emit({ "filename": os.path.basename(self.wav_path), "segments": len(segments), "predictions": y_pred.tolist(), "probabilities": [round(float(p), 4) for p in y_proba], "mean_prob": round(float(np.mean(y_proba)), 4), "final_label": int(np.mean(y_proba) >= 0.5) }) except Exception as e: self.update_signal.emit(f"错误: {str(e)}") # 第一层界面:主菜单 class MainMenuWidget(QWidget): def __init__(self, parent=None): super().__init__(parent) self.parent = parent self.init_ui() def init_ui(self): layout = QVBoxLayout() title = QLabel("音频分类器") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") layout.addWidget(title) # 添加按钮 record_btn = QPushButton("采集音频") record_btn.setMinimumHeight(60) record_btn.setStyleSheet("font-size: 16px;") record_btn.clicked.connect(lambda: self.parent.switch_to_input("record")) upload_btn = QPushButton("上传外部WAV文件") upload_btn.setMinimumHeight(60) upload_btn.setStyleSheet("font-size: 16px;") upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload")) layout.addWidget(record_btn) layout.addWidget(upload_btn) layout.addStretch(1) self.setLayout(layout) # 第二层界面:输入界面(录音或上传文件) class InputWidget(QWidget): def __init__(self, parent=None, mode="record"): super().__init__(parent) self.parent = parent self.mode = mode self.wav_path = "" self.record_thread = None self.device_index = None # >>> NEW: 当前选中的输入设备索引 self.init_ui() def init_ui(self): layout = QVBoxLayout() # 返回按钮 back_btn = QPushButton("返回") back_btn.clicked.connect(self.parent.switch_to_main_menu) layout.addWidget(back_btn) if self.mode == "record": self.setup_record_ui(layout) else: self.setup_upload_ui(layout) # 模型路径 model_layout = QHBoxLayout() self.model_path = QLineEdit("svm_model.pkl") self.scaler_path = QLineEdit("scaler.pkl") model_layout.addWidget(QLabel("模型路径:")) model_layout.addWidget(self.model_path) model_layout.addWidget(QLabel("标准化器路径:")) model_layout.addWidget(self.scaler_path) layout.addLayout(model_layout) # 处理按钮 self.process_btn = QPushButton("开始处理") self.process_btn.setMinimumHeight(50) self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;") self.process_btn.clicked.connect(self.start_process) self.process_btn.setEnabled(False) # 初始不可用 layout.addWidget(self.process_btn) # 日志区域 self.log_area = QTextEdit() self.log_area.setReadOnly(True) layout.addWidget(QLabel("日志:")) layout.addWidget(self.log_area) self.setLayout(layout) def setup_record_ui(self, layout): title = QLabel("音频采集") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") layout.addWidget(title) # >>> NEW: 麦克风选择 device_layout = QHBoxLayout() device_layout.addWidget(QLabel("输入设备:")) self.device_combo = QComboBox() self.refresh_devices() self.device_combo.currentIndexChanged.connect(self.on_device_changed) device_layout.addWidget(self.device_combo) refresh_btn = QPushButton("刷新设备") refresh_btn.clicked.connect(self.refresh_devices) device_layout.addWidget(refresh_btn) layout.addLayout(device_layout) # 录音控制 record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)") record_hint.setAlignment(Qt.AlignCenter) self.record_btn = QPushButton("按住录音") self.record_btn.setMinimumHeight(80) self.record_btn.setStyleSheet(""" QPushButton { background-color: #ff4d4d; color: white; font-size: 18px; border-radius: 10px; } QPushButton:pressed { background-color: #cc0000; } """) self.record_btn.mousePressEvent = self.start_recording self.record_btn.mouseReleaseEvent = self.stop_recording self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu) # >>> NEW: 实时电平与录音进度 self.mic_level = QProgressBar() self.mic_level.setRange(0, 100) self.mic_level.setFormat("麦克风电平:%p%") self.record_progress = QProgressBar() self.record_progress.setRange(0, 100) self.record_progress.setValue(0) self.record_duration_label = QLabel("录音时长: 0.0秒") self.record_duration_label.setAlignment(Qt.AlignCenter) layout.addWidget(record_hint) layout.addWidget(self.record_btn) layout.addWidget(self.mic_level) # >>> NEW layout.addWidget(self.record_progress) layout.addWidget(self.record_duration_label) # 录音计时器 self.record_timer = QTimer(self) self.record_timer.timeout.connect(self.update_record_duration) self.record_start_time = 0 def refresh_devices(self): """枚举有输入通道的设备,并填充到下拉框""" self.device_combo.clear() try: p = pyaudio.PyAudio() default_host_api = p.get_host_api_info_by_index(0) default_input_index = default_host_api.get("defaultInputDevice", -1) found = [] for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) if int(info.get('maxInputChannels', 0)) > 0: name = info.get('name', f"Device {i}") sr = int(info.get('defaultSampleRate', 0)) label = f"[{i}] {name} (sr={sr})" self.device_combo.addItem(label, i) found.append(i) p.terminate() # 选中默认输入设备(若存在) if default_input_index in found: idx = found.index(default_input_index) self.device_combo.setCurrentIndex(idx) self.device_index = default_input_index elif found: self.device_combo.setCurrentIndex(0) self.device_index = self.device_combo.currentData() else: self.device_index = None except Exception: self.device_index = None def on_device_changed(self, _): self.device_index = self.device_combo.currentData() def setup_upload_ui(self, layout): title = QLabel("上传WAV文件") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") layout.addWidget(title) # 文件选择 file_layout = QHBoxLayout() self.file_path = QLineEdit() self.file_path.setReadOnly(True) self.browse_btn = QPushButton("浏览WAV文件") self.browse_btn.clicked.connect(self.browse_file) file_layout.addWidget(self.file_path) file_layout.addWidget(self.browse_btn) layout.addLayout(file_layout) def start_recording(self, event): if event.button() == Qt.LeftButton: if not self.record_thread or not self.record_thread.isRunning(): # >>> CHANGED: 传入 device_index self.record_thread = RecordThread(max_duration=60, device_index=self.device_index) self.record_thread.update_signal.connect(self.update_log) self.record_thread.finish_signal.connect(self.on_recording_finish) self.record_thread.progress_signal.connect(self.record_progress.setValue) self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW self.record_thread.start() self.record_start_time = time.time() self.record_timer.start(100) self.update_log("开始录音...(松开按钮结束)") def stop_recording(self, event): if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning(): self.record_thread.is_recording = False self.record_timer.stop() self.record_duration_label.setText("录音时长: 0.0秒") self.record_progress.setValue(0) # 让电平逐步回落 self.mic_level.setValue(0) def update_record_duration(self): elapsed = time.time() - self.record_start_time self.record_duration_label.setText(f"录音时长: {elapsed:.1f}秒") def on_recording_finish(self, temp_file): self.wav_path = temp_file self.process_btn.setEnabled(True) def browse_file(self): file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)") if file: self.file_path.setText(file) self.wav_path = file self.process_btn.setEnabled(True) def start_process(self): if self.mode == "upload": self.wav_path = self.file_path.text() model_path = self.model_path.text() scaler_path = self.scaler_path.text() if not self.wav_path or not os.path.exists(self.wav_path): self.log_area.append("请先录音或选择有效的WAV文件!") return if not os.path.exists(model_path) or not os.path.exists(scaler_path): self.log_area.append("模型文件不存在!请先运行train_model.py训练模型") return self.log_area.clear() self.process_btn.setEnabled(False) self.thread = ProcessThread(self.wav_path, model_path, scaler_path) self.thread.update_signal.connect(self.update_log) self.thread.finish_signal.connect(self.on_process_finish) self.thread.start() def update_log(self, msg): self.log_area.append(msg) def on_process_finish(self, result): self.parent.switch_to_result(result) # 第三层界面:结果显示界面 class ResultWidget(QWidget): def __init__(self, parent=None, result=None): super().__init__(parent) self.parent = parent self.result = result self.init_ui() def init_ui(self): layout = QVBoxLayout() # 返回按钮 back_btn = QPushButton("返回") back_btn.clicked.connect(self.parent.switch_to_input_from_result) layout.addWidget(back_btn) title = QLabel("处理结果") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") layout.addWidget(title) # 结果显示 self.result_area = QTextEdit() self.result_area.setReadOnly(True) if self.result: res = f"文件名: {self.result['filename']}\n" res += f"片段数: {self.result['segments']}\n" res += "预测结果:\n" for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])): res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n" res += f"\n平均概率: {self.result['mean_prob']}\n" res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}" self.result_area.setText(res) layout.addWidget(self.result_area) self.setLayout(layout) # 主窗口 class AudioClassifierGUI(QMainWindow): def __init__(self): super().__init__() self.current_input_mode = "record" # 记录当前输入模式 self.process_result = None # 存储处理结果 self.init_ui() def init_ui(self): self.setWindowTitle("音频分类器") self.setGeometry(100, 100, 800, 600) # 创建堆叠窗口 self.stacked_widget = QStackedWidget() self.setCentralWidget(self.stacked_widget) # 创建三个界面 self.main_menu_widget = MainMenuWidget(self) self.record_input_widget = InputWidget(self, "record") self.upload_input_widget = InputWidget(self, "upload") self.result_widget = ResultWidget(self) # 添加到堆叠窗口 self.stacked_widget.addWidget(self.main_menu_widget) self.stacked_widget.addWidget(self.record_input_widget) self.stacked_widget.addWidget(self.upload_input_widget) self.stacked_widget.addWidget(self.result_widget) # 显示主菜单 self.stacked_widget.setCurrentWidget(self.main_menu_widget) def switch_to_input(self, mode): self.current_input_mode = mode if mode == "record": self.stacked_widget.setCurrentWidget(self.record_input_widget) else: self.stacked_widget.setCurrentWidget(self.upload_input_widget) def switch_to_main_menu(self): self.stacked_widget.setCurrentWidget(self.main_menu_widget) def switch_to_result(self, result): self.process_result = result self.result_widget = ResultWidget(self, result) # 移除旧的结果界面并添加新的 self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) self.stacked_widget.addWidget(self.result_widget) self.stacked_widget.setCurrentWidget(self.result_widget) def switch_to_input_from_result(self): if self.current_input_mode == "record": self.stacked_widget.setCurrentWidget(self.record_input_widget) else: self.stacked_widget.setCurrentWidget(self.upload_input_widget) if __name__ == "__main__": try: import pyaudio # 确保已安装 except ImportError: print("请先安装pyaudio: pip install pyaudio") sys.exit(1) app = QApplication(sys.argv) window = AudioClassifierGUI() window.show() sys.exit(app.exec_())