You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-02/audio_classifier_gui_03.py

918 lines
37 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import csv
import sqlite3
from datetime import datetime
import numpy as np
import joblib
import librosa
import pyaudio
import wave
from scipy.signal import hilbert
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QPushButton, QLabel,
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
QComboBox, QTableWidget, QTableWidgetItem, QMessageBox,
QDialog, QFormLayout, QDialogButtonBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
# ======================
# 本地 SQLite 数据库(含 CRUD
# ======================
class DatabaseManager:
def __init__(self, db_path="results.db"):
self.db_path = db_path
self._ensure_schema()
def _connect(self):
return sqlite3.connect(self.db_path)
def _ensure_schema(self):
con = self._connect()
cur = con.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT,
segments INTEGER,
mean_prob REAL,
final_label INTEGER,
created_at TEXT
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS run_segments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id INTEGER,
seg_index INTEGER,
label INTEGER,
proba REAL,
FOREIGN KEY(run_id) REFERENCES runs(id)
);
""")
con.commit()
con.close()
# ---------- 原有写入 ----------
def insert_result(self, result_dict):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(
result_dict.get("filename"),
int(result_dict.get("segments", 0)),
float(result_dict.get("mean_prob", 0.0)),
int(result_dict.get("final_label", 0)),
datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)
)
run_id = cur.lastrowid
preds = result_dict.get("predictions", [])
probas = result_dict.get("probabilities", [])
for i, (lab, pr) in enumerate(zip(preds, probas), start=1):
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(run_id, i, int(lab), float(pr))
)
con.commit()
con.close()
return run_id
# ---------- R查询列表/单条) ----------
def fetch_recent_runs(self, limit=50):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs ORDER BY id DESC LIMIT ?
""", (limit,))
rows = cur.fetchall()
con.close()
return rows
def get_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs WHERE id = ?
""", (run_id,))
row = cur.fetchone()
con.close()
return row
def search_runs(self, keyword: str, limit=100):
kw = f"%{keyword}%"
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, created_at, filename, segments, mean_prob, final_label
FROM runs
WHERE filename LIKE ?
ORDER BY id DESC LIMIT ?
""", (kw, limit))
rows = cur.fetchall()
con.close()
return rows
# ---------- C/U/Drun ----------
def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None):
created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
(filename, int(segments), float(mean_prob), int(final_label), created_at)
)
run_id = cur.lastrowid
con.commit()
con.close()
return run_id
def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE runs
SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=?
WHERE id=?
""", (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id)))
con.commit()
con.close()
def delete_run(self, run_id: int):
con = self._connect()
cur = con.cursor()
# 先删子表
cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),))
cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),))
con.commit()
con.close()
# ---------- C/R/U/Dsegments ----------
def list_segments(self, run_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("""
SELECT id, seg_index, label, proba FROM run_segments
WHERE run_id=? ORDER BY seg_index ASC
""", (int(run_id),))
rows = cur.fetchall()
con.close()
return rows
def add_segment(self, run_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute(
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
(int(run_id), int(seg_index), int(label), float(proba))
)
seg_id = cur.lastrowid
con.commit()
con.close()
return seg_id
def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float):
con = self._connect()
cur = con.cursor()
cur.execute("""
UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=?
""", (int(seg_index), int(label), float(proba), int(seg_id)))
con.commit()
con.close()
def delete_segment(self, seg_id: int):
con = self._connect()
cur = con.cursor()
cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),))
con.commit()
con.close()
# ---------- 导出 ----------
def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000):
rows = self.fetch_recent_runs(limit)
headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"]
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(headers)
for r in rows:
writer.writerow(r)
return csv_path
DB = DatabaseManager("results.db")
# ======================
# 特征提取(同前)
# ======================
class FeatureExtractor:
WIN = 1024
OVERLAP = 512
THRESHOLD = 0.1
SEG_LEN_S = 0.2
STEP = WIN - OVERLAP
@staticmethod
def frame_energy(signal: np.ndarray) -> np.ndarray:
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
return np.sum(frames ** 2, axis=0)
@staticmethod
def detect_hits(energy: np.ndarray) -> np.ndarray:
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
if idx.size == 0:
return np.array([])
flags = np.diff(np.concatenate(([0], idx))) > 5
return idx[flags]
@staticmethod
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray):
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
segments = []
for frame_idx in hit_starts:
s = frame_idx * FeatureExtractor.STEP
e = min(s + seg_len, len(signal))
segments.append(signal[s:e])
return segments
@staticmethod
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
x = x.flatten()
if len(x) == 0:
return np.zeros(5)
rms = np.sqrt(np.mean(x ** 2))
fft = np.fft.fft(x)
freq = np.fft.fftfreq(len(x), 1 / sr)
positive = freq >= 0
freq, fft_mag = freq[positive], np.abs(fft[positive])
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
spec_power = fft_mag
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
env_peak = np.max(np.abs(hilbert(x)))
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
# ======================
# 录音线程(同前)
# ======================
class RecordThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
level_signal = pyqtSignal(int)
def __init__(self, max_duration=60, device_index=None):
super().__init__()
self.max_duration = max_duration
self.is_recording = False
self.temp_file = "temp_recording.wav"
self.device_index = device_index
def run(self):
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
p, stream = None, None
try:
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True,
input_device_index=self.device_index, frames_per_buffer=CHUNK)
self.is_recording = True
start_time = time.time()
frames = []
max_int16 = 32767.0
while self.is_recording:
elapsed = time.time() - start_time
if elapsed >= self.max_duration:
self.update_signal.emit(f"已达最大时长 {self.max_duration}")
break
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
data = stream.read(CHUNK, exception_on_overflow=False)
frames.append(data)
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
rms = np.sqrt(np.mean(chunk_np ** 2))
self.level_signal.emit(int(np.clip(rms * 500, 0, 100)))
stream.stop_stream(); stream.close(); p.terminate()
if frames:
wf = wave.open(self.temp_file, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
self.finish_signal.emit(self.temp_file)
else:
self.update_signal.emit("未录到有效音频。")
except Exception as e:
self.update_signal.emit(f"录制错误: {e}")
# ======================
# 推理处理线程(写入 DB
# ======================
class ProcessThread(QThread):
update_signal = pyqtSignal(str)
finish_signal = pyqtSignal(dict)
def __init__(self, wav_path, model_path, scaler_path):
super().__init__()
self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path
def run(self):
try:
self.update_signal.emit("加载模型和标准化器...")
model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path)
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
sig = sig / max(np.max(np.abs(sig)), 1e-9)
ene = FeatureExtractor.frame_energy(sig)
hits = FeatureExtractor.detect_hits(ene)
segs = FeatureExtractor.extract_segments(sig, sr, hits)
if not segs:
self.update_signal.emit("未检测到有效片段!")
return
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs]
X = np.vstack(feats)
X_std = scaler.transform(X)
y_pred = model.predict(X_std)
if hasattr(model, "predict_proba"):
y_proba = model.predict_proba(X_std)[:, 1]
else:
y_proba = (y_pred.astype(float) + 0.0)
result = {
"filename": os.path.basename(self.wav_path),
"segments": len(segs),
"predictions": y_pred.tolist(),
"probabilities": [round(float(p), 4) for p in y_proba],
"mean_prob": round(float(np.mean(y_proba)), 4),
"final_label": int(np.mean(y_proba) >= 0.5)
}
DB.insert_result(result)
self.finish_signal.emit(result)
except Exception as e:
self.update_signal.emit(f"错误: {e}")
# ======================
# 对话框:编辑 run
# ======================
class RunEditDialog(QDialog):
def __init__(self, parent=None, run_row=None):
"""
run_row: (id, created_at, filename, segments, mean_prob, final_label) 或 None新增
"""
super().__init__(parent)
self.setWindowTitle("编辑记录" if run_row else "新增记录")
self.resize(380, 240)
self.run_row = run_row
self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True)
self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.ed_filename = QLineEdit()
self.ed_segments = QLineEdit("0")
self.ed_mean_prob = QLineEdit("0.0")
self.ed_final_label = QLineEdit("0") # 0=实心 1=空心
form = QFormLayout()
form.addRow("ID", self.ed_id)
form.addRow("时间(created_at)", self.ed_created)
form.addRow("文件名(filename)", self.ed_filename)
form.addRow("片段数(segments)", self.ed_segments)
form.addRow("平均概率(mean_prob)", self.ed_mean_prob)
form.addRow("最终标签(0=实 1=空)", self.ed_final_label)
if run_row:
self.ed_id.setText(str(run_row[0]))
self.ed_created.setText(str(run_row[1]))
self.ed_filename.setText(str(run_row[2]))
self.ed_segments.setText(str(run_row[3]))
self.ed_mean_prob.setText(str(run_row[4]))
self.ed_final_label.setText(str(run_row[5]))
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(self.accept)
btns.rejected.connect(self.reject)
lay = QVBoxLayout()
lay.addLayout(form)
lay.addWidget(btns)
self.setLayout(lay)
def values(self):
return dict(
id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None),
created_at=self.ed_created.text().strip(),
filename=self.ed_filename.text().strip(),
segments=int(float(self.ed_segments.text() or 0)),
mean_prob=float(self.ed_mean_prob.text() or 0.0),
final_label=int(float(self.ed_final_label.text() or 0))
)
# ======================
# 对话框:管理 segments
# ======================
class SegmentManagerDialog(QDialog):
def __init__(self, parent=None, run_id=None):
super().__init__(parent)
self.setWindowTitle(f"片段管理 - run_id={run_id}")
self.resize(560, 420)
self.run_id = run_id
self.table = QTableWidget(0, 4)
self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"])
self.btn_add = QPushButton("新增片段")
self.btn_edit = QPushButton("编辑所选")
self.btn_del = QPushButton("删除所选")
self.btn_close = QPushButton("关闭")
hl = QHBoxLayout()
for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close):
hl.addWidget(b)
lay = QVBoxLayout()
lay.addWidget(self.table)
lay.addLayout(hl)
self.setLayout(lay)
self.btn_add.clicked.connect(self.add_segment)
self.btn_edit.clicked.connect(self.edit_selected)
self.btn_del.clicked.connect(self.delete_selected)
self.btn_close.clicked.connect(self.accept)
self.load_segments()
def load_segments(self):
rows = DB.list_segments(self.run_id)
self.table.setRowCount(len(rows))
for i, (sid, idx, label, proba) in enumerate(rows):
self.table.setItem(i, 0, QTableWidgetItem(str(sid)))
self.table.setItem(i, 1, QTableWidgetItem(str(idx)))
self.table.setItem(i, 2, QTableWidgetItem(str(label)))
self.table.setItem(i, 3, QTableWidgetItem(str(proba)))
self.table.resizeColumnsToContents()
def _get_selected_seg_id(self):
r = self.table.currentRow()
if r < 0:
return None
return int(self.table.item(r, 0).text())
def add_segment(self):
d = QDialog(self)
d.setWindowTitle("新增片段")
form = QFormLayout()
ed_index = QLineEdit("1")
ed_label = QLineEdit("0")
ed_proba = QLineEdit("0.0")
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)),
float(ed_proba.text() or 0.0))
self.load_segments()
def edit_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
# 读当前值
row = self.table.currentRow()
cur_idx = self.table.item(row, 1).text()
cur_label = self.table.item(row, 2).text()
cur_proba = self.table.item(row, 3).text()
d = QDialog(self)
d.setWindowTitle("编辑片段")
form = QFormLayout()
ed_index = QLineEdit(cur_idx)
ed_label = QLineEdit(cur_label)
ed_proba = QLineEdit(cur_proba)
form.addRow("seg_index", ed_index)
form.addRow("label(0/1)", ed_label)
form.addRow("proba(0~1)", ed_proba)
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
if d.exec_() == QDialog.Accepted:
DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)),
int(float(ed_label.text() or cur_label)),
float(ed_proba.text() or cur_proba))
self.load_segments()
def delete_selected(self):
seg_id = self._get_selected_seg_id()
if not seg_id:
QMessageBox.information(self, "提示", "请先选择一行片段。")
return
if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!",
QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes:
DB.delete_segment(seg_id)
self.load_segments()
# ======================
# 主菜单
# ======================
class MainMenuWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
layout = QVBoxLayout()
title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
layout.addWidget(title)
btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record"))
btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload"))
btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history)
for b in [btn1, btn2, btn3]:
b.setMinimumHeight(60); layout.addWidget(b)
layout.addStretch(1); self.setLayout(layout)
# ======================
# 输入界面(录音 / 上传)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
# ======================
# 历史记录(含 CRUD 按钮)
# ======================
class HistoryWidget(QWidget):
def __init__(self, parent):
super().__init__(parent)
self.parent=parent
self.init_ui()
def init_ui(self):
l=QVBoxLayout()
back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back)
top=QHBoxLayout()
self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...")
btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search)
btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load)
btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv)
btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run)
btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected)
btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected)
btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments)
for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]:
top.addWidget(w)
l.addLayout(top)
self.table=QTableWidget(0,6)
self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"])
l.addWidget(self.table)
self.setLayout(l)
self.load()
def _current_run_id(self):
r=self.table.currentRow()
if r<0: return None
return int(self.table.item(r,0).text())
# ---- 数据加载 / 搜索 ----
def load(self):
rows=DB.fetch_recent_runs(200)
self._fill(rows)
def search(self):
kw=self.search_edit.text().strip()
rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200)
self._fill(rows)
def _fill(self, rows):
self.table.setRowCount(len(rows))
for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows):
self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid)))
self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at)))
self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn)))
self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs)))
self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob)))
self.table.setItem(r_idx, 5, QTableWidgetItem("空心" if int(final_label) else "实心"))
self.table.resizeColumnsToContents()
# ---- 导出 ----
def export_csv(self):
p=DB.export_runs_to_csv()
QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}")
# ---- CRUD: run ----
def add_run(self):
dlg=RunEditDialog(self, run_row=None)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def edit_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
row=DB.get_run(rid)
dlg=RunEditDialog(self, run_row=row)
if dlg.exec_()==QDialog.Accepted:
v=dlg.values()
DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
self.load()
def delete_selected(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!",
QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes:
DB.delete_run(rid)
self.load()
# ---- 片段管理 ----
def open_segments(self):
rid=self._current_run_id()
if not rid:
QMessageBox.information(self,"提示","请先选择一行记录。")
return
dlg=SegmentManagerDialog(self, run_id=rid)
dlg.exec_()
# 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主
self.load()
# ======================
# 结果页(简)
# ======================
class ResultWidget(QWidget):
def __init__(self, parent, result=None):
super().__init__(parent); self.parent=parent
layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b)
t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t)
area=QTextEdit(); area.setReadOnly(True)
if result:
txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'空心' if result['final_label'] else '实心'}"
area.setText(txt)
layout.addWidget(area); self.setLayout(layout)
# ======================
# 主窗口
# ======================
class AudioClassifierGUI(QMainWindow):
def __init__(self):
super().__init__()
self.current_input_mode = "record"
self.process_result = None
self.init_ui()
def init_ui(self):
self.setWindowTitle("音频分类器")
self.setGeometry(100, 100, 1000, 700)
self.stacked_widget = QStackedWidget()
self.setCentralWidget(self.stacked_widget)
self.main_menu_widget = MainMenuWidget(self)
self.record_input_widget = InputWidget(self, "record")
self.upload_input_widget = InputWidget(self, "upload")
self.result_widget = ResultWidget(self)
self.history_widget = HistoryWidget(self)
self.stacked_widget.addWidget(self.main_menu_widget) # 0
self.stacked_widget.addWidget(self.record_input_widget) # 1
self.stacked_widget.addWidget(self.upload_input_widget) # 2
self.stacked_widget.addWidget(self.result_widget) # 3
self.stacked_widget.addWidget(self.history_widget) # 4
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_input(self, mode):
self.current_input_mode = mode
if mode == "record":
self.stacked_widget.setCurrentWidget(self.record_input_widget)
else:
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
def switch_to_main_menu(self):
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
def switch_to_result(self, result):
self.process_result = result
self.result_widget = ResultWidget(self, result)
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
self.stacked_widget.addWidget(self.result_widget)
self.stacked_widget.setCurrentWidget(self.result_widget)
def switch_to_history(self):
self.history_widget.load()
self.stacked_widget.setCurrentWidget(self.history_widget)
# ======================
# 上传/录音 界面(放最后,避免打断阅读)
# ======================
class InputWidget(QWidget):
def __init__(self, parent, mode="record"):
super().__init__(parent)
self.parent, self.mode = parent, mode
self.wav_path, self.record_thread = "", None
self.device_index = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
layout.addWidget(back)
if self.mode == "record": self.setup_record(layout)
else: self.setup_upload(layout)
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
layout.addLayout(path_layout)
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
self.process_btn.clicked.connect(self.start_process)
layout.addWidget(self.process_btn)
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
self.setLayout(layout)
def setup_record(self, layout):
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
self.dev_combo = QComboBox(); self.refresh_devices()
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
dev_layout.addWidget(self.dev_combo)
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
def setup_upload(self, layout):
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
b = QPushButton("浏览"); b.clicked.connect(self.browse)
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
def refresh_devices(self):
self.dev_combo.clear(); p = pyaudio.PyAudio()
for i in range(p.get_device_count()):
info = p.get_device_info_by_index(i)
if int(info.get('maxInputChannels',0))>0:
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
if self.dev_combo.count()>0:
self.dev_combo.setCurrentIndex(0)
self.device_index = self.dev_combo.currentData()
p.terminate()
def start_rec(self,e):
if e.button()==Qt.LeftButton:
self.record_thread=RecordThread(device_index=self.device_index)
self.record_thread.update_signal.connect(self.log.append)
self.record_thread.finish_signal.connect(self.finish_rec)
self.record_thread.progress_signal.connect(self.progress.setValue)
self.record_thread.level_signal.connect(self.level.setValue)
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
def stop_rec(self,e):
if e.button()==Qt.LeftButton and self.record_thread:
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
def browse(self):
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
def start_process(self):
if self.mode=="upload": self.wav_path=self.file.text()
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
self.thread.update_signal.connect(self.log.append)
self.thread.finish_signal.connect(self.parent.switch_to_result)
self.thread.start()
if __name__ == "__main__":
try:
import pyaudio
except ImportError:
print("请先安装pyaudio: pip install pyaudio")
sys.exit(1)
app = QApplication(sys.argv)
window = AudioClassifierGUI()
window.show()
sys.exit(app.exec_())