|
|
#!/usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
import csv
|
|
|
import sqlite3
|
|
|
from datetime import datetime
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
import librosa
|
|
|
import pyaudio
|
|
|
import wave
|
|
|
from scipy.signal import hilbert
|
|
|
|
|
|
from PyQt5.QtWidgets import (
|
|
|
QApplication, QMainWindow, QPushButton, QLabel,
|
|
|
QLineEdit, QTextEdit, QFileDialog, QVBoxLayout,
|
|
|
QHBoxLayout, QWidget, QProgressBar, QStackedWidget,
|
|
|
QComboBox, QTableWidget, QTableWidgetItem, QMessageBox,
|
|
|
QDialog, QFormLayout, QDialogButtonBox
|
|
|
)
|
|
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 本地 SQLite 数据库(含 CRUD)
|
|
|
# ======================
|
|
|
class DatabaseManager:
|
|
|
def __init__(self, db_path="results.db"):
|
|
|
self.db_path = db_path
|
|
|
self._ensure_schema()
|
|
|
|
|
|
def _connect(self):
|
|
|
return sqlite3.connect(self.db_path)
|
|
|
|
|
|
def _ensure_schema(self):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS runs (
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
filename TEXT,
|
|
|
segments INTEGER,
|
|
|
mean_prob REAL,
|
|
|
final_label INTEGER,
|
|
|
created_at TEXT
|
|
|
);
|
|
|
""")
|
|
|
cur.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS run_segments (
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
run_id INTEGER,
|
|
|
seg_index INTEGER,
|
|
|
label INTEGER,
|
|
|
proba REAL,
|
|
|
FOREIGN KEY(run_id) REFERENCES runs(id)
|
|
|
);
|
|
|
""")
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
|
|
|
# ---------- 原有写入 ----------
|
|
|
def insert_result(self, result_dict):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute(
|
|
|
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
|
|
|
(
|
|
|
result_dict.get("filename"),
|
|
|
int(result_dict.get("segments", 0)),
|
|
|
float(result_dict.get("mean_prob", 0.0)),
|
|
|
int(result_dict.get("final_label", 0)),
|
|
|
datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
)
|
|
|
)
|
|
|
run_id = cur.lastrowid
|
|
|
preds = result_dict.get("predictions", [])
|
|
|
probas = result_dict.get("probabilities", [])
|
|
|
for i, (lab, pr) in enumerate(zip(preds, probas), start=1):
|
|
|
cur.execute(
|
|
|
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
|
|
|
(run_id, i, int(lab), float(pr))
|
|
|
)
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
return run_id
|
|
|
|
|
|
# ---------- R(查询列表/单条) ----------
|
|
|
def fetch_recent_runs(self, limit=50):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
SELECT id, created_at, filename, segments, mean_prob, final_label
|
|
|
FROM runs ORDER BY id DESC LIMIT ?
|
|
|
""", (limit,))
|
|
|
rows = cur.fetchall()
|
|
|
con.close()
|
|
|
return rows
|
|
|
|
|
|
def get_run(self, run_id: int):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
SELECT id, created_at, filename, segments, mean_prob, final_label
|
|
|
FROM runs WHERE id = ?
|
|
|
""", (run_id,))
|
|
|
row = cur.fetchone()
|
|
|
con.close()
|
|
|
return row
|
|
|
|
|
|
def search_runs(self, keyword: str, limit=100):
|
|
|
kw = f"%{keyword}%"
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
SELECT id, created_at, filename, segments, mean_prob, final_label
|
|
|
FROM runs
|
|
|
WHERE filename LIKE ?
|
|
|
ORDER BY id DESC LIMIT ?
|
|
|
""", (kw, limit))
|
|
|
rows = cur.fetchall()
|
|
|
con.close()
|
|
|
return rows
|
|
|
|
|
|
# ---------- C/U/D(run) ----------
|
|
|
def create_run(self, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str = None):
|
|
|
created_at = created_at or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute(
|
|
|
"INSERT INTO runs(filename, segments, mean_prob, final_label, created_at) VALUES (?, ?, ?, ?, ?)",
|
|
|
(filename, int(segments), float(mean_prob), int(final_label), created_at)
|
|
|
)
|
|
|
run_id = cur.lastrowid
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
return run_id
|
|
|
|
|
|
def update_run(self, run_id: int, filename: str, segments: int, mean_prob: float, final_label: int, created_at: str):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
UPDATE runs
|
|
|
SET filename=?, segments=?, mean_prob=?, final_label=?, created_at=?
|
|
|
WHERE id=?
|
|
|
""", (filename, int(segments), float(mean_prob), int(final_label), created_at, int(run_id)))
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
|
|
|
def delete_run(self, run_id: int):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
# 先删子表
|
|
|
cur.execute("DELETE FROM run_segments WHERE run_id=?", (int(run_id),))
|
|
|
cur.execute("DELETE FROM runs WHERE id=?", (int(run_id),))
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
|
|
|
# ---------- C/R/U/D(segments) ----------
|
|
|
def list_segments(self, run_id: int):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
SELECT id, seg_index, label, proba FROM run_segments
|
|
|
WHERE run_id=? ORDER BY seg_index ASC
|
|
|
""", (int(run_id),))
|
|
|
rows = cur.fetchall()
|
|
|
con.close()
|
|
|
return rows
|
|
|
|
|
|
def add_segment(self, run_id: int, seg_index: int, label: int, proba: float):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute(
|
|
|
"INSERT INTO run_segments(run_id, seg_index, label, proba) VALUES (?, ?, ?, ?)",
|
|
|
(int(run_id), int(seg_index), int(label), float(proba))
|
|
|
)
|
|
|
seg_id = cur.lastrowid
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
return seg_id
|
|
|
|
|
|
def update_segment(self, seg_id: int, seg_index: int, label: int, proba: float):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("""
|
|
|
UPDATE run_segments SET seg_index=?, label=?, proba=? WHERE id=?
|
|
|
""", (int(seg_index), int(label), float(proba), int(seg_id)))
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
|
|
|
def delete_segment(self, seg_id: int):
|
|
|
con = self._connect()
|
|
|
cur = con.cursor()
|
|
|
cur.execute("DELETE FROM run_segments WHERE id=?", (int(seg_id),))
|
|
|
con.commit()
|
|
|
con.close()
|
|
|
|
|
|
# ---------- 导出 ----------
|
|
|
def export_runs_to_csv(self, csv_path="results_export.csv", limit=1000):
|
|
|
rows = self.fetch_recent_runs(limit)
|
|
|
headers = ["id", "created_at", "filename", "segments", "mean_prob", "final_label"]
|
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
|
|
writer = csv.writer(f)
|
|
|
writer.writerow(headers)
|
|
|
for r in rows:
|
|
|
writer.writerow(r)
|
|
|
return csv_path
|
|
|
|
|
|
|
|
|
DB = DatabaseManager("results.db")
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 特征提取(同前)
|
|
|
# ======================
|
|
|
class FeatureExtractor:
|
|
|
WIN = 1024
|
|
|
OVERLAP = 512
|
|
|
THRESHOLD = 0.1
|
|
|
SEG_LEN_S = 0.2
|
|
|
STEP = WIN - OVERLAP
|
|
|
|
|
|
@staticmethod
|
|
|
def frame_energy(signal: np.ndarray) -> np.ndarray:
|
|
|
frames = librosa.util.frame(signal, frame_length=FeatureExtractor.WIN, hop_length=FeatureExtractor.STEP)
|
|
|
return np.sum(frames ** 2, axis=0)
|
|
|
|
|
|
@staticmethod
|
|
|
def detect_hits(energy: np.ndarray) -> np.ndarray:
|
|
|
idx = np.where(energy > FeatureExtractor.THRESHOLD)[0]
|
|
|
if idx.size == 0:
|
|
|
return np.array([])
|
|
|
flags = np.diff(np.concatenate(([0], idx))) > 5
|
|
|
return idx[flags]
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_segments(signal: np.ndarray, sr: int, hit_starts: np.ndarray):
|
|
|
seg_len = int(FeatureExtractor.SEG_LEN_S * sr)
|
|
|
segments = []
|
|
|
for frame_idx in hit_starts:
|
|
|
s = frame_idx * FeatureExtractor.STEP
|
|
|
e = min(s + seg_len, len(signal))
|
|
|
segments.append(signal[s:e])
|
|
|
return segments
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_features(x: np.ndarray, sr: int) -> np.ndarray:
|
|
|
x = x.flatten()
|
|
|
if len(x) == 0:
|
|
|
return np.zeros(5)
|
|
|
rms = np.sqrt(np.mean(x ** 2))
|
|
|
fft = np.fft.fft(x)
|
|
|
freq = np.fft.fftfreq(len(x), 1 / sr)
|
|
|
positive = freq >= 0
|
|
|
freq, fft_mag = freq[positive], np.abs(fft[positive])
|
|
|
main_freq = freq[np.argmax(fft_mag)] if len(fft_mag) > 0 else 0
|
|
|
spec_power = fft_mag
|
|
|
centroid = np.sum(freq * spec_power) / (np.sum(spec_power) + 1e-12)
|
|
|
spread = np.sqrt(np.sum(((freq - centroid) ** 2) * spec_power) / (np.sum(spec_power) + 1e-12))
|
|
|
skewness = np.sum(((freq - centroid) ** 3) * spec_power) / (
|
|
|
(np.sum(spec_power) + 1e-12) * (spread ** 3 + 1e-12)) if (spread > 0) else 0
|
|
|
mfcc = librosa.feature.mfcc(y=x, sr=sr, n_mfcc=13)
|
|
|
mfcc_mean = np.mean(mfcc[0]) if mfcc.size > 0 else 0
|
|
|
env_peak = np.max(np.abs(hilbert(x)))
|
|
|
return np.array([rms, main_freq, skewness, mfcc_mean, env_peak])
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 录音线程(同前)
|
|
|
# ======================
|
|
|
class RecordThread(QThread):
|
|
|
update_signal = pyqtSignal(str)
|
|
|
finish_signal = pyqtSignal(str)
|
|
|
progress_signal = pyqtSignal(int)
|
|
|
level_signal = pyqtSignal(int)
|
|
|
|
|
|
def __init__(self, max_duration=60, device_index=None):
|
|
|
super().__init__()
|
|
|
self.max_duration = max_duration
|
|
|
self.is_recording = False
|
|
|
self.temp_file = "temp_recording.wav"
|
|
|
self.device_index = device_index
|
|
|
|
|
|
def run(self):
|
|
|
FORMAT = pyaudio.paInt16
|
|
|
CHANNELS = 1
|
|
|
RATE = 44100
|
|
|
CHUNK = 1024
|
|
|
p, stream = None, None
|
|
|
try:
|
|
|
p = pyaudio.PyAudio()
|
|
|
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True,
|
|
|
input_device_index=self.device_index, frames_per_buffer=CHUNK)
|
|
|
self.is_recording = True
|
|
|
start_time = time.time()
|
|
|
frames = []
|
|
|
max_int16 = 32767.0
|
|
|
while self.is_recording:
|
|
|
elapsed = time.time() - start_time
|
|
|
if elapsed >= self.max_duration:
|
|
|
self.update_signal.emit(f"已达最大时长 {self.max_duration} 秒")
|
|
|
break
|
|
|
self.progress_signal.emit(int((elapsed / self.max_duration) * 100))
|
|
|
data = stream.read(CHUNK, exception_on_overflow=False)
|
|
|
frames.append(data)
|
|
|
chunk_np = np.frombuffer(data, dtype=np.int16).astype(np.float32) / max_int16
|
|
|
rms = np.sqrt(np.mean(chunk_np ** 2))
|
|
|
self.level_signal.emit(int(np.clip(rms * 500, 0, 100)))
|
|
|
stream.stop_stream(); stream.close(); p.terminate()
|
|
|
if frames:
|
|
|
wf = wave.open(self.temp_file, 'wb')
|
|
|
wf.setnchannels(CHANNELS)
|
|
|
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT))
|
|
|
wf.setframerate(RATE)
|
|
|
wf.writeframes(b''.join(frames))
|
|
|
wf.close()
|
|
|
self.finish_signal.emit(self.temp_file)
|
|
|
else:
|
|
|
self.update_signal.emit("未录到有效音频。")
|
|
|
except Exception as e:
|
|
|
self.update_signal.emit(f"录制错误: {e}")
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 推理处理线程(写入 DB)
|
|
|
# ======================
|
|
|
class ProcessThread(QThread):
|
|
|
update_signal = pyqtSignal(str)
|
|
|
finish_signal = pyqtSignal(dict)
|
|
|
|
|
|
def __init__(self, wav_path, model_path, scaler_path):
|
|
|
super().__init__()
|
|
|
self.wav_path, self.model_path, self.scaler_path = wav_path, model_path, scaler_path
|
|
|
|
|
|
def run(self):
|
|
|
try:
|
|
|
self.update_signal.emit("加载模型和标准化器...")
|
|
|
model, scaler = joblib.load(self.model_path), joblib.load(self.scaler_path)
|
|
|
self.update_signal.emit(f"读取音频: {os.path.basename(self.wav_path)}")
|
|
|
sig, sr = librosa.load(self.wav_path, sr=None, mono=True)
|
|
|
sig = sig / max(np.max(np.abs(sig)), 1e-9)
|
|
|
ene = FeatureExtractor.frame_energy(sig)
|
|
|
hits = FeatureExtractor.detect_hits(ene)
|
|
|
segs = FeatureExtractor.extract_segments(sig, sr, hits)
|
|
|
if not segs:
|
|
|
self.update_signal.emit("未检测到有效片段!")
|
|
|
return
|
|
|
feats = [FeatureExtractor.extract_features(seg, sr) for seg in segs]
|
|
|
X = np.vstack(feats)
|
|
|
X_std = scaler.transform(X)
|
|
|
y_pred = model.predict(X_std)
|
|
|
if hasattr(model, "predict_proba"):
|
|
|
y_proba = model.predict_proba(X_std)[:, 1]
|
|
|
else:
|
|
|
y_proba = (y_pred.astype(float) + 0.0)
|
|
|
result = {
|
|
|
"filename": os.path.basename(self.wav_path),
|
|
|
"segments": len(segs),
|
|
|
"predictions": y_pred.tolist(),
|
|
|
"probabilities": [round(float(p), 4) for p in y_proba],
|
|
|
"mean_prob": round(float(np.mean(y_proba)), 4),
|
|
|
"final_label": int(np.mean(y_proba) >= 0.5)
|
|
|
}
|
|
|
DB.insert_result(result)
|
|
|
self.finish_signal.emit(result)
|
|
|
except Exception as e:
|
|
|
self.update_signal.emit(f"错误: {e}")
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 对话框:编辑 run
|
|
|
# ======================
|
|
|
class RunEditDialog(QDialog):
|
|
|
def __init__(self, parent=None, run_row=None):
|
|
|
"""
|
|
|
run_row: (id, created_at, filename, segments, mean_prob, final_label) 或 None(新增)
|
|
|
"""
|
|
|
super().__init__(parent)
|
|
|
self.setWindowTitle("编辑记录" if run_row else "新增记录")
|
|
|
self.resize(380, 240)
|
|
|
self.run_row = run_row
|
|
|
|
|
|
self.ed_id = QLineEdit(); self.ed_id.setReadOnly(True)
|
|
|
self.ed_created = QLineEdit(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
|
|
self.ed_filename = QLineEdit()
|
|
|
self.ed_segments = QLineEdit("0")
|
|
|
self.ed_mean_prob = QLineEdit("0.0")
|
|
|
self.ed_final_label = QLineEdit("0") # 0=实心 1=空心
|
|
|
|
|
|
form = QFormLayout()
|
|
|
form.addRow("ID:", self.ed_id)
|
|
|
form.addRow("时间(created_at):", self.ed_created)
|
|
|
form.addRow("文件名(filename):", self.ed_filename)
|
|
|
form.addRow("片段数(segments):", self.ed_segments)
|
|
|
form.addRow("平均概率(mean_prob):", self.ed_mean_prob)
|
|
|
form.addRow("最终标签(0=实 1=空):", self.ed_final_label)
|
|
|
|
|
|
if run_row:
|
|
|
self.ed_id.setText(str(run_row[0]))
|
|
|
self.ed_created.setText(str(run_row[1]))
|
|
|
self.ed_filename.setText(str(run_row[2]))
|
|
|
self.ed_segments.setText(str(run_row[3]))
|
|
|
self.ed_mean_prob.setText(str(run_row[4]))
|
|
|
self.ed_final_label.setText(str(run_row[5]))
|
|
|
|
|
|
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
|
|
btns.accepted.connect(self.accept)
|
|
|
btns.rejected.connect(self.reject)
|
|
|
|
|
|
lay = QVBoxLayout()
|
|
|
lay.addLayout(form)
|
|
|
lay.addWidget(btns)
|
|
|
self.setLayout(lay)
|
|
|
|
|
|
def values(self):
|
|
|
return dict(
|
|
|
id=(int(self.ed_id.text()) if self.ed_id.text().strip() else None),
|
|
|
created_at=self.ed_created.text().strip(),
|
|
|
filename=self.ed_filename.text().strip(),
|
|
|
segments=int(float(self.ed_segments.text() or 0)),
|
|
|
mean_prob=float(self.ed_mean_prob.text() or 0.0),
|
|
|
final_label=int(float(self.ed_final_label.text() or 0))
|
|
|
)
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 对话框:管理 segments
|
|
|
# ======================
|
|
|
class SegmentManagerDialog(QDialog):
|
|
|
def __init__(self, parent=None, run_id=None):
|
|
|
super().__init__(parent)
|
|
|
self.setWindowTitle(f"片段管理 - run_id={run_id}")
|
|
|
self.resize(560, 420)
|
|
|
self.run_id = run_id
|
|
|
|
|
|
self.table = QTableWidget(0, 4)
|
|
|
self.table.setHorizontalHeaderLabels(["ID", "seg_index", "label", "proba"])
|
|
|
|
|
|
self.btn_add = QPushButton("新增片段")
|
|
|
self.btn_edit = QPushButton("编辑所选")
|
|
|
self.btn_del = QPushButton("删除所选")
|
|
|
self.btn_close = QPushButton("关闭")
|
|
|
|
|
|
hl = QHBoxLayout()
|
|
|
for b in (self.btn_add, self.btn_edit, self.btn_del, self.btn_close):
|
|
|
hl.addWidget(b)
|
|
|
|
|
|
lay = QVBoxLayout()
|
|
|
lay.addWidget(self.table)
|
|
|
lay.addLayout(hl)
|
|
|
self.setLayout(lay)
|
|
|
|
|
|
self.btn_add.clicked.connect(self.add_segment)
|
|
|
self.btn_edit.clicked.connect(self.edit_selected)
|
|
|
self.btn_del.clicked.connect(self.delete_selected)
|
|
|
self.btn_close.clicked.connect(self.accept)
|
|
|
|
|
|
self.load_segments()
|
|
|
|
|
|
def load_segments(self):
|
|
|
rows = DB.list_segments(self.run_id)
|
|
|
self.table.setRowCount(len(rows))
|
|
|
for i, (sid, idx, label, proba) in enumerate(rows):
|
|
|
self.table.setItem(i, 0, QTableWidgetItem(str(sid)))
|
|
|
self.table.setItem(i, 1, QTableWidgetItem(str(idx)))
|
|
|
self.table.setItem(i, 2, QTableWidgetItem(str(label)))
|
|
|
self.table.setItem(i, 3, QTableWidgetItem(str(proba)))
|
|
|
self.table.resizeColumnsToContents()
|
|
|
|
|
|
def _get_selected_seg_id(self):
|
|
|
r = self.table.currentRow()
|
|
|
if r < 0:
|
|
|
return None
|
|
|
return int(self.table.item(r, 0).text())
|
|
|
|
|
|
def add_segment(self):
|
|
|
d = QDialog(self)
|
|
|
d.setWindowTitle("新增片段")
|
|
|
form = QFormLayout()
|
|
|
ed_index = QLineEdit("1")
|
|
|
ed_label = QLineEdit("0")
|
|
|
ed_proba = QLineEdit("0.0")
|
|
|
form.addRow("seg_index:", ed_index)
|
|
|
form.addRow("label(0/1):", ed_label)
|
|
|
form.addRow("proba(0~1):", ed_proba)
|
|
|
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
|
|
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
|
|
|
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
|
|
|
if d.exec_() == QDialog.Accepted:
|
|
|
DB.add_segment(self.run_id, int(float(ed_index.text() or 1)), int(float(ed_label.text() or 0)),
|
|
|
float(ed_proba.text() or 0.0))
|
|
|
self.load_segments()
|
|
|
|
|
|
def edit_selected(self):
|
|
|
seg_id = self._get_selected_seg_id()
|
|
|
if not seg_id:
|
|
|
QMessageBox.information(self, "提示", "请先选择一行片段。")
|
|
|
return
|
|
|
# 读当前值
|
|
|
row = self.table.currentRow()
|
|
|
cur_idx = self.table.item(row, 1).text()
|
|
|
cur_label = self.table.item(row, 2).text()
|
|
|
cur_proba = self.table.item(row, 3).text()
|
|
|
|
|
|
d = QDialog(self)
|
|
|
d.setWindowTitle("编辑片段")
|
|
|
form = QFormLayout()
|
|
|
ed_index = QLineEdit(cur_idx)
|
|
|
ed_label = QLineEdit(cur_label)
|
|
|
ed_proba = QLineEdit(cur_proba)
|
|
|
form.addRow("seg_index:", ed_index)
|
|
|
form.addRow("label(0/1):", ed_label)
|
|
|
form.addRow("proba(0~1):", ed_proba)
|
|
|
btns = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
|
|
btns.accepted.connect(d.accept); btns.rejected.connect(d.reject)
|
|
|
v = QVBoxLayout(); v.addLayout(form); v.addWidget(btns); d.setLayout(v)
|
|
|
if d.exec_() == QDialog.Accepted:
|
|
|
DB.update_segment(seg_id, int(float(ed_index.text() or cur_idx)),
|
|
|
int(float(ed_label.text() or cur_label)),
|
|
|
float(ed_proba.text() or cur_proba))
|
|
|
self.load_segments()
|
|
|
|
|
|
def delete_selected(self):
|
|
|
seg_id = self._get_selected_seg_id()
|
|
|
if not seg_id:
|
|
|
QMessageBox.information(self, "提示", "请先选择一行片段。")
|
|
|
return
|
|
|
if QMessageBox.question(self, "确认", "确认删除该片段?此操作不可恢复!",
|
|
|
QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes:
|
|
|
DB.delete_segment(seg_id)
|
|
|
self.load_segments()
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 主菜单
|
|
|
# ======================
|
|
|
class MainMenuWidget(QWidget):
|
|
|
def __init__(self, parent):
|
|
|
super().__init__(parent)
|
|
|
layout = QVBoxLayout()
|
|
|
title = QLabel("音频分类器"); title.setAlignment(Qt.AlignCenter)
|
|
|
title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 20px;")
|
|
|
layout.addWidget(title)
|
|
|
btn1 = QPushButton("采集音频"); btn1.clicked.connect(lambda: parent.switch_to_input("record"))
|
|
|
btn2 = QPushButton("上传WAV文件"); btn2.clicked.connect(lambda: parent.switch_to_input("upload"))
|
|
|
btn3 = QPushButton("历史记录 / 导出 / CRUD"); btn3.clicked.connect(parent.switch_to_history)
|
|
|
for b in [btn1, btn2, btn3]:
|
|
|
b.setMinimumHeight(60); layout.addWidget(b)
|
|
|
layout.addStretch(1); self.setLayout(layout)
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 输入界面(录音 / 上传)
|
|
|
# ======================
|
|
|
class InputWidget(QWidget):
|
|
|
def __init__(self, parent, mode="record"):
|
|
|
super().__init__(parent)
|
|
|
self.parent, self.mode = parent, mode
|
|
|
self.wav_path, self.record_thread = "", None
|
|
|
self.device_index = None
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
|
|
|
layout.addWidget(back)
|
|
|
if self.mode == "record": self.setup_record(layout)
|
|
|
else: self.setup_upload(layout)
|
|
|
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
|
|
|
path_layout = QHBoxLayout()
|
|
|
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
|
|
|
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
|
|
|
layout.addLayout(path_layout)
|
|
|
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
|
|
|
self.process_btn.clicked.connect(self.start_process)
|
|
|
layout.addWidget(self.process_btn)
|
|
|
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def setup_record(self, layout):
|
|
|
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
|
|
|
layout.addWidget(title)
|
|
|
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
|
|
|
self.dev_combo = QComboBox(); self.refresh_devices()
|
|
|
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
|
|
|
dev_layout.addWidget(self.dev_combo)
|
|
|
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
|
|
|
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
|
|
|
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
|
|
|
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
|
|
|
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
|
|
|
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
|
|
|
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
|
|
|
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
|
|
|
|
|
|
def setup_upload(self, layout):
|
|
|
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
|
|
|
b = QPushButton("浏览"); b.clicked.connect(self.browse)
|
|
|
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
|
|
|
|
|
|
def refresh_devices(self):
|
|
|
self.dev_combo.clear(); p = pyaudio.PyAudio()
|
|
|
for i in range(p.get_device_count()):
|
|
|
info = p.get_device_info_by_index(i)
|
|
|
if int(info.get('maxInputChannels',0))>0:
|
|
|
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
|
|
|
if self.dev_combo.count()>0:
|
|
|
self.dev_combo.setCurrentIndex(0)
|
|
|
self.device_index = self.dev_combo.currentData()
|
|
|
p.terminate()
|
|
|
|
|
|
def start_rec(self,e):
|
|
|
if e.button()==Qt.LeftButton:
|
|
|
self.record_thread=RecordThread(device_index=self.device_index)
|
|
|
self.record_thread.update_signal.connect(self.log.append)
|
|
|
self.record_thread.finish_signal.connect(self.finish_rec)
|
|
|
self.record_thread.progress_signal.connect(self.progress.setValue)
|
|
|
self.record_thread.level_signal.connect(self.level.setValue)
|
|
|
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
|
|
|
|
|
|
def stop_rec(self,e):
|
|
|
if e.button()==Qt.LeftButton and self.record_thread:
|
|
|
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
|
|
|
|
|
|
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
|
|
|
|
|
|
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
|
|
|
|
|
|
def browse(self):
|
|
|
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
|
|
|
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
|
|
|
|
|
|
def start_process(self):
|
|
|
if self.mode=="upload": self.wav_path=self.file.text()
|
|
|
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
|
|
|
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
|
|
|
self.thread.update_signal.connect(self.log.append)
|
|
|
self.thread.finish_signal.connect(self.parent.switch_to_result)
|
|
|
self.thread.start()
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 历史记录(含 CRUD 按钮)
|
|
|
# ======================
|
|
|
class HistoryWidget(QWidget):
|
|
|
def __init__(self, parent):
|
|
|
super().__init__(parent)
|
|
|
self.parent=parent
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
l=QVBoxLayout()
|
|
|
back=QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu); l.addWidget(back)
|
|
|
|
|
|
top=QHBoxLayout()
|
|
|
self.search_edit=QLineEdit(); self.search_edit.setPlaceholderText("按文件名搜索...")
|
|
|
btn_search=QPushButton("搜索"); btn_search.clicked.connect(self.search)
|
|
|
btn_refresh=QPushButton("刷新"); btn_refresh.clicked.connect(self.load)
|
|
|
btn_export=QPushButton("导出CSV"); btn_export.clicked.connect(self.export_csv)
|
|
|
btn_add=QPushButton("新增记录"); btn_add.clicked.connect(self.add_run)
|
|
|
btn_edit=QPushButton("编辑所选"); btn_edit.clicked.connect(self.edit_selected)
|
|
|
btn_del=QPushButton("删除所选"); btn_del.clicked.connect(self.delete_selected)
|
|
|
btn_seg=QPushButton("片段管理"); btn_seg.clicked.connect(self.open_segments)
|
|
|
for w in [self.search_edit, btn_search, btn_refresh, btn_export, btn_add, btn_edit, btn_del, btn_seg]:
|
|
|
top.addWidget(w)
|
|
|
l.addLayout(top)
|
|
|
|
|
|
self.table=QTableWidget(0,6)
|
|
|
self.table.setHorizontalHeaderLabels(["ID","时间","文件名","片段数","平均概率","最终标签"])
|
|
|
l.addWidget(self.table)
|
|
|
self.setLayout(l)
|
|
|
self.load()
|
|
|
|
|
|
def _current_run_id(self):
|
|
|
r=self.table.currentRow()
|
|
|
if r<0: return None
|
|
|
return int(self.table.item(r,0).text())
|
|
|
|
|
|
# ---- 数据加载 / 搜索 ----
|
|
|
def load(self):
|
|
|
rows=DB.fetch_recent_runs(200)
|
|
|
self._fill(rows)
|
|
|
|
|
|
def search(self):
|
|
|
kw=self.search_edit.text().strip()
|
|
|
rows=DB.search_runs(kw, limit=200) if kw else DB.fetch_recent_runs(200)
|
|
|
self._fill(rows)
|
|
|
|
|
|
def _fill(self, rows):
|
|
|
self.table.setRowCount(len(rows))
|
|
|
for r_idx, (rid, created_at, fn, segs, mean_prob, final_label) in enumerate(rows):
|
|
|
self.table.setItem(r_idx, 0, QTableWidgetItem(str(rid)))
|
|
|
self.table.setItem(r_idx, 1, QTableWidgetItem(str(created_at)))
|
|
|
self.table.setItem(r_idx, 2, QTableWidgetItem(str(fn)))
|
|
|
self.table.setItem(r_idx, 3, QTableWidgetItem(str(segs)))
|
|
|
self.table.setItem(r_idx, 4, QTableWidgetItem(str(mean_prob)))
|
|
|
self.table.setItem(r_idx, 5, QTableWidgetItem("空心" if int(final_label) else "实心"))
|
|
|
self.table.resizeColumnsToContents()
|
|
|
|
|
|
# ---- 导出 ----
|
|
|
def export_csv(self):
|
|
|
p=DB.export_runs_to_csv()
|
|
|
QMessageBox.information(self,"导出完成",f"已保存到 {os.path.abspath(p)}")
|
|
|
|
|
|
# ---- CRUD: run ----
|
|
|
def add_run(self):
|
|
|
dlg=RunEditDialog(self, run_row=None)
|
|
|
if dlg.exec_()==QDialog.Accepted:
|
|
|
v=dlg.values()
|
|
|
DB.create_run(v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
|
|
|
self.load()
|
|
|
|
|
|
def edit_selected(self):
|
|
|
rid=self._current_run_id()
|
|
|
if not rid:
|
|
|
QMessageBox.information(self,"提示","请先选择一行记录。")
|
|
|
return
|
|
|
row=DB.get_run(rid)
|
|
|
dlg=RunEditDialog(self, run_row=row)
|
|
|
if dlg.exec_()==QDialog.Accepted:
|
|
|
v=dlg.values()
|
|
|
DB.update_run(rid, v["filename"], v["segments"], v["mean_prob"], v["final_label"], v["created_at"])
|
|
|
self.load()
|
|
|
|
|
|
def delete_selected(self):
|
|
|
rid=self._current_run_id()
|
|
|
if not rid:
|
|
|
QMessageBox.information(self,"提示","请先选择一行记录。")
|
|
|
return
|
|
|
if QMessageBox.question(self,"确认","确认删除所选记录及其所有片段?此操作不可恢复!",
|
|
|
QMessageBox.Yes|QMessageBox.No)==QMessageBox.Yes:
|
|
|
DB.delete_run(rid)
|
|
|
self.load()
|
|
|
|
|
|
# ---- 片段管理 ----
|
|
|
def open_segments(self):
|
|
|
rid=self._current_run_id()
|
|
|
if not rid:
|
|
|
QMessageBox.information(self,"提示","请先选择一行记录。")
|
|
|
return
|
|
|
dlg=SegmentManagerDialog(self, run_id=rid)
|
|
|
dlg.exec_()
|
|
|
# 片段变更后可按需同步更新片段数/平均概率,此处保持手动编辑为主
|
|
|
self.load()
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 结果页(简)
|
|
|
# ======================
|
|
|
class ResultWidget(QWidget):
|
|
|
def __init__(self, parent, result=None):
|
|
|
super().__init__(parent); self.parent=parent
|
|
|
layout=QVBoxLayout(); b=QPushButton("返回"); b.clicked.connect(self.parent.switch_to_main_menu); layout.addWidget(b)
|
|
|
t=QLabel("处理结果"); t.setAlignment(Qt.AlignCenter); layout.addWidget(t)
|
|
|
area=QTextEdit(); area.setReadOnly(True)
|
|
|
if result:
|
|
|
txt=f"文件:{result['filename']}\n片段:{result['segments']}\n平均概率:{result['mean_prob']}\n最终:{'空心' if result['final_label'] else '实心'}"
|
|
|
area.setText(txt)
|
|
|
layout.addWidget(area); self.setLayout(layout)
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 主窗口
|
|
|
# ======================
|
|
|
class AudioClassifierGUI(QMainWindow):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.current_input_mode = "record"
|
|
|
self.process_result = None
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
self.setWindowTitle("音频分类器")
|
|
|
self.setGeometry(100, 100, 1000, 700)
|
|
|
|
|
|
self.stacked_widget = QStackedWidget()
|
|
|
self.setCentralWidget(self.stacked_widget)
|
|
|
|
|
|
self.main_menu_widget = MainMenuWidget(self)
|
|
|
self.record_input_widget = InputWidget(self, "record")
|
|
|
self.upload_input_widget = InputWidget(self, "upload")
|
|
|
self.result_widget = ResultWidget(self)
|
|
|
self.history_widget = HistoryWidget(self)
|
|
|
|
|
|
self.stacked_widget.addWidget(self.main_menu_widget) # 0
|
|
|
self.stacked_widget.addWidget(self.record_input_widget) # 1
|
|
|
self.stacked_widget.addWidget(self.upload_input_widget) # 2
|
|
|
self.stacked_widget.addWidget(self.result_widget) # 3
|
|
|
self.stacked_widget.addWidget(self.history_widget) # 4
|
|
|
|
|
|
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
|
|
|
|
|
|
def switch_to_input(self, mode):
|
|
|
self.current_input_mode = mode
|
|
|
if mode == "record":
|
|
|
self.stacked_widget.setCurrentWidget(self.record_input_widget)
|
|
|
else:
|
|
|
self.stacked_widget.setCurrentWidget(self.upload_input_widget)
|
|
|
|
|
|
def switch_to_main_menu(self):
|
|
|
self.stacked_widget.setCurrentWidget(self.main_menu_widget)
|
|
|
|
|
|
def switch_to_result(self, result):
|
|
|
self.process_result = result
|
|
|
self.result_widget = ResultWidget(self, result)
|
|
|
self.stacked_widget.removeWidget(self.stacked_widget.widget(3))
|
|
|
self.stacked_widget.addWidget(self.result_widget)
|
|
|
self.stacked_widget.setCurrentWidget(self.result_widget)
|
|
|
|
|
|
def switch_to_history(self):
|
|
|
self.history_widget.load()
|
|
|
self.stacked_widget.setCurrentWidget(self.history_widget)
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
# 上传/录音 界面(放最后,避免打断阅读)
|
|
|
# ======================
|
|
|
class InputWidget(QWidget):
|
|
|
def __init__(self, parent, mode="record"):
|
|
|
super().__init__(parent)
|
|
|
self.parent, self.mode = parent, mode
|
|
|
self.wav_path, self.record_thread = "", None
|
|
|
self.device_index = None
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
layout = QVBoxLayout()
|
|
|
back = QPushButton("返回"); back.clicked.connect(self.parent.switch_to_main_menu)
|
|
|
layout.addWidget(back)
|
|
|
if self.mode == "record": self.setup_record(layout)
|
|
|
else: self.setup_upload(layout)
|
|
|
self.model_path, self.scaler_path = QLineEdit("svm_model.pkl"), QLineEdit("scaler.pkl")
|
|
|
path_layout = QHBoxLayout()
|
|
|
path_layout.addWidget(QLabel("模型:")); path_layout.addWidget(self.model_path)
|
|
|
path_layout.addWidget(QLabel("标准化器:")); path_layout.addWidget(self.scaler_path)
|
|
|
layout.addLayout(path_layout)
|
|
|
self.process_btn = QPushButton("开始处理"); self.process_btn.setEnabled(False)
|
|
|
self.process_btn.clicked.connect(self.start_process)
|
|
|
layout.addWidget(self.process_btn)
|
|
|
self.log = QTextEdit(); self.log.setReadOnly(True); layout.addWidget(self.log)
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def setup_record(self, layout):
|
|
|
title = QLabel("音频采集"); title.setAlignment(Qt.AlignCenter)
|
|
|
layout.addWidget(title)
|
|
|
dev_layout = QHBoxLayout(); dev_layout.addWidget(QLabel("设备:"))
|
|
|
self.dev_combo = QComboBox(); self.refresh_devices()
|
|
|
self.dev_combo.currentIndexChanged.connect(lambda: setattr(self, "device_index", self.dev_combo.currentData()))
|
|
|
dev_layout.addWidget(self.dev_combo)
|
|
|
btn_ref = QPushButton("刷新设备"); btn_ref.clicked.connect(self.refresh_devices)
|
|
|
dev_layout.addWidget(btn_ref); layout.addLayout(dev_layout)
|
|
|
self.rec_btn = QPushButton("按住录音"); self.rec_btn.setMinimumHeight(80)
|
|
|
self.rec_btn.mousePressEvent = self.start_rec; self.rec_btn.mouseReleaseEvent = self.stop_rec
|
|
|
self.level, self.progress, self.dur = QProgressBar(), QProgressBar(), QLabel("0s")
|
|
|
self.level.setFormat("电平:%p%"); self.progress.setRange(0,100)
|
|
|
for w in [self.rec_btn,self.level,self.progress,self.dur]: layout.addWidget(w)
|
|
|
self.timer = QTimer(self); self.timer.timeout.connect(self.update_dur)
|
|
|
|
|
|
def setup_upload(self, layout):
|
|
|
h = QHBoxLayout(); self.file = QLineEdit(); self.file.setReadOnly(True)
|
|
|
b = QPushButton("浏览"); b.clicked.connect(self.browse)
|
|
|
h.addWidget(self.file); h.addWidget(b); layout.addLayout(h)
|
|
|
|
|
|
def refresh_devices(self):
|
|
|
self.dev_combo.clear(); p = pyaudio.PyAudio()
|
|
|
for i in range(p.get_device_count()):
|
|
|
info = p.get_device_info_by_index(i)
|
|
|
if int(info.get('maxInputChannels',0))>0:
|
|
|
self.dev_combo.addItem(f"[{i}] {info.get('name')}", i)
|
|
|
if self.dev_combo.count()>0:
|
|
|
self.dev_combo.setCurrentIndex(0)
|
|
|
self.device_index = self.dev_combo.currentData()
|
|
|
p.terminate()
|
|
|
|
|
|
def start_rec(self,e):
|
|
|
if e.button()==Qt.LeftButton:
|
|
|
self.record_thread=RecordThread(device_index=self.device_index)
|
|
|
self.record_thread.update_signal.connect(self.log.append)
|
|
|
self.record_thread.finish_signal.connect(self.finish_rec)
|
|
|
self.record_thread.progress_signal.connect(self.progress.setValue)
|
|
|
self.record_thread.level_signal.connect(self.level.setValue)
|
|
|
self.record_thread.start(); self.start=time.time(); self.timer.start(100)
|
|
|
|
|
|
def stop_rec(self,e):
|
|
|
if e.button()==Qt.LeftButton and self.record_thread:
|
|
|
self.record_thread.is_recording=False; self.timer.stop(); self.progress.setValue(0); self.level.setValue(0)
|
|
|
|
|
|
def update_dur(self): self.dur.setText(f"{time.time()-self.start:.1f}s")
|
|
|
|
|
|
def finish_rec(self,file): self.wav_path=file; self.process_btn.setEnabled(True)
|
|
|
|
|
|
def browse(self):
|
|
|
f,_=QFileDialog.getOpenFileName(self,"选择WAV","","WAV文件 (*.wav)")
|
|
|
if f: self.file.setText(f); self.wav_path=f; self.process_btn.setEnabled(True)
|
|
|
|
|
|
def start_process(self):
|
|
|
if self.mode=="upload": self.wav_path=self.file.text()
|
|
|
if not os.path.exists(self.wav_path): self.log.append("无效文件"); return
|
|
|
self.thread=ProcessThread(self.wav_path,self.model_path.text(),self.scaler_path.text())
|
|
|
self.thread.update_signal.connect(self.log.append)
|
|
|
self.thread.finish_signal.connect(self.parent.switch_to_result)
|
|
|
self.thread.start()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
try:
|
|
|
import pyaudio
|
|
|
except ImportError:
|
|
|
print("请先安装pyaudio: pip install pyaudio")
|
|
|
sys.exit(1)
|
|
|
|
|
|
app = QApplication(sys.argv)
|
|
|
window = AudioClassifierGUI()
|
|
|
window.show()
|
|
|
sys.exit(app.exec_())
|