#!/usr/bin/env python # -*- coding: utf-8 -*- import sys import os import time import csv import sqlite3 from datetime import datetime import numpy as np import joblib import librosa import pyaudio import wave from scipy.signal import hilbert from PyQt5.QtWidgets import ( QApplication, QMainWindow, QPushButton, QLabel, QLineEdit, QTextEdit, QFileDialog, QVBoxLayout, QHBoxLayout, QWidget, QProgressBar, QStackedWidget, QComboBox, QTableWidget, QTableWidgetItem, QMessageBox, QDialog, QFormLayout, QDialogButtonBox ) from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer # ====================== # 本地 SQLite 数据库(含 CRUD) # ====================== class DatabaseManager: def __init__(self, db_path="results.db"): self.db_path = db_path self._ensure_schema() def _connect(self): return sqlite3.connect(self.db_path) def _ensure_schema(self): con = self._connect() cur = con.cursor() cur.execute(""" CREATE TABLE IF NOT EXISTS runs ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT, segments INTEGER, mean_prob REAL, final_label INTEGER, created_at TEXT ); """) cur.execute(""" CREATE TABLE IF NOT EXISTS run_segments ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_id INTEGER, seg_index INTEGER, label INTEGER, proba REAL, FOREIGN KEY(run_id) REFERENCES runs(id) ); """) con.commit() con.close() # ---------- 原有写入 ---------- def insert_result(self, result_dict): con = self._connect() cur = con.cursor() cur.execute( "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", ( result_dict.get("filename"), int(result_dict.get("segments", 0)), float(result_dict.get("mean_prob", 0.0)), int(result_dict.get("final_label", 0)), datetime.now().strftime("%Y-%m-%d %H:%M:%S") ) ) run_id = cur.lastrowid preds = result_dict.get("predictions", []) probas = result_dict.get("probabilities", []) for i, (lab, pr) in enumerate(zip(preds, probas), start=1): cur.execute( "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", (run_id, i, int(lab), float(pr)) ) con.commit() con.close() return run_id # ---------- R(查询列表/单条) ---------- def fetch_recent_runs(self, limit=50): con = self._connect() cur = con.cursor() cur.execute(""" SELECT id, created_at, filename, segments, mean_prob, final_label FROM runs ORDER BY id DESC LIMIT ? """, (limit,)) rows = cur.fetchall() con.close() return rows def get_run(self, run_id: int): con = self._connect() cur = con.cursor() cur.execute(""" SELECT id, created_at, filename, segments, mean_prob, final_label FROM runs WHERE id = ? """, (run_id,)) row = cur.fetchone() con.close() return row def search_runs(self, keyword: str, limit=100): kw = f"%{keyword}%" con = self._connect() cur = con.cursor() cur.execute(""" SELECT id, created_at, filename, segments, mean_prob, final_label FROM runs WHERE filename LIKE ? ORDER BY id DESC LIMIT ? """, (kw, limit)) rows = cur.fetchall() con.close() return rows # ---------- C/U/D(run) ---------- def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None): created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S") con = self._connect() cur = con.cursor() cur.execute( "INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)", (filename, int(segments), float(mean_prob), int(final_label), created_at) ) run_id = cur.lastrowid con.commit() con.close() return run_id def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str): con = self._connect() cur = con.cursor() cur.execute(""" UPDATE runs SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=? WHERE id=? """, (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id))) con.commit() con.close() def delete_run(self, run_id: int): con = self._connect() cur = con.cursor() # 先删子表 cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),)) cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),)) con.commit() con.close() # ---------- C/R/U/D(segments) ---------- def list_segments(self, run_id: int): con = self._connect() cur = con.cursor() cur.execute(""" SELECT id, seg_index, label, proba FROM run_segments WHERE run_id=? ORDER BY seg_index ASC """, (int(run_id),)) rows = cur.fetchall() con.close() return rows def add_segment(self, run_id: int, seg_index: int, label: int, proba: float): con = self._connect() cur = con.cursor() cur.execute( "INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)", (int(run_id), int(seg_index), int(label), float(proba)) ) seg_id = cur.lastrowid con.commit() con.close() return seg_id def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float): con = self._connect() cur = con.cursor() cur.execute(""" UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=? """, (int(seg_index), int(label), float(proba), int(seg_id))) con.commit() con.close() def delete_segment(self, seg_id: int): con = self._connect() cur = con.cursor() cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),)) con.commit() con.close() # ---------- 导出 ---------- def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000): rows = self.fetch_recent_runs(limit) headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"] with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(headers) for r in rows: writer.writerow(r) return csv_path DB = DatabaseManager("results.db") # ====================== # 特征提取(同前) # ====================== class FeatureExtractor: WIN = 1024 OVERLAP = 512 THRESHOLD = 0.1 SEG_LEN_S = 0.2 STEP = WIN - OVERLAP @staticmethod def frame_energy(signal: np.ndarray) -> np.ndarray: frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP) return np.sum(frames ** 2, axis=0) @staticmethod def detect_hits(energy: np.ndarray) -> np.ndarray: idx = np.where(energy > FeatureExtractor.THRESHOLD)[0] if idx.size == 0: return np.array([]) flags = np.diff(np.concatenate(([0], idx))) > 5 return idx[flags] @staticmethod def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray): seg_len = int(FeatureExtractor.SEG_LEN_S * sr) segments = [] for frame_idx in hit_starts: s = frame_idx * FeatureExtractor.STEP e = min(s + seg_len, len(signal)) segments.append(signal[s:e]) return segments @staticmethod def extract_features(x: np.ndarray, sr: int) -> np.ndarray: x = x.flatten() if len(x) == 0: return np.zeros(5) rms = np.sqrt(np.mean(x ** 2)) fft = np.fft.fft(x) freq = np.fft.fftfreq(len(x), 1 / sr) positive = freq >= 0 freq, fft_mag = freq[positive], np.abs(fft[positive]) main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0 spec_power = fft_mag centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12) spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12)) skewness = np.sum(((freq - centroid) ** 3) * spec_power) / ( (np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0 mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13) mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0 env_peak = np.max(np.abs(hilbert(x))) return np.array([rms, main_freq, skewness, mfcc_mean, env_peak]) # ====================== # 录音线程(同前) # ====================== class RecordThread(QThread): update_signal = pyqtSignal(str) finish_signal = pyqtSignal(str) progress_signal = pyqtSignal(int) level_signal = pyqtSignal(int) def __init__(self, max_duration=60, device_index=None): super().__init__() self.max_duration = max_duration self.is_recording = False self.temp_file = "temp_recording.wav" self.device_index = device_index def run(self): FORMAT = pyaudio.paInt16 CHANNELS = 1 RATE = 44100 CHUNK = 1024 p, stream = None, None try: p = pyaudio.PyAudio() stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, input_device_index=self.device_index, frames_per_buffer=CHUNK) self.is_recording = True start_time = time.time() frames = [] max_int16 = 32767.0 while self.is_recording: elapsed = time.time() - start_time if elapsed >= self.max_duration: self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒") break self.progress_signal.emit(int((elapsed / self.max_duration) * 100)) data = stream.read(CHUNK, exception_on_overflow=False) frames.append(data) chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16 rms = np.sqrt(np.mean(chunk_np ** 2)) self.level_signal.emit(int(np.clip(rms * 500, 0, 100))) stream.stop_stream(); stream.close(); p.terminate() if frames: wf = wave.open(self.temp_file, 'wb') wf.setnchannels(CHANNELS) wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) wf.setframerate(RATE) wf.writeframes(b''.join(frames)) wf.close() self.finish_signal.emit(self.temp_file) else: self.update_signal.emit("未录到有效音频。") except Exception as e: self.update_signal.emit(f"录制错误: {e}") # ====================== # 推理处理线程(写入 DB) # ====================== class ProcessThread(QThread): update_signal = pyqtSignal(str) finish_signal = pyqtSignal(dict) def __init__(self, wav_path, model_path, scaler_path): super().__init__() self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path def run(self): try: self.update_signal.emit("加载模型和标准化器...") model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path) self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}") sig, sr = librosa.load(self.wav_path, sr=None, mono=True) sig = sig / max(np.max(np.abs(sig)), 1e-9) ene = FeatureExtractor.frame_energy(sig) hits = FeatureExtractor.detect_hits(ene) segs = FeatureExtractor.extract_segments(sig, sr, hits) if not segs: self.update_signal.emit("未检测到有效片段!") return feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs] X = np.vstack(feats) X_std = scaler.transform(X) y_pred = model.predict(X_std) if hasattr(model, "predict_proba"): y_proba = model.predict_proba(X_std)[:, 1] else: y_proba = (y_pred.astype(float) + 0.0) result = { "filename": os.path.basename(self.wav_path), "segments": len(segs), "predictions": y_pred.tolist(), "probabilities": [round(float(p), 4) for p in y_proba], "mean_prob": round(float(np.mean(y_proba)), 4), "final_label": int(np.mean(y_proba) >= 0.5) } DB.insert_result(result) self.finish_signal.emit(result) except Exception as e: self.update_signal.emit(f"错误: {e}") # ====================== # 对话框:编辑 run # ====================== class RunEditDialog(QDialog): def __init__(self, parent=None, run_row=None): """ run_row: (id, created_at, filename, segments, mean_prob, final_label) 或 None(新增) """ super().__init__(parent) self.setWindowTitle("编辑记录" if run_row else "新增记录") self.resize(380, 240) self.run_row = run_row self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True) self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) self.ed_filename = QLineEdit() self.ed_segments = QLineEdit("0") self.ed_mean_prob = QLineEdit("0.0") self.ed_final_label = QLineEdit("0") # 0=实心 1=空心 form = QFormLayout() form.addRow("ID:", self.ed_id) form.addRow("时间(created_at):", self.ed_created) form.addRow("文件名(filename):", self.ed_filename) form.addRow("片段数(segments):", self.ed_segments) form.addRow("平均概率(mean_prob):", self.ed_mean_prob) form.addRow("最终标签(0=实 1=空):", self.ed_final_label) if run_row: self.ed_id.setText(str(run_row[0])) self.ed_created.setText(str(run_row[1])) self.ed_filename.setText(str(run_row[2])) self.ed_segments.setText(str(run_row[3])) self.ed_mean_prob.setText(str(run_row[4])) self.ed_final_label.setText(str(run_row[5])) btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) btns.accepted.connect(self.accept) btns.rejected.connect(self.reject) lay = QVBoxLayout() lay.addLayout(form) lay.addWidget(btns) self.setLayout(lay) def values(self): return dict( id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None), created_at=self.ed_created.text().strip(), filename=self.ed_filename.text().strip(), segments=int(float(self.ed_segments.text() or 0)), mean_prob=float(self.ed_mean_prob.text() or 0.0), final_label=int(float(self.ed_final_label.text() or 0)) ) # ====================== # 对话框:管理 segments # ====================== class SegmentManagerDialog(QDialog): def __init__(self, parent=None, run_id=None): super().__init__(parent) self.setWindowTitle(f"片段管理 - run_id={run_id}") self.resize(560, 420) self.run_id = run_id self.table = QTableWidget(0, 4) self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"]) self.btn_add = QPushButton("新增片段") self.btn_edit = QPushButton("编辑所选") self.btn_del = QPushButton("删除所选") self.btn_close = QPushButton("关闭") hl = QHBoxLayout() for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close): hl.addWidget(b) lay = QVBoxLayout() lay.addWidget(self.table) lay.addLayout(hl) self.setLayout(lay) self.btn_add.clicked.connect(self.add_segment) self.btn_edit.clicked.connect(self.edit_selected) self.btn_del.clicked.connect(self.delete_selected) self.btn_close.clicked.connect(self.accept) self.load_segments() def load_segments(self): rows = DB.list_segments(self.run_id) self.table.setRowCount(len(rows)) for i, (sid, idx, label, proba) in enumerate(rows): self.table.setItem(i, 0, QTableWidgetItem(str(sid))) self.table.setItem(i, 1, QTableWidgetItem(str(idx))) self.table.setItem(i, 2, QTableWidgetItem(str(label))) self.table.setItem(i, 3, QTableWidgetItem(str(proba))) self.table.resizeColumnsToContents() def _get_selected_seg_id(self): r = self.table.currentRow() if r < 0: return None return int(self.table.item(r, 0).text()) def add_segment(self): d = QDialog(self) d.setWindowTitle("新增片段") form = QFormLayout() ed_index = QLineEdit("1") ed_label = QLineEdit("0") ed_proba = QLineEdit("0.0") form.addRow("seg_index:", ed_index) form.addRow("label(0/1):", ed_label) form.addRow("proba(0~1):", ed_proba) btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) if d.exec_() == QDialog.Accepted: DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)), float(ed_proba.text() or 0.0)) self.load_segments() def edit_selected(self): seg_id = self._get_selected_seg_id() if not seg_id: QMessageBox.information(self, "提示", "请先选择一行片段。") return # 读当前值 row = self.table.currentRow() cur_idx = self.table.item(row, 1).text() cur_label = self.table.item(row, 2).text() cur_proba = self.table.item(row, 3).text() d = QDialog(self) d.setWindowTitle("编辑片段") form = QFormLayout() ed_index = QLineEdit(cur_idx) ed_label = QLineEdit(cur_label) ed_proba = QLineEdit(cur_proba) form.addRow("seg_index:", ed_index) form.addRow("label(0/1):", ed_label) form.addRow("proba(0~1):", ed_proba) btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) btns.accepted.connect(d.accept); btns.rejected.connect(d.reject) v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v) if d.exec_() == QDialog.Accepted: DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)), int(float(ed_label.text() or cur_label)), float(ed_proba.text() or cur_proba)) self.load_segments() def delete_selected(self): seg_id = self._get_selected_seg_id() if not seg_id: QMessageBox.information(self, "提示", "请先选择一行片段。") return if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!", QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes: DB.delete_segment(seg_id) self.load_segments() # ====================== # 主菜单 # ====================== class MainMenuWidget(QWidget): def __init__(self, parent): super().__init__(parent) layout = QVBoxLayout() title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;") layout.addWidget(title) btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record")) btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload")) btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history) for b in [btn1, btn2, btn3]: b.setMinimumHeight(60); layout.addWidget(b) layout.addStretch(1); self.setLayout(layout) # ====================== # 输入界面(录音 / 上传) # ====================== class InputWidget(QWidget): def __init__(self, parent, mode="record"): super().__init__(parent) self.parent, self.mode = parent, mode self.wav_path, self.record_thread = "", None self.device_index = None self.init_ui() def init_ui(self): layout = QVBoxLayout() back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) layout.addWidget(back) if self.mode == "record": self.setup_record(layout) else: self.setup_upload(layout) self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") path_layout = QHBoxLayout() path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) layout.addLayout(path_layout) self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) self.process_btn.clicked.connect(self.start_process) layout.addWidget(self.process_btn) self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) self.setLayout(layout) def setup_record(self, layout): title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) layout.addWidget(title) dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) self.dev_combo = QComboBox(); self.refresh_devices() self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) dev_layout.addWidget(self.dev_combo) btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) def setup_upload(self, layout): h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) b = QPushButton("浏览"); b.clicked.connect(self.browse) h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) def refresh_devices(self): self.dev_combo.clear(); p = pyaudio.PyAudio() for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) if int(info.get('maxInputChannels',0))>0: self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) if self.dev_combo.count()>0: self.dev_combo.setCurrentIndex(0) self.device_index = self.dev_combo.currentData() p.terminate() def start_rec(self,e): if e.button()==Qt.LeftButton: self.record_thread=RecordThread(device_index=self.device_index) self.record_thread.update_signal.connect(self.log.append) self.record_thread.finish_signal.connect(self.finish_rec) self.record_thread.progress_signal.connect(self.progress.setValue) self.record_thread.level_signal.connect(self.level.setValue) self.record_thread.start(); self.start=time.time(); self.timer.start(100) def stop_rec(self,e): if e.button()==Qt.LeftButton and self.record_thread: self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) def browse(self): f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) def start_process(self): if self.mode=="upload": self.wav_path=self.file.text() if not os.path.exists(self.wav_path): self.log.append("无效文件"); return self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) self.thread.update_signal.connect(self.log.append) self.thread.finish_signal.connect(self.parent.switch_to_result) self.thread.start() # ====================== # 历史记录(含 CRUD 按钮) # ====================== class HistoryWidget(QWidget): def __init__(self, parent): super().__init__(parent) self.parent=parent self.init_ui() def init_ui(self): l=QVBoxLayout() back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back) top=QHBoxLayout() self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...") btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search) btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load) btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv) btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run) btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected) btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected) btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments) for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]: top.addWidget(w) l.addLayout(top) self.table=QTableWidget(0,6) self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"]) l.addWidget(self.table) self.setLayout(l) self.load() def _current_run_id(self): r=self.table.currentRow() if r<0: return None return int(self.table.item(r,0).text()) # ---- 数据加载 / 搜索 ---- def load(self): rows=DB.fetch_recent_runs(200) self._fill(rows) def search(self): kw=self.search_edit.text().strip() rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200) self._fill(rows) def _fill(self, rows): self.table.setRowCount(len(rows)) for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows): self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid))) self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at))) self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn))) self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs))) self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob))) self.table.setItem(r_idx, 5, QTableWidgetItem("实心" if int(final_label) else "空心")) self.table.resizeColumnsToContents() # ---- 导出 ---- def export_csv(self): p=DB.export_runs_to_csv() QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}") # ---- CRUD: run ---- def add_run(self): dlg=RunEditDialog(self, run_row=None) if dlg.exec_()==QDialog.Accepted: v=dlg.values() DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) self.load() def edit_selected(self): rid=self._current_run_id() if not rid: QMessageBox.information(self,"提示","请先选择一行记录。") return row=DB.get_run(rid) dlg=RunEditDialog(self, run_row=row) if dlg.exec_()==QDialog.Accepted: v=dlg.values() DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"]) self.load() def delete_selected(self): rid=self._current_run_id() if not rid: QMessageBox.information(self,"提示","请先选择一行记录。") return if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!", QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes: DB.delete_run(rid) self.load() # ---- 片段管理 ---- def open_segments(self): rid=self._current_run_id() if not rid: QMessageBox.information(self,"提示","请先选择一行记录。") return dlg=SegmentManagerDialog(self, run_id=rid) dlg.exec_() # 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主 self.load() # ====================== # 结果页(简) # ====================== class ResultWidget(QWidget): def __init__(self, parent, result=None): super().__init__(parent); self.parent=parent layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b) t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t) area=QTextEdit(); area.setReadOnly(True) if result: txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'实心' if result['final_label'] else '空心'}" area.setText(txt) layout.addWidget(area); self.setLayout(layout) # ====================== # 主窗口 # ====================== class AudioClassifierGUI(QMainWindow): def __init__(self): super().__init__() self.current_input_mode = "record" self.process_result = None self.init_ui() def init_ui(self): self.setWindowTitle("音频分类器") self.setGeometry(100, 100, 1000, 700) self.stacked_widget = QStackedWidget() self.setCentralWidget(self.stacked_widget) self.main_menu_widget = MainMenuWidget(self) self.record_input_widget = InputWidget(self, "record") self.upload_input_widget = InputWidget(self, "upload") self.result_widget = ResultWidget(self) self.history_widget = HistoryWidget(self) self.stacked_widget.addWidget(self.main_menu_widget) # 0 self.stacked_widget.addWidget(self.record_input_widget) # 1 self.stacked_widget.addWidget(self.upload_input_widget) # 2 self.stacked_widget.addWidget(self.result_widget) # 3 self.stacked_widget.addWidget(self.history_widget) # 4 self.stacked_widget.setCurrentWidget(self.main_menu_widget) def switch_to_input(self, mode): self.current_input_mode = mode if mode == "record": self.stacked_widget.setCurrentWidget(self.record_input_widget) else: self.stacked_widget.setCurrentWidget(self.upload_input_widget) def switch_to_main_menu(self): self.stacked_widget.setCurrentWidget(self.main_menu_widget) def switch_to_result(self, result): self.process_result = result self.result_widget = ResultWidget(self, result) self.stacked_widget.removeWidget(self.stacked_widget.widget(3)) self.stacked_widget.addWidget(self.result_widget) self.stacked_widget.setCurrentWidget(self.result_widget) def switch_to_history(self): self.history_widget.load() self.stacked_widget.setCurrentWidget(self.history_widget) # ====================== # 上传/录音 界面(放最后,避免打断阅读) # ====================== class InputWidget(QWidget): def __init__(self, parent, mode="record"): super().__init__(parent) self.parent, self.mode = parent, mode self.wav_path, self.record_thread = "", None self.device_index = None self.init_ui() def init_ui(self): layout = QVBoxLayout() back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu) layout.addWidget(back) if self.mode == "record": self.setup_record(layout) else: self.setup_upload(layout) self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl") path_layout = QHBoxLayout() path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path) path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path) layout.addLayout(path_layout) self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False) self.process_btn.clicked.connect(self.start_process) layout.addWidget(self.process_btn) self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log) self.setLayout(layout) def setup_record(self, layout): title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter) layout.addWidget(title) dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:")) self.dev_combo = QComboBox(); self.refresh_devices() self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData())) dev_layout.addWidget(self.dev_combo) btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices) dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout) self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80) self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s") self.level.setFormat("电平:%p%"); self.progress.setRange(0,100) for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w) self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur) def setup_upload(self, layout): h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True) b = QPushButton("浏览"); b.clicked.connect(self.browse) h.addWidget(self.file); h.addWidget(b); layout.addLayout(h) def refresh_devices(self): self.dev_combo.clear(); p = pyaudio.PyAudio() for i in range(p.get_device_count()): info = p.get_device_info_by_index(i) if int(info.get('maxInputChannels',0))>0: self.dev_combo.addItem(f"[{i}] {info.get('name')}", i) if self.dev_combo.count()>0: self.dev_combo.setCurrentIndex(0) self.device_index = self.dev_combo.currentData() p.terminate() def start_rec(self,e): if e.button()==Qt.LeftButton: self.record_thread=RecordThread(device_index=self.device_index) self.record_thread.update_signal.connect(self.log.append) self.record_thread.finish_signal.connect(self.finish_rec) self.record_thread.progress_signal.connect(self.progress.setValue) self.record_thread.level_signal.connect(self.level.setValue) self.record_thread.start(); self.start=time.time(); self.timer.start(100) def stop_rec(self,e): if e.button()==Qt.LeftButton and self.record_thread: self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0) def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s") def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True) def browse(self): f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)") if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True) def start_process(self): if self.mode=="upload": self.wav_path=self.file.text() if not os.path.exists(self.wav_path): self.log.append("无效文件"); return self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text()) self.thread.update_signal.connect(self.log.append) self.thread.finish_signal.connect(self.parent.switch_to_result) self.thread.start() if __name__ == "__main__": try: import pyaudio except ImportError: print("请先安装pyaudio: pip install pyaudio") sys.exit(1) app = QApplication(sys.argv) window = AudioClassifierGUI() window.show() sys.exit(app.exec_())