|
|
#!/usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
import librosa
|
|
|
import pyaudio
|
|
|
import wave
|
|
|
from scipy.signal import hilbert
|
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
|
|
|
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
|
|
|
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
|
|
|
QComboBox) # >>> CHANGED: add QComboBox
|
|
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
|
|
|
|
|
|
|
|
|
# 特征提取类
|
|
|
class FeatureExtractor:
|
|
|
WIN = 1024
|
|
|
OVERLAP = 512
|
|
|
THRESHOLD = 0.1
|
|
|
SEG_LEN_S = 0.2
|
|
|
STEP = WIN - OVERLAP
|
|
|
|
|
|
@staticmethod
|
|
|
def frame_energy(signal: np.ndarray) -> np.ndarray:
|
|
|
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
|
|
|
return np.sum(frames ** 2, axis=0)
|
|
|
|
|
|
@staticmethod
|
|
|
def detect_hits(energy: np.ndarray) -> np.ndarray:
|
|
|
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
|
|
|
if idx.size == 0:
|
|
|
return np.array([])
|
|
|
flags = np.diff(np.concatenate(([0], idx))) > 5
|
|
|
return idx[flags]
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray) -> list[np.ndarray]:
|
|
|
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
|
|
|
segments = []
|
|
|
for frame_idx in hit_starts:
|
|
|
s = frame_idx * FeatureExtractor.STEP
|
|
|
e = min(s + seg_len, len(signal))
|
|
|
segments.append(signal[s:e])
|
|
|
return segments
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
|
|
|
x = x.flatten()
|
|
|
if len(x) == 0:
|
|
|
return np.zeros(5)
|
|
|
|
|
|
# 1. RMS能量
|
|
|
rms = np.sqrt(np.mean(x ** 2))
|
|
|
|
|
|
# 2. 主频(频谱峰值频率)
|
|
|
fft = np.fft.fft(x)
|
|
|
freq = np.fft.fftfreq(len(x), 1 / sr)
|
|
|
positive_freq_mask = freq >= 0
|
|
|
freq = freq[positive_freq_mask]
|
|
|
fft_mag = np.abs(fft[positive_freq_mask])
|
|
|
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
|
|
|
|
|
|
# 3. 频谱偏度
|
|
|
spec_power = fft_mag
|
|
|
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
|
|
|
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
|
|
|
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
|
|
|
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
|
|
|
|
|
|
# 4. MFCC第一维均值
|
|
|
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
|
|
|
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
|
|
|
|
|
|
# 5. 包络峰值(希尔伯特变换)
|
|
|
env_peak = np.max(np.abs(hilbert(x)))
|
|
|
|
|
|
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
|
|
|
|
|
|
|
|
|
# 录音线程
|
|
|
class RecordThread(QThread):
|
|
|
update_signal = pyqtSignal(str)
|
|
|
finish_signal = pyqtSignal(str)
|
|
|
progress_signal = pyqtSignal(int)
|
|
|
level_signal = pyqtSignal(int) # >>> NEW: 实时电平(0-100)
|
|
|
|
|
|
def __init__(self, max_duration=60, device_index=None):
|
|
|
super().__init__()
|
|
|
self.max_duration = max_duration
|
|
|
self.is_recording = False
|
|
|
self.temp_file = "temp_recording.wav"
|
|
|
self.device_index = device_index # >>> NEW
|
|
|
|
|
|
def run(self):
|
|
|
# >>> CHANGED: 更通用的音频参数
|
|
|
FORMAT = pyaudio.paInt16
|
|
|
CHANNELS = 1
|
|
|
RATE = 44100
|
|
|
CHUNK = 1024
|
|
|
|
|
|
p = None
|
|
|
stream = None
|
|
|
try:
|
|
|
p = pyaudio.PyAudio()
|
|
|
|
|
|
# 设备信息日志(帮助排查)
|
|
|
try:
|
|
|
device_log = []
|
|
|
for i in range(p.get_device_count()):
|
|
|
info = p.get_device_info_by_index(i)
|
|
|
if int(info.get('maxInputChannels', 0)) > 0:
|
|
|
device_log.append(f"[{i}] {info.get('name')} | in={info.get('maxInputChannels')} sr={int(info.get('defaultSampleRate', 0))}")
|
|
|
if device_log:
|
|
|
self.update_signal.emit("可用输入设备:\n" + "\n".join(device_log))
|
|
|
except Exception as _:
|
|
|
pass
|
|
|
|
|
|
# >>> CHANGED: 指定 input_device_index(若传入)
|
|
|
stream = p.open(
|
|
|
format=FORMAT,
|
|
|
channels=CHANNELS,
|
|
|
rate=RATE,
|
|
|
input=True,
|
|
|
input_device_index=self.device_index, # 可能为 None,则使用系统默认
|
|
|
frames_per_buffer=CHUNK
|
|
|
)
|
|
|
|
|
|
self.is_recording = True
|
|
|
start_time = time.time()
|
|
|
frames = []
|
|
|
|
|
|
# 用于电平估计的满刻度
|
|
|
max_int16 = 32767.0
|
|
|
|
|
|
while self.is_recording:
|
|
|
elapsed = time.time() - start_time
|
|
|
if elapsed >= self.max_duration:
|
|
|
self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒")
|
|
|
break
|
|
|
|
|
|
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
|
|
|
|
|
|
data = stream.read(CHUNK, exception_on_overflow=False) # >>> CHANGED
|
|
|
frames.append(data)
|
|
|
|
|
|
# >>> NEW: 计算电平(RMS)
|
|
|
# 将bytes转为np.int16,归一化到[-1,1],计算RMS并发射到UI
|
|
|
try:
|
|
|
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
|
|
|
rms = np.sqrt(np.mean(chunk_np ** 2))
|
|
|
# 简单映射到0-100
|
|
|
level = int(np.clip(rms * 100 * 5, 0, 100)) # 放大系数5可根据环境调整
|
|
|
self.level_signal.emit(level)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
if stream is not None:
|
|
|
stream.stop_stream()
|
|
|
stream.close()
|
|
|
if p is not None:
|
|
|
p.terminate()
|
|
|
|
|
|
if len(frames) > 0:
|
|
|
wf = wave.open(self.temp_file, 'wb')
|
|
|
wf.setnchannels(CHANNELS)
|
|
|
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) # 2 字节
|
|
|
wf.setframerate(RATE)
|
|
|
wf.writeframes(b''.join(frames))
|
|
|
wf.close()
|
|
|
self.update_signal.emit(f"录制完成,时长: {time.time() - start_time:.1f}秒")
|
|
|
self.finish_signal.emit(self.temp_file)
|
|
|
else:
|
|
|
self.update_signal.emit("未录制到有效音频(请检查麦克风选择与系统权限)")
|
|
|
|
|
|
except Exception as e:
|
|
|
# >>> NEW: 打开失败时提示设备列表
|
|
|
err_msg = f"录制错误: {str(e)}"
|
|
|
try:
|
|
|
if p is not None:
|
|
|
err_msg += "\n(提示:在“采集音频”界面尝试切换输入设备,或检查系统隐私权限中麦克风是否对本应用开放)"
|
|
|
finally:
|
|
|
self.update_signal.emit(err_msg)
|
|
|
finally:
|
|
|
try:
|
|
|
if stream is not None:
|
|
|
stream.stop_stream()
|
|
|
stream.close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
try:
|
|
|
if p is not None:
|
|
|
p.terminate()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
# 处理线程
|
|
|
class ProcessThread(QThread):
|
|
|
update_signal = pyqtSignal(str)
|
|
|
finish_signal = pyqtSignal(dict)
|
|
|
|
|
|
def __init__(self, wav_path, model_path, scaler_path):
|
|
|
super().__init__()
|
|
|
self.wav_path = wav_path
|
|
|
self.model_path = model_path
|
|
|
self.scaler_path = scaler_path
|
|
|
|
|
|
def run(self):
|
|
|
try:
|
|
|
self.update_signal.emit("加载模型和标准化器...")
|
|
|
model = joblib.load(self.model_path)
|
|
|
scaler = joblib.load(self.scaler_path)
|
|
|
|
|
|
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
|
|
|
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
|
|
|
# 归一化避免静音/削顶影响
|
|
|
if np.max(np.abs(sig)) > 0:
|
|
|
sig = sig / np.max(np.abs(sig))
|
|
|
|
|
|
self.update_signal.emit("提取特征...")
|
|
|
ene = FeatureExtractor.frame_energy(sig)
|
|
|
hit_starts = FeatureExtractor.detect_hits(ene)
|
|
|
segments = FeatureExtractor.extract_segments(sig, sr, hit_starts)
|
|
|
if not segments:
|
|
|
self.update_signal.emit("未检测到有效片段!")
|
|
|
return
|
|
|
|
|
|
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segments]
|
|
|
X = np.vstack(feats)
|
|
|
self.update_signal.emit(f"提取到 {len(segments)} 个片段,特征维度: {X.shape[1]}")
|
|
|
|
|
|
X_std = scaler.transform(X)
|
|
|
y_pred = model.predict(X_std)
|
|
|
# 兼容无 predict_proba 的分类器
|
|
|
if hasattr(model, "predict_proba"):
|
|
|
y_proba = model.predict_proba(X_std)[:, 1]
|
|
|
else:
|
|
|
# 退化:用决策函数映射到(0,1)
|
|
|
if hasattr(model, "decision_function"):
|
|
|
df = model.decision_function(X_std)
|
|
|
y_proba = 1 / (1 + np.exp(-df))
|
|
|
else:
|
|
|
y_proba = (y_pred.astype(float) + 0.0)
|
|
|
|
|
|
self.finish_signal.emit({
|
|
|
"filename": os.path.basename(self.wav_path),
|
|
|
"segments": len(segments),
|
|
|
"predictions": y_pred.tolist(),
|
|
|
"probabilities": [round(float(p), 4) for p in y_proba],
|
|
|
"mean_prob": round(float(np.mean(y_proba)), 4),
|
|
|
"final_label": int(np.mean(y_proba) >= 0.5)
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
self.update_signal.emit(f"错误: {str(e)}")
|
|
|
|
|
|
|
|
|
# 第一层界面:主菜单
|
|
|
class MainMenuWidget(QWidget):
|
|
|
def __init__(self, parent=None):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
|
|
|
title = QLabel("音频分类器")
|
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# 添加按钮
|
|
|
record_btn = QPushButton("采集音频")
|
|
|
record_btn.setMinimumHeight(60)
|
|
|
record_btn.setStyleSheet("font-size: 16px;")
|
|
|
record_btn.clicked.connect(lambda: self.parent.switch_to_input("record"))
|
|
|
|
|
|
upload_btn = QPushButton("上传外部WAV文件")
|
|
|
upload_btn.setMinimumHeight(60)
|
|
|
upload_btn.setStyleSheet("font-size: 16px;")
|
|
|
upload_btn.clicked.connect(lambda: self.parent.switch_to_input("upload"))
|
|
|
|
|
|
layout.addWidget(record_btn)
|
|
|
layout.addWidget(upload_btn)
|
|
|
layout.addStretch(1)
|
|
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
|
# 第二层界面:输入界面(录音或上传文件)
|
|
|
class InputWidget(QWidget):
|
|
|
def __init__(self, parent=None, mode="record"):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.mode = mode
|
|
|
self.wav_path = ""
|
|
|
self.record_thread = None
|
|
|
self.device_index = None # >>> NEW: 当前选中的输入设备索引
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
|
|
|
# 返回按钮
|
|
|
back_btn = QPushButton("返回")
|
|
|
back_btn.clicked.connect(self.parent.switch_to_main_menu)
|
|
|
layout.addWidget(back_btn)
|
|
|
|
|
|
if self.mode == "record":
|
|
|
self.setup_record_ui(layout)
|
|
|
else:
|
|
|
self.setup_upload_ui(layout)
|
|
|
|
|
|
# 模型路径
|
|
|
model_layout = QHBoxLayout()
|
|
|
self.model_path = QLineEdit("svm_model.pkl")
|
|
|
self.scaler_path = QLineEdit("scaler.pkl")
|
|
|
model_layout.addWidget(QLabel("模型路径:"))
|
|
|
model_layout.addWidget(self.model_path)
|
|
|
model_layout.addWidget(QLabel("标准化器路径:"))
|
|
|
model_layout.addWidget(self.scaler_path)
|
|
|
layout.addLayout(model_layout)
|
|
|
|
|
|
# 处理按钮
|
|
|
self.process_btn = QPushButton("开始处理")
|
|
|
self.process_btn.setMinimumHeight(50)
|
|
|
self.process_btn.setStyleSheet("font-size: 16px; background-color: #4CAF50; color: white;")
|
|
|
self.process_btn.clicked.connect(self.start_process)
|
|
|
self.process_btn.setEnabled(False) # 初始不可用
|
|
|
layout.addWidget(self.process_btn)
|
|
|
|
|
|
# 日志区域
|
|
|
self.log_area = QTextEdit()
|
|
|
self.log_area.setReadOnly(True)
|
|
|
layout.addWidget(QLabel("日志:"))
|
|
|
layout.addWidget(self.log_area)
|
|
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def setup_record_ui(self, layout):
|
|
|
title = QLabel("音频采集")
|
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# >>> NEW: 麦克风选择
|
|
|
device_layout = QHBoxLayout()
|
|
|
device_layout.addWidget(QLabel("输入设备:"))
|
|
|
self.device_combo = QComboBox()
|
|
|
self.refresh_devices()
|
|
|
self.device_combo.currentIndexChanged.connect(self.on_device_changed)
|
|
|
device_layout.addWidget(self.device_combo)
|
|
|
refresh_btn = QPushButton("刷新设备")
|
|
|
refresh_btn.clicked.connect(self.refresh_devices)
|
|
|
device_layout.addWidget(refresh_btn)
|
|
|
layout.addLayout(device_layout)
|
|
|
|
|
|
# 录音控制
|
|
|
record_hint = QLabel("按住按钮开始录音,松开结束(说话同时观察下方“麦克风电平”是否跳动)")
|
|
|
record_hint.setAlignment(Qt.AlignCenter)
|
|
|
|
|
|
self.record_btn = QPushButton("按住录音")
|
|
|
self.record_btn.setMinimumHeight(80)
|
|
|
self.record_btn.setStyleSheet("""
|
|
|
QPushButton {
|
|
|
background-color: #ff4d4d;
|
|
|
color: white;
|
|
|
font-size: 18px;
|
|
|
border-radius: 10px;
|
|
|
}
|
|
|
QPushButton:pressed {
|
|
|
background-color: #cc0000;
|
|
|
}
|
|
|
""")
|
|
|
self.record_btn.mousePressEvent = self.start_recording
|
|
|
self.record_btn.mouseReleaseEvent = self.stop_recording
|
|
|
self.record_btn.setContextMenuPolicy(Qt.PreventContextMenu)
|
|
|
|
|
|
# >>> NEW: 实时电平与录音进度
|
|
|
self.mic_level = QProgressBar()
|
|
|
self.mic_level.setRange(0, 100)
|
|
|
self.mic_level.setFormat("麦克风电平:%p%")
|
|
|
|
|
|
self.record_progress = QProgressBar()
|
|
|
self.record_progress.setRange(0, 100)
|
|
|
self.record_progress.setValue(0)
|
|
|
|
|
|
self.record_duration_label = QLabel("录音时长: 0.0秒")
|
|
|
self.record_duration_label.setAlignment(Qt.AlignCenter)
|
|
|
|
|
|
layout.addWidget(record_hint)
|
|
|
layout.addWidget(self.record_btn)
|
|
|
layout.addWidget(self.mic_level) # >>> NEW
|
|
|
layout.addWidget(self.record_progress)
|
|
|
layout.addWidget(self.record_duration_label)
|
|
|
|
|
|
# 录音计时器
|
|
|
self.record_timer = QTimer(self)
|
|
|
self.record_timer.timeout.connect(self.update_record_duration)
|
|
|
self.record_start_time = 0
|
|
|
|
|
|
def refresh_devices(self):
|
|
|
"""枚举有输入通道的设备,并填充到下拉框"""
|
|
|
self.device_combo.clear()
|
|
|
try:
|
|
|
p = pyaudio.PyAudio()
|
|
|
default_host_api = p.get_host_api_info_by_index(0)
|
|
|
default_input_index = default_host_api.get("defaultInputDevice", -1)
|
|
|
found = []
|
|
|
for i in range(p.get_device_count()):
|
|
|
info = p.get_device_info_by_index(i)
|
|
|
if int(info.get('maxInputChannels', 0)) > 0:
|
|
|
name = info.get('name', f"Device {i}")
|
|
|
sr = int(info.get('defaultSampleRate', 0))
|
|
|
label = f"[{i}] {name} (sr={sr})"
|
|
|
self.device_combo.addItem(label, i)
|
|
|
found.append(i)
|
|
|
p.terminate()
|
|
|
|
|
|
# 选中默认输入设备(若存在)
|
|
|
if default_input_index in found:
|
|
|
idx = found.index(default_input_index)
|
|
|
self.device_combo.setCurrentIndex(idx)
|
|
|
self.device_index = default_input_index
|
|
|
elif found:
|
|
|
self.device_combo.setCurrentIndex(0)
|
|
|
self.device_index = self.device_combo.currentData()
|
|
|
else:
|
|
|
self.device_index = None
|
|
|
except Exception:
|
|
|
self.device_index = None
|
|
|
|
|
|
def on_device_changed(self, _):
|
|
|
self.device_index = self.device_combo.currentData()
|
|
|
|
|
|
def setup_upload_ui(self, layout):
|
|
|
title = QLabel("上传WAV文件")
|
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# 文件选择
|
|
|
file_layout = QHBoxLayout()
|
|
|
self.file_path = QLineEdit()
|
|
|
self.file_path.setReadOnly(True)
|
|
|
self.browse_btn = QPushButton("浏览WAV文件")
|
|
|
self.browse_btn.clicked.connect(self.browse_file)
|
|
|
file_layout.addWidget(self.file_path)
|
|
|
file_layout.addWidget(self.browse_btn)
|
|
|
layout.addLayout(file_layout)
|
|
|
|
|
|
def start_recording(self, event):
|
|
|
if event.button() == Qt.LeftButton:
|
|
|
if not self.record_thread or not self.record_thread.isRunning():
|
|
|
# >>> CHANGED: 传入 device_index
|
|
|
self.record_thread = RecordThread(max_duration=60, device_index=self.device_index)
|
|
|
self.record_thread.update_signal.connect(self.update_log)
|
|
|
self.record_thread.finish_signal.connect(self.on_recording_finish)
|
|
|
self.record_thread.progress_signal.connect(self.record_progress.setValue)
|
|
|
self.record_thread.level_signal.connect(self.mic_level.setValue) # >>> NEW
|
|
|
self.record_thread.start()
|
|
|
|
|
|
self.record_start_time = time.time()
|
|
|
self.record_timer.start(100)
|
|
|
self.update_log("开始录音...(松开按钮结束)")
|
|
|
|
|
|
def stop_recording(self, event):
|
|
|
if event.button() == Qt.LeftButton and self.record_thread and self.record_thread.isRunning():
|
|
|
self.record_thread.is_recording = False
|
|
|
self.record_timer.stop()
|
|
|
self.record_duration_label.setText("录音时长: 0.0秒")
|
|
|
self.record_progress.setValue(0)
|
|
|
# 让电平逐步回落
|
|
|
self.mic_level.setValue(0)
|
|
|
|
|
|
def update_record_duration(self):
|
|
|
elapsed = time.time() - self.record_start_time
|
|
|
self.record_duration_label.setText(f"录音时长: {elapsed:.1f}秒")
|
|
|
|
|
|
def on_recording_finish(self, temp_file):
|
|
|
self.wav_path = temp_file
|
|
|
self.process_btn.setEnabled(True)
|
|
|
|
|
|
def browse_file(self):
|
|
|
file, _ = QFileDialog.getOpenFileName(self, "选择WAV", "", "WAV文件 (*.wav)")
|
|
|
if file:
|
|
|
self.file_path.setText(file)
|
|
|
self.wav_path = file
|
|
|
self.process_btn.setEnabled(True)
|
|
|
|
|
|
def start_process(self):
|
|
|
if self.mode == "upload":
|
|
|
self.wav_path = self.file_path.text()
|
|
|
|
|
|
model_path = self.model_path.text()
|
|
|
scaler_path = self.scaler_path.text()
|
|
|
|
|
|
if not self.wav_path or not os.path.exists(self.wav_path):
|
|
|
self.log_area.append("请先录音或选择有效的WAV文件!")
|
|
|
return
|
|
|
if not os.path.exists(model_path) or not os.path.exists(scaler_path):
|
|
|
self.log_area.append("模型文件不存在!请先运行train_model.py训练模型")
|
|
|
return
|
|
|
|
|
|
self.log_area.clear()
|
|
|
self.process_btn.setEnabled(False)
|
|
|
|
|
|
self.thread = ProcessThread(self.wav_path, model_path, scaler_path)
|
|
|
self.thread.update_signal.connect(self.update_log)
|
|
|
self.thread.finish_signal.connect(self.on_process_finish)
|
|
|
self.thread.start()
|
|
|
|
|
|
def update_log(self, msg):
|
|
|
self.log_area.append(msg)
|
|
|
|
|
|
def on_process_finish(self, result):
|
|
|
self.parent.switch_to_result(result)
|
|
|
|
|
|
|
|
|
# 第三层界面:结果显示界面
|
|
|
class ResultWidget(QWidget):
|
|
|
def __init__(self, parent=None, result=None):
|
|
|
super().__init__(parent)
|
|
|
self.parent = parent
|
|
|
self.result = result
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
|
|
|
# 返回按钮
|
|
|
back_btn = QPushButton("返回")
|
|
|
back_btn.clicked.connect(self.parent.switch_to_input_from_result)
|
|
|
layout.addWidget(back_btn)
|
|
|
|
|
|
title = QLabel("处理结果")
|
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;")
|
|
|
layout.addWidget(title)
|
|
|
|
|
|
# 结果显示
|
|
|
self.result_area = QTextEdit()
|
|
|
self.result_area.setReadOnly(True)
|
|
|
|
|
|
if self.result:
|
|
|
res = f"文件名: {self.result['filename']}\n"
|
|
|
res += f"片段数: {self.result['segments']}\n"
|
|
|
res += "预测结果:\n"
|
|
|
for i, (pred, prob) in enumerate(zip(self.result['predictions'], self.result['probabilities'])):
|
|
|
res += f" 片段{i + 1}: 标签={pred} (概率={prob})\n"
|
|
|
res += f"\n平均概率: {self.result['mean_prob']}\n"
|
|
|
res += f"最终结果: {'空心' if self.result['final_label'] else '实心'}"
|
|
|
self.result_area.setText(res)
|
|
|
|
|
|
layout.addWidget(self.result_area)
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
|
# 主窗口
|
|
|
class AudioClassifierGUI(QMainWindow):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.current_input_mode = "record" # 记录当前输入模式
|
|
|
self.process_result = None # 存储处理结果
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
self.setWindowTitle("音频分类器")
|
|
|
self.setGeometry(100, 100, 800, 600)
|
|
|
|
|
|
# 创建堆叠窗口
|
|
|
self.stacked_widget = QStackedWidget()
|
|
|
self.setCentralWidget(self.stacked_widget)
|
|
|
|
|
|
# 创建三个界面
|
|
|
self.main_menu_widget = MainMenuWidget(self)
|
|
|
self.record_input_widget = InputWidget(self, "record")
|
|
|
self.upload_input_widget = InputWidget(self, "upload")
|
|
|
self.result_widget = ResultWidget(self)
|
|
|
|
|
|
# 添加到堆叠窗口
|
|
|
self.stacked_widget.addWidget(self.main_menu_widget)
|
|
|
self.stacked_widget.addWidget(self.record_input_widget)
|
|
|
self.stacked_widget.addWidget(self.upload_input_widget)
|
|
|
self.stacked_widget.addWidget(self.result_widget)
|
|
|
|
|
|
# 显示主菜单
|
|
|
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
|
|
|
|
|
|
def switch_to_input(self, mode):
|
|
|
self.current_input_mode = mode
|
|
|
if mode == "record":
|
|
|
self.stacked_widget.setCurrentWidget(self.record_input_widget)
|
|
|
else:
|
|
|
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
|
|
|
|
|
|
def switch_to_main_menu(self):
|
|
|
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
|
|
|
|
|
|
def switch_to_result(self, result):
|
|
|
self.process_result = result
|
|
|
self.result_widget = ResultWidget(self, result)
|
|
|
# 移除旧的结果界面并添加新的
|
|
|
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
|
|
|
self.stacked_widget.addWidget(self.result_widget)
|
|
|
self.stacked_widget.setCurrentWidget(self.result_widget)
|
|
|
|
|
|
def switch_to_input_from_result(self):
|
|
|
if self.current_input_mode == "record":
|
|
|
self.stacked_widget.setCurrentWidget(self.record_input_widget)
|
|
|
else:
|
|
|
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
try:
|
|
|
import pyaudio # 确保已安装
|
|
|
except ImportError:
|
|
|
print("请先安装pyaudio: pip install pyaudio")
|
|
|
sys.exit(1)
|
|
|
|
|
|
app = QApplication(sys.argv)
|
|
|
window = AudioClassifierGUI()
|
|
|
window.show()
|
|
|
sys.exit(app.exec_())
|