You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-02/audio_classifier_gui_02.py

631 lines
23 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox) # >>> CHANGED: add QComboBox
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# 特征提取类
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]:
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
# 1. RMS能量
rms = np.sqrt(np.mean(x ** 2))
# 2. 主频(频谱峰值频率)
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive_freq_mask = freq >= 0
freq = freq[positive_freq_mask]
fft_mag = np.abs(fft[positive_freq_mask])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
# 3. 频谱偏度
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
# 4. MFCC第一维均值
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
# 5. 包络峰值(希尔伯特变换)
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# 录音线程
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int) # >>> NEW: 实时电平0-100
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index # >>> NEW
def run(self):
# >>> CHANGED: 更通用的音频参数
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p = None
stream = None
try:
p = pyaudio.PyAudio()
# 设备信息日志(帮助排查)
try:
device_log = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}")
if device_log:
self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log))
except Exception as _:
pass
# >>> CHANGED: 指定 input_device_index若传入
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
input_device_index=self.device_index, # 可能为 None则使用系统默认
frames_per_buffer=CHUNK
)
self.is_recording = True
start_time = time.time()
frames = []
# 用于电平估计的满刻度
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED
frames.append(data)
# >>> NEW: 计算电平RMS
# 将bytes转为np.int16归一化到[-1,1]计算RMS并发射到UI
try:
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
# 简单映射到0-100
level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整
self.level_signal.emit(level)
except Exception:
pass
if stream is not None:
stream.stop_stream()
stream.close()
if p is not None:
p.terminate()
if len(frames) > 0:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}")
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)")
except Exception as e:
# >>> NEW: 打开失败时提示设备列表
err_msg = f"录制错误: {str(e)}"
try:
if p is not None:
err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)"
finally:
self.update_signal.emit(err_msg)
finally:
try:
if stream is not None:
stream.stop_stream()
stream.close()
except Exception:
pass
try:
if p is not None:
p.terminate()
except Exception:
pass
# 处理线程
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path = wav_path
self.model_path = model_path
self.scaler_path = scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model = joblib.load(self.model_path)
scaler = joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
# 归一化避免静音/削顶影响
if np.max(np.abs(sig)) > 0:
sig = sig / np.max(np.abs(sig))
self.update_signal.emit("提取特征...")
ene = FeatureExtractor.frame_energy(sig)
hit_starts = FeatureExtractor.detect_hits(ene)
segments = FeatureExtractor.extract_segments(sig, sr, hit_starts)
if not segments:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments]
X = np.vstack(feats)
self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}")
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
# 兼容无 predict_proba 的分类器
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
# 退化:用决策函数映射到(0,1)
if hasattr(model, "decision_function"):
df = model.decision_function(X_std)
y_proba = 1 / (1 + np.exp(-df))
else:
y_proba = (y_pred.astype(float) + 0.0)
self.finish_signal.emit({
"filename": os.path.basename(self.wav_path),
"segments": len(segments),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
})
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}")
# 第一层界面:主菜单
class MainMenuWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.parent = parent
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
title = QLabel("音频分类器")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
# 添加按钮
record_btn = QPushButton("采集音频")
record_btn.setMinimumHeight(60)
record_btn.setStyleSheet("font-size: 16px;")
record_btn.clicked.connect(lambda: self.parent.switch_to_input("record"))
upload_btn = QPushButton("上传外部WAV文件")
upload_btn.setMinimumHeight(60)
upload_btn.setStyleSheet("font-size: 16px;")
upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload"))
layout.addWidget(record_btn)
layout.addWidget(upload_btn)
layout.addStretch(1)
self.setLayout(layout)
# 第二层界面:输入界面(录音或上传文件)
class InputWidget(QWidget):
def __init__(self, parent=None, mode="record"):
super().__init__(parent)
self.parent = parent
self.mode = mode
self.wav_path = ""
self.record_thread = None
self.device_index = None # >>> NEW: 当前选中的输入设备索引
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back_btn)
if self.mode == "record":
self.setup_record_ui(layout)
else:
self.setup_upload_ui(layout)
# 模型路径
model_layout = QHBoxLayout()
self.model_path = QLineEdit("svm_model.pkl")
self.scaler_path = QLineEdit("scaler.pkl")
model_layout.addWidget(QLabel("模型路径:"))
model_layout.addWidget(self.model_path)
model_layout.addWidget(QLabel("标准化器路径:"))
model_layout.addWidget(self.scaler_path)
layout.addLayout(model_layout)
# 处理按钮
self.process_btn = QPushButton("开始处理")
self.process_btn.setMinimumHeight(50)
self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;")
self.process_btn.clicked.connect(self.start_process)
self.process_btn.setEnabled(False) # 初始不可用
layout.addWidget(self.process_btn)
# 日志区域
self.log_area = QTextEdit()
self.log_area.setReadOnly(True)
layout.addWidget(QLabel("日志:"))
layout.addWidget(self.log_area)
self.setLayout(layout)
def setup_record_ui(self, layout):
title = QLabel("音频采集")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# >>> NEW: 麦克风选择
device_layout = QHBoxLayout()
device_layout.addWidget(QLabel("输入设备:"))
self.device_combo = QComboBox()
self.refresh_devices()
self.device_combo.currentIndexChanged.connect(self.on_device_changed)
device_layout.addWidget(self.device_combo)
refresh_btn = QPushButton("刷新设备")
refresh_btn.clicked.connect(self.refresh_devices)
device_layout.addWidget(refresh_btn)
layout.addLayout(device_layout)
# 录音控制
record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)")
record_hint.setAlignment(Qt.AlignCenter)
self.record_btn = QPushButton("按住录音")
self.record_btn.setMinimumHeight(80)
self.record_btn.setStyleSheet("""
QPushButton {
background-color: #ff4d4d;
color: white;
font-size: 18px;
border-radius: 10px;
}
QPushButton:pressed {
background-color: #cc0000;
}
""")
self.record_btn.mousePressEvent = self.start_recording
self.record_btn.mouseReleaseEvent = self.stop_recording
self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu)
# >>> NEW: 实时电平与录音进度
self.mic_level = QProgressBar()
self.mic_level.setRange(0, 100)
self.mic_level.setFormat("麦克风电平:%p%")
self.record_progress = QProgressBar()
self.record_progress.setRange(0, 100)
self.record_progress.setValue(0)
self.record_duration_label = QLabel("录音时长: 0.0秒")
self.record_duration_label.setAlignment(Qt.AlignCenter)
layout.addWidget(record_hint)
layout.addWidget(self.record_btn)
layout.addWidget(self.mic_level) # >>> NEW
layout.addWidget(self.record_progress)
layout.addWidget(self.record_duration_label)
# 录音计时器
self.record_timer = QTimer(self)
self.record_timer.timeout.connect(self.update_record_duration)
self.record_start_time = 0
def refresh_devices(self):
"""枚举有输入通道的设备,并填充到下拉框"""
self.device_combo.clear()
try:
p = pyaudio.PyAudio()
default_host_api = p.get_host_api_info_by_index(0)
default_input_index = default_host_api.get("defaultInputDevice", -1)
found = []
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels', 0)) > 0:
name = info.get('name', f"Device {i}")
sr = int(info.get('defaultSampleRate', 0))
label = f"[{i}] {name} (sr={sr})"
self.device_combo.addItem(label, i)
found.append(i)
p.terminate()
# 选中默认输入设备(若存在)
if default_input_index in found:
idx = found.index(default_input_index)
self.device_combo.setCurrentIndex(idx)
self.device_index = default_input_index
elif found:
self.device_combo.setCurrentIndex(0)
self.device_index = self.device_combo.currentData()
else:
self.device_index = None
except Exception:
self.device_index = None
def on_device_changed(self, _):
self.device_index = self.device_combo.currentData()
def setup_upload_ui(self, layout):
title = QLabel("上传WAV文件")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 文件选择
file_layout = QHBoxLayout()
self.file_path = QLineEdit()
self.file_path.setReadOnly(True)
self.browse_btn = QPushButton("浏览WAV文件")
self.browse_btn.clicked.connect(self.browse_file)
file_layout.addWidget(self.file_path)
file_layout.addWidget(self.browse_btn)
layout.addLayout(file_layout)
def start_recording(self, event):
if event.button() == Qt.LeftButton:
if not self.record_thread or not self.record_thread.isRunning():
# >>> CHANGED: 传入 device_index
self.record_thread = RecordThread(max_duration=60, device_index=self.device_index)
self.record_thread.update_signal.connect(self.update_log)
self.record_thread.finish_signal.connect(self.on_recording_finish)
self.record_thread.progress_signal.connect(self.record_progress.setValue)
self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW
self.record_thread.start()
self.record_start_time = time.time()
self.record_timer.start(100)
self.update_log("开始录音...(松开按钮结束)")
def stop_recording(self, event):
if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning():
self.record_thread.is_recording = False
self.record_timer.stop()
self.record_duration_label.setText("录音时长: 0.0秒")
self.record_progress.setValue(0)
# 让电平逐步回落
self.mic_level.setValue(0)
def update_record_duration(self):
elapsed = time.time() - self.record_start_time
self.record_duration_label.setText(f"录音时长: {elapsed:.1f}")
def on_recording_finish(self, temp_file):
self.wav_path = temp_file
self.process_btn.setEnabled(True)
def browse_file(self):
file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)")
if file:
self.file_path.setText(file)
self.wav_path = file
self.process_btn.setEnabled(True)
def start_process(self):
if self.mode == "upload":
self.wav_path = self.file_path.text()
model_path = self.model_path.text()
scaler_path = self.scaler_path.text()
if not self.wav_path or not os.path.exists(self.wav_path):
self.log_area.append("请先录音或选择有效的WAV文件")
return
if not os.path.exists(model_path) or not os.path.exists(scaler_path):
self.log_area.append("模型文件不存在请先运行train_model.py训练模型")
return
self.log_area.clear()
self.process_btn.setEnabled(False)
self.thread = ProcessThread(self.wav_path, model_path, scaler_path)
self.thread.update_signal.connect(self.update_log)
self.thread.finish_signal.connect(self.on_process_finish)
self.thread.start()
def update_log(self, msg):
self.log_area.append(msg)
def on_process_finish(self, result):
self.parent.switch_to_result(result)
# 第三层界面:结果显示界面
class ResultWidget(QWidget):
def __init__(self, parent=None, result=None):
super().__init__(parent)
self.parent = parent
self.result = result
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 返回按钮
back_btn = QPushButton("返回")
back_btn.clicked.connect(self.parent.switch_to_input_from_result)
layout.addWidget(back_btn)
title = QLabel("处理结果")
title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
layout.addWidget(title)
# 结果显示
self.result_area = QTextEdit()
self.result_area.setReadOnly(True)
if self.result:
res = f"文件名: {self.result['filename']}\n"
res += f"片段数: {self.result['segments']}\n"
res += "预测结果:\n"
for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])):
res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n"
res += f"\n平均概率: {self.result['mean_prob']}\n"
res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}"
self.result_area.setText(res)
layout.addWidget(self.result_area)
self.setLayout(layout)
# 主窗口
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record" # 记录当前输入模式
self.process_result = None # 存储处理结果
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 800, 600)
# 创建堆叠窗口
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
# 创建三个界面
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
# 添加到堆叠窗口
self.stacked_widget.addWidget(self.main_menu_widget)
self.stacked_widget.addWidget(self.record_input_widget)
self.stacked_widget.addWidget(self.upload_input_widget)
self.stacked_widget.addWidget(self.result_widget)
# 显示主菜单
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
# 移除旧的结果界面并添加新的
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_input_from_result(self):
if self.current_input_mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
if __name__ == "__main__":
try:
import pyaudio # 确保已安装
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())