Compare commits

...

21 Commits

Binary file not shown.

@ -1,135 +0,0 @@
#include "FFVideoFormatConvert.h"
#include "VideoObjNetwork.h"
CFFVideoFormatConvert::CFFVideoFormatConvert(void)
: m_img_convert_ctx(NULL)
, m_pFrame(NULL)
, m_pBuffer(NULL), m_uBufferSize(0)
, m_iWidth(0), m_iHeight(0)
, m_pImage(NULL)
{
}
CFFVideoFormatConvert::~CFFVideoFormatConvert(void)
{
Close();
}
void CFFVideoFormatConvert::Close()
{
if (m_img_convert_ctx)
{
sws_freeContext(m_img_convert_ctx);
m_img_convert_ctx = NULL;
}
if (m_pFrame)
{
av_frame_free(&m_pFrame);
m_pFrame = NULL;
}
if (m_pBuffer)
{
av_free(m_pBuffer);
m_pBuffer = NULL;
m_uBufferSize = 0;
}
if (m_pImage != NULL)
{
delete m_pImage;
m_pImage = NULL;
}
m_iWidth = 0;
m_iHeight = 0;
}
bool CFFVideoFormatConvert::RGB32toYUV420P(const QImage* pIn, AVFrame** pOut)
{
// reinitialize object if image width or height changed
if (pIn->width() != m_iWidth || pIn->height() != m_iHeight)
{
Close();
}
if (m_pBuffer == NULL)
{
m_iWidth = pIn->width();
m_iHeight = pIn->height();
m_uBufferSize = av_image_get_buffer_size(AV_PIX_FMT_YUV420P, pIn->width(), pIn->height(), 1);
m_pBuffer = (uint8_t *) av_malloc(m_uBufferSize);
}
if (m_pFrame == NULL)
{
m_pFrame = av_frame_alloc();
m_pFrame->width = pIn->width();
m_pFrame->height = pIn->height();
av_image_fill_arrays(m_pFrame->data, m_pFrame->linesize,
m_pBuffer, AV_PIX_FMT_YUV420P, pIn->width(), pIn->height(), 1);
}
if (m_img_convert_ctx == NULL)
{
m_img_convert_ctx = sws_getContext(pIn->width(), pIn->height(),
AV_PIX_FMT_RGB32,
pIn->width(), pIn->height(),
AV_PIX_FMT_YUV420P, SWS_BICUBIC, NULL, NULL, NULL);
}
const uint8_t *const srcSlice[] = { pIn->bits() };
const int srcStride[] = { pIn->bytesPerLine()};
sws_scale(m_img_convert_ctx,
srcSlice,
srcStride, 0, pIn->height(),
m_pFrame->data,
m_pFrame->linesize);
*pOut = m_pFrame;
return true;
}
bool CFFVideoFormatConvert::YUV420P2RGB32(const AVFrame* pIn, QImage** pOut)
{
// reinitialize object if image width or height changed
if (pIn->width != m_iWidth || pIn->height != m_iHeight)
{
Close();
}
if (m_pBuffer == NULL)
{
m_iWidth = pIn->width;
m_iHeight = pIn->height;
m_uBufferSize = av_image_get_buffer_size(AV_PIX_FMT_RGB32, pIn->width, pIn->height, 1);
m_pBuffer = (uint8_t *)av_malloc(m_uBufferSize);
}
if (m_pFrame == NULL)
{
m_pFrame = av_frame_alloc();
av_image_fill_arrays(m_pFrame->data, m_pFrame->linesize,
m_pBuffer, AV_PIX_FMT_RGB32, pIn->width, pIn->height, 1);
}
if (m_img_convert_ctx == NULL)
{
m_img_convert_ctx = sws_getContext(pIn->width, pIn->height,
AV_PIX_FMT_YUV420P,
pIn->width, pIn->height,
AV_PIX_FMT_RGB32, SWS_BICUBIC, NULL, NULL, NULL);
}
sws_scale(m_img_convert_ctx,
(uint8_t const * const *)pIn->data,
pIn->linesize, 0, pIn->height,
m_pFrame->data,
m_pFrame->linesize);
*pOut = new QImage((uchar *)m_pFrame->data[0], m_iWidth, m_iHeight, QImage::Format_RGB32);
return true;
}

@ -1,28 +0,0 @@
#pragma once
#include <QImage>
struct AVFrame;
struct SwsContext;
class CFFVideoFormatConvert
{
public:
CFFVideoFormatConvert(void);
~CFFVideoFormatConvert(void);
bool RGB32toYUV420P(const QImage* pIn, AVFrame** pOut);
bool YUV420P2RGB32(const AVFrame* pIn, QImage** pOut);
private:
void Close();
private:
SwsContext* m_img_convert_ctx;
AVFrame* m_pFrame;
uint8_t* m_pBuffer;
uint m_uBufferSize;
int m_iWidth;
int m_iHeight;
QImage* m_pImage;
};

@ -1,70 +0,0 @@
#ifndef QTCAMERACAPTURE_H
#define QTCAMERACAPTURE_H
#include <QObject>
#include <QAbstractVideoSurface>
#include <QDebug>
class QtCameraCapture : public QAbstractVideoSurface
{
Q_OBJECT
public:
enum PixelFormat {
Format_Invalid,
Format_ARGB32,
Format_ARGB32_Premultiplied,
Format_RGB32,
Format_RGB24,
Format_RGB565,
Format_RGB555,
Format_ARGB8565_Premultiplied,
Format_BGRA32,
Format_BGRA32_Premultiplied,
Format_BGR32,
Format_BGR24,
Format_BGR565,
Format_BGR555,
Format_BGRA5658_Premultiplied,
Format_AYUV444,
Format_AYUV444_Premultiplied,
Format_YUV444,
Format_YUV420P,
Format_YV12,
Format_UYVY,
Format_YUYV,
Format_NV12,
Format_NV21,
Format_IMC1,
Format_IMC2,
Format_IMC3,
Format_IMC4,
Format_Y8,
Format_Y16,
Format_Jpeg,
Format_CameraRaw,
Format_AdobeDng,
#ifndef Q_QDOC
NPixelFormats,
#endif
Format_User = 1000
};
Q_ENUM(PixelFormat)
explicit QtCameraCapture(QObject *parent = 0);
QList<QVideoFrame::PixelFormat> supportedPixelFormats(
QAbstractVideoBuffer::HandleType handleType = QAbstractVideoBuffer::NoHandle) const;
bool present(const QVideoFrame &frame) override;
signals:
void frameAvailable(QImage frame);
};
#endif // QTCAMERACAPTURE_H

@ -1,36 +0,0 @@
#pragma once
#define GET_STR(x) #x
#define A_VER 3
#define T_VER 4
// vertex shader
const char *vString = GET_STR(
attribute vec4 vertexIn;
attribute vec2 textureIn;
varying vec2 textureOut;
void main(void)
{
gl_Position = vertexIn;
textureOut = textureIn;
}
);
// texture shader
const char *tString = GET_STR(
varying vec2 textureOut;
uniform sampler2D tex_y;
uniform sampler2D tex_u;
uniform sampler2D tex_v;
void main(void)
{
vec3 yuv;
vec3 rgb;
yuv.x = texture2D(tex_y, textureOut).r;
yuv.y = texture2D(tex_u, textureOut).r - 0.5;
yuv.z = texture2D(tex_v, textureOut).r - 0.5;
rgb = mat3(1.0, 1.0, 1.0,
0.0, -0.39465, 2.03211,
1.13983, -0.58060, 0.0) * yuv;
gl_FragColor = vec4(rgb, 1.0);
}
);

@ -1,66 +0,0 @@
#pragma once
#include <string>
#include <mutex>
#include <thread>
extern "C"
{
#include "libavcodec/avcodec.h"
#include "libavformat/avformat.h"
#include "libavutil/avutil.h"
#include "libswscale/swscale.h"
#include "libavutil/imgutils.h"
};
typedef void (*VideoDataCallback)(int iEncode, int iWidth, int iHeight, const char* pData, long lLen, long lPTS, void* pUserParam);
class VLKVideoWidget;
class CVideoObjNetwork
{
public:
CVideoObjNetwork();
virtual ~CVideoObjNetwork();
virtual bool Open(const std::string& strURL, VLKVideoWidget* pVideoWidget);
virtual bool IsOpen();
void SetDataCallback(VideoDataCallback pVideoDataCB, long lUserParam);
virtual void Clear();
virtual void Close();
virtual bool StartLocalRecord();
virtual void StopLocalRecord();
virtual void Capture();
private:
static void ThreadFunc(CVideoObjNetwork* pThis);
virtual void OnThreadFunc();
bool OpenDemux(const std::string& strURL);
void CloseDemux();
void WriteLocalRecord(const AVPacket* pkt);
static int interrupt_callback(void* para);
void ReadPacketLoop();
bool OpenDecoder(const AVCodecParameters *para);
void Send2Decode(const AVPacket* pkt);
void Send2Display(const AVFrame* frame);
void CloseDecoder();
private:
static bool m_bInit;
std::mutex m_mutex;
std::string m_strURL;
VLKVideoWidget* m_pVideoWidget;
AVFormatContext* m_pAVFmtContext;
int m_iVideoStreamIndex;
int m_iAudioStreamIndex;
int m_iWidth;
int m_iHeight;
std::thread* m_pThread;
bool m_bExit;
VideoDataCallback m_cbFunc;
long m_lUserParam;
AVCodecContext* m_pCodecContext;
};

@ -1,27 +0,0 @@
#ifndef IMAGEPREVIEWDIALOG_H
#define IMAGEPREVIEWDIALOG_H
#include <QDialog>
#include <QLabel>
#include <QVBoxLayout>
#include <QPushButton>
#include <QScrollArea>
class ImagePreviewDialog : public QDialog
{
Q_OBJECT
public:
explicit ImagePreviewDialog(const QString &imagePath, QWidget *parent = nullptr);
~ImagePreviewDialog();
private:
QLabel *m_imageLabel;
QScrollArea *m_scrollArea;
QPushButton *m_closeButton;
void setupUi();
void loadImage(const QString &imagePath);
};
#endif // IMAGEPREVIEWDIALOG_H

@ -1,28 +0,0 @@
#include "widget.h"
#include <QApplication>
#include <QDebug>
int main(int argc, char *argv[])
{
QApplication a(argc, argv);
// print SDK Version
qDebug() << "ViewLink SDK Version: " << GetSDKVersion();
// initialize SDK
VLK_Init();
Widget w;
w.show();
int ret = a.exec();
// diconnect all
VLK_Disconnect();
// uninitialize SDK
VLK_UnInit();
return ret;
}

@ -1,247 +0,0 @@
#ifndef WIDGET_H
#define WIDGET_H
#include <QWidget>
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QJsonDocument>
#include <QJsonObject>
#include <QTimer>
#include <QLabel>
#include <QGraphicsScene>
#include <QGraphicsView>
#include <QGraphicsItem>
#include <QGeoCoordinate>
#include <QPushButton>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <netinet/in.h>
#include "ViewLink.h"
#include "VideoObjNetwork.h"
#include "VideoObjUSBCamera.h"
#include "../identification system/src/image_processor.h"
#include "../identification system/src/zhipu_api.h"
// 替换无人机连接类使用workproject下的MAVLinkClient
#include "../workproject/include/mavlink_client.h"
#include "MapWidget.h"
namespace Ui {
class Widget;
}
class Widget : public QWidget
{
Q_OBJECT
public:
explicit Widget(QWidget *parent = 0);
~Widget();
private slots:
void on_btnConnectTCP_clicked();
void on_btnConnctSerialPort_clicked();
private:
static int VLK_ConnStatusCallback(int iConnStatus, const char* szMessage, int iMsgLen, void* pUserParam);
static int VLK_DevStatusCallback(int iType, const char* szBuffer, int iBufLen, void* pUserParam);
signals:
void SignalConnectionStatus(int iConnStatus, const QString& strMessage);
void SignalDeviceModel(VLK_DEV_MODEL model);
void SignalDeviceConfig(VLK_DEV_CONFIG config);
void SignalDeviceTelemetry(VLK_DEV_TELEMETRY telemetry);
private slots:
void onSlotConnectionStatus(int iConnStatus, const QString& strMessage);
void onSlotDeviceModel(VLK_DEV_MODEL model);
void onSlotDeviceConfig(VLK_DEV_CONFIG config);
void onSlotDeviceTelemetry(VLK_DEV_TELEMETRY telemetry);
void on_btnUp_pressed();
void on_btnUp_released();
void on_btnLeft_pressed();
void on_btnLeft_released();
void on_btnHome_clicked();
void on_btnRight_pressed();
void on_btnRight_released();
void on_btnDown_pressed();
void on_btnDown_released();
void on_cmbImageSensor_activated(int index);
void on_cmbIRColor_activated(int index);
void on_checkBoxPIP_clicked(bool checked);
void on_btnOpenNetworkVideo_clicked();
void on_btnOpenUSBVideo_clicked();
void on_btnZoomIn_pressed();
void on_btnZoomIn_released();
void on_btnZoomOut_pressed();
void on_btnZoomOut_released();
void on_btnGimbalTakePhoto_clicked();
void on_btnStartRecord_clicked();
void on_btnStopRecord_clicked();
void on_btnIdentifyTarget_clicked();
void on_btnClearIdentification_clicked();
void on_sliderConfidence_valueChanged(int value);
// u6dfbu52a0u9884u5904u7406u590du9009u6846u72b6u6001u53d8u5316u7684u69fdu51fdu6570
void on_checkBoxPreprocess_stateChanged(int state);
// 无人机控制相关槽函数
void on_btnConnectUAV_clicked();
void on_btnArmDisarm_clicked();
void on_btnTakeoff_clicked();
void on_btnLand_clicked();
void on_btnRTL_clicked();
void on_btnMode_clicked();
// 添加切换地图类型的槽函数
void on_btnSwitchMapType_clicked();
// 添加声源定位相关槽函数
void on_btnStartSoundLocator_clicked();
void on_btnStopSoundLocator_clicked();
void onSoundDataReceived();
void updateSoundVisualization(double x, double y, double strength, double angle);
// 添加显示目标在地图上的槽函数
void on_btnShowTargetsOnMap_clicked();
private:
// initialize UI control
void InitUI();
void initTcpConnect();
void initSerialConnect();
// 更新UI状态
void updateUIState();
// 保存检测到的目标和距离数据
void saveDetectionDataToFile();
// 构建发送给大模型的上下文信息
QString buildDetectionContext();
// MAVLink回调函数
void handleHeartbeat(const mavlink_heartbeat_t& heartbeat);
void handleSystemStatus(const mavlink_sys_status_t& status);
void handleAttitude(const mavlink_attitude_t& attitude);
void handlePosition(const mavlink_global_position_int_t& position);
void handleGPS(const mavlink_gps_raw_int_t& gps);
// 发送MAVLink命令
bool sendMavlinkCommand(uint16_t command, float param1 = 0, float param2 = 0,
float param3 = 0, float param4 = 0, float param5 = 0,
float param6 = 0, float param7 = 0);
// 声源定位相关方法
void fetchSoundLocatorData();
void initSoundVisualization();
// 计算目标地理位置
QGeoCoordinate calculateTargetPosition(double distance, double gimbalYaw, double gimbalPitch);
// 在地图上显示目标
void showDetectedTargetsOnMap();
// 添加单个目标到地图
void addTargetToMap(const DetectedObject& target);
// 计算目标位置估计的不确定性(米)
double calculateUncertainty(double distance);
// TCP控制相关方法
bool connectToMoveControlTCP();
void disconnectFromMoveControlTCP();
bool sendMoveCommand(double x, double y);
void moveTowardSoundSource(double angle);
private:
Ui::Widget *ui;
CVideoObjNetwork m_VideoObjNetwork;
CVideoObjUSBCamera m_VideoObjUSBCamera;
ImageProcessor* m_pImageProcessor;
ZhipuAPI* m_pZhipuAPI;
std::vector<DetectedObject> m_detectedObjects;
QString m_lastCapturedImagePath;
float m_confidenceThreshold;
// 目标距离相关成员
QLabel* m_labelDistanceValue;
QString m_detectionLogFile;
QString m_lastDetectionTime;
float m_lastLaserDistance;
bool m_distanceEstimationEnabled;
// 地图目标标记相关成员
bool m_targetsOnMap; // 标记目标是否已显示在地图上
double m_lastGimbalYaw; // 上次吊舱偏航角
double m_lastGimbalPitch; // 上次吊舱俯仰角
QPushButton* m_btnShowTargetsOnMap; // 显示目标按钮
// 替换为MAVLinkClient
MAVLinkClient *m_mavlinkClient;
MapWidget *m_mapWidget;
bool m_isUAVArmed;
// 保存当前飞行数据
mavlink_heartbeat_t m_heartbeat;
mavlink_sys_status_t m_sysStatus;
mavlink_attitude_t m_attitude;
mavlink_global_position_int_t m_position;
mavlink_gps_raw_int_t m_gps;
bool m_isMAVConnected;
// 声音定位相关成员
QNetworkAccessManager* m_soundNetworkManager;
QTimer* m_soundDataTimer;
bool m_isSoundLocatorRunning;
QString m_soundLocatorIP;
int m_soundLocatorPort;
QGraphicsScene* m_soundScene;
QGraphicsView* m_soundView;
QGraphicsEllipseItem* m_soundSourceItem;
QGraphicsLineItem* m_soundDirectionLine;
QGraphicsTextItem* m_soundInfoText;
QGraphicsEllipseItem* m_soundDetectorItem;
double m_soundX;
double m_soundY;
double m_soundStrength;
double m_soundAngle;
// TCP移动控制相关成员
int m_moveControlSocket;
struct sockaddr_in m_moveControlAddr;
bool m_moveControlConnected;
QString m_moveControlIP;
int m_moveControlPort;
bool m_lastSoundProcessed; // 标记上一次声音是否已处理
// GPS状态标签
QLabel* m_labelGPSValue;
QLabel* m_labelSatellitesValue;
QLabel* m_labelHDOPValue;
QLabel* m_labelVDOPValue;
};
#endif // WIDGET_H

@ -1,344 +0,0 @@
<template>
<div class="app-container">
<header class="app-header">
<h1>声源定位系统</h1>
<div class="system-status">
<el-tag :type="connectionStatus ? 'success' : 'danger'">
{{ connectionStatus ? '已连接到数据源' : '未连接到数据源' }}
</el-tag>
</div>
</header>
<main class="main-content">
<div class="control-panel">
<h2>控制面板</h2>
<div class="control-buttons">
<el-button type="primary" @click="startMonitoring" :disabled="isMonitoring">
开始监听
</el-button>
<el-button type="danger" @click="stopMonitoring" :disabled="!isMonitoring">
停止监听
</el-button>
</div>
<div class="source-info">
<h3>声源信息</h3>
<el-descriptions :column="1" border>
<el-descriptions-item label="X 坐标">{{ sourceData.X.toFixed(2) }}</el-descriptions-item>
<el-descriptions-item label="Y 坐标">{{ sourceData.Y.toFixed(2) }}</el-descriptions-item>
<el-descriptions-item label="强度">{{ sourceData.strength.toFixed(2) }}</el-descriptions-item>
<el-descriptions-item label="角度">{{ sourceData.angle.toFixed(2) }}°</el-descriptions-item>
</el-descriptions>
</div>
</div>
<div class="visualization-panel">
<h2>声源定位可视化</h2>
<div ref="chartContainer" class="chart-container"></div>
</div>
</main>
<footer class="app-footer">
<p>© 2025 声源定位系统. 基于 K210 麦克风阵列.</p>
</footer>
</div>
</template>
<script>
import axios from 'axios';
import * as echarts from 'echarts';
export default {
name: 'App',
data() {
return {
sourceData: {
X: 0.0,
Y: 0.0,
strength: 0.0,
angle: 0.0,
},
chart: null,
connectionStatus: false,
isMonitoring: false,
pollingInterval: null,
API_BASE_URL: 'http://127.0.0.1:5000', //
};
},
mounted() {
this.initChart();
this.checkConnection();
},
beforeUnmount() {
if (this.pollingInterval) {
clearInterval(this.pollingInterval);
}
if (this.chart) {
this.chart.dispose();
}
},
methods: {
checkConnection() {
axios.get(`${this.API_BASE_URL}/data`)
.then(() => {
this.connectionStatus = true;
})
.catch(() => {
this.connectionStatus = false;
});
},
initChart() {
const chartDom = this.$refs.chartContainer;
this.chart = echarts.init(chartDom);
this.updateChart();
//
window.addEventListener('resize', () => {
this.chart.resize();
});
},
updateChart() {
const { X, Y, strength, angle } = this.sourceData;
// 线 ()
const angleRad = (angle * Math.PI) / 180;
const directionLength = 10;
const dirX = directionLength * Math.sin(angleRad);
const dirY = directionLength * Math.cos(angleRad);
const option = {
title: {
text: '实时声源定位地图',
left: 'center'
},
tooltip: {
trigger: 'item',
formatter: function(params) {
if (params.seriesIndex === 0) {
return `声源位置:<br/>X: ${X.toFixed(2)}<br/>Y: ${Y.toFixed(2)}<br/>强度: ${strength.toFixed(2)}<br/>角度: ${angle.toFixed(2)}°`;
}
return '';
}
},
legend: {
data: ['声源位置', '方向'],
bottom: 10
},
grid: {
top: 80,
left: 50,
right: 50,
bottom: 60
},
xAxis: {
type: 'value',
min: -15,
max: 15,
name: 'X 坐标',
nameLocation: 'center',
nameGap: 30,
axisLine: {
show: true,
onZero: true
},
splitLine: {
show: true,
lineStyle: {
type: 'dashed'
}
}
},
yAxis: {
type: 'value',
min: -15,
max: 15,
name: 'Y 坐标',
nameLocation: 'center',
nameGap: 30,
axisLine: {
show: true,
onZero: true
},
splitLine: {
show: true,
lineStyle: {
type: 'dashed'
}
}
},
series: [
{
name: '声源位置',
type: 'scatter',
symbolSize: Math.max(10, strength * 5),
data: [[X, Y]],
itemStyle: {
color: '#F56C6C'
},
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowColor: 'rgba(245, 108, 108, 0.5)'
}
},
label: {
show: true,
position: 'top',
formatter: `强度: ${strength.toFixed(2)}\n角度: ${angle.toFixed(2)}°`
},
z: 10
},
{
name: '方向',
type: 'line',
data: [[0, 0], [dirX, dirY]],
lineStyle: {
width: 2,
color: '#409EFF'
},
symbol: ['circle', 'arrow'],
symbolSize: [5, 12],
label: {
show: false
},
z: 5
},
{
name: '探测器',
type: 'scatter',
data: [[0, 0]],
symbolSize: 10,
itemStyle: {
color: '#67C23A'
},
label: {
show: true,
position: 'bottom',
formatter: '探测器'
},
z: 8
}
]
};
this.chart.setOption(option);
},
startMonitoring() {
this.isMonitoring = true;
// 500ms
this.pollingInterval = setInterval(() => {
this.fetchSourceData();
}, 500);
},
stopMonitoring() {
this.isMonitoring = false;
if (this.pollingInterval) {
clearInterval(this.pollingInterval);
this.pollingInterval = null;
}
},
fetchSourceData() {
axios.get(`${this.API_BASE_URL}/data`)
.then(response => {
this.sourceData = response.data;
this.connectionStatus = true;
this.updateChart();
})
.catch(error => {
console.error('获取声源数据失败:', error);
this.connectionStatus = false;
});
}
}
}
</script>
<style>
.app-container {
display: flex;
flex-direction: column;
min-height: 100vh;
background-color: #f5f7fa;
font-family: 'Helvetica Neue', Helvetica, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', Arial, sans-serif;
}
.app-header {
background-color: #304156;
color: white;
padding: 1rem 2rem;
display: flex;
align-items: center;
justify-content: space-between;
box-shadow: 0 2px 12px 0 rgba(0, 0, 0, 0.1);
}
.app-header h1 {
margin: 0;
font-size: 1.6rem;
}
.main-content {
flex: 1;
display: flex;
padding: 1.5rem;
gap: 1.5rem;
}
.control-panel {
flex: 1;
background-color: white;
border-radius: 4px;
padding: 1.5rem;
box-shadow: 0 2px 12px 0 rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.visualization-panel {
flex: 2;
background-color: white;
border-radius: 4px;
padding: 1.5rem;
box-shadow: 0 2px 12px 0 rgba(0, 0, 0, 0.1);
}
.chart-container {
height: 500px;
width: 100%;
}
.control-buttons {
display: flex;
gap: 1rem;
margin: 1rem 0;
}
.app-footer {
background-color: #304156;
color: #a7b0bc;
text-align: center;
padding: 1rem;
font-size: 0.9rem;
}
h2, h3 {
margin-top: 0;
color: #303133;
border-bottom: 1px solid #ebeef5;
padding-bottom: 0.5rem;
}
@media (max-width: 768px) {
.main-content {
flex-direction: column;
}
.chart-container {
height: 350px;
}
}
</style>

@ -1 +0,0 @@
Subproject commit 39c8188929657d71dfecbac4288025765589c300

@ -0,0 +1,69 @@
# Details
Date : 2025-06-10 08:55:17
Directory e:\\pycharm_projects\\AudioClassification-Pytorch-master
Total : 54 files, 6851 codes, 838 comments, 1201 blanks, all 8890 lines
[Summary](results.md) / Details / [Diff Summary](diff.md) / [Diff Details](diff-details.md)
## Files
| filename | language | code | comment | blank | total |
| :--- | :--- | ---: | ---: | ---: | ---: |
| [README.md](/README.md) | Markdown | 302 | 0 | 72 | 374 |
| [README\_en.md](/README_en.md) | Markdown | 231 | 0 | 45 | 276 |
| [audio-classification-platform/README.md](/audio-classification-platform/README.md) | Markdown | 96 | 0 | 41 | 137 |
| [audio-classification-platform/backend/app.py](/audio-classification-platform/backend/app.py) | Python | 255 | 35 | 61 | 351 |
| [audio-classification-platform/backend/config.py](/audio-classification-platform/backend/config.py) | Python | 22 | 8 | 10 | 40 |
| [audio-classification-platform/backend/requirements.txt](/audio-classification-platform/backend/requirements.txt) | pip requirements | 7 | 0 | 1 | 8 |
| [audio-classification-platform/frontend/index.html](/audio-classification-platform/frontend/index.html) | HTML | 25 | 0 | 2 | 27 |
| [audio-classification-platform/frontend/package.json](/audio-classification-platform/frontend/package.json) | JSON | 24 | 0 | 1 | 25 |
| [audio-classification-platform/frontend/src/App.vue](/audio-classification-platform/frontend/src/App.vue) | Vue | 88 | 5 | 14 | 107 |
| [audio-classification-platform/frontend/src/components/AudioRecorder.vue](/audio-classification-platform/frontend/src/components/AudioRecorder.vue) | Vue | 551 | 38 | 102 | 691 |
| [audio-classification-platform/frontend/src/components/AudioRecorder\_new.vue](/audio-classification-platform/frontend/src/components/AudioRecorder_new.vue) | Vue | 811 | 48 | 162 | 1,021 |
| [audio-classification-platform/frontend/src/components/AudioUpload.vue](/audio-classification-platform/frontend/src/components/AudioUpload.vue) | Vue | 580 | 28 | 107 | 715 |
| [audio-classification-platform/frontend/src/components/HistoryList.vue](/audio-classification-platform/frontend/src/components/HistoryList.vue) | Vue | 513 | 25 | 79 | 617 |
| [audio-classification-platform/frontend/src/components/PredictionResult.vue](/audio-classification-platform/frontend/src/components/PredictionResult.vue) | Vue | 803 | 22 | 117 | 942 |
| [audio-classification-platform/frontend/src/main.js](/audio-classification-platform/frontend/src/main.js) | JavaScript | 13 | 1 | 5 | 19 |
| [audio-classification-platform/frontend/src/router/index.js](/audio-classification-platform/frontend/src/router/index.js) | JavaScript | 14 | 0 | 4 | 18 |
| [audio-classification-platform/frontend/src/utils/api.js](/audio-classification-platform/frontend/src/utils/api.js) | JavaScript | 156 | 25 | 33 | 214 |
| [audio-classification-platform/frontend/src/views/HomePage.vue](/audio-classification-platform/frontend/src/views/HomePage.vue) | Vue | 692 | 40 | 110 | 842 |
| [audio-classification-platform/frontend/vite.config.js](/audio-classification-platform/frontend/vite.config.js) | JavaScript | 32 | 0 | 2 | 34 |
| [audio-classification-platform/start.bat](/audio-classification-platform/start.bat) | Batch | 20 | 0 | 5 | 25 |
| [audio-classification-platform/start.sh](/audio-classification-platform/start.sh) | Shell Script | 26 | 3 | 8 | 37 |
| [configs/augmentation.yml](/configs/augmentation.yml) | YAML | 21 | 21 | 5 | 47 |
| [configs/cam++.yml](/configs/cam++.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/ecapa\_tdnn.yml](/configs/ecapa_tdnn.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/eres2net.yml](/configs/eres2net.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/panns.yml](/configs/panns.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/res2net.yml](/configs/res2net.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/resnet\_se.yml](/configs/resnet_se.yml) | YAML | 43 | 33 | 5 | 81 |
| [configs/tdnn.yml](/configs/tdnn.yml) | YAML | 43 | 33 | 5 | 81 |
| [create\_data.py](/create_data.py) | Python | 75 | 9 | 16 | 100 |
| [eval.py](/eval.py) | Python | 20 | 2 | 5 | 27 |
| [extract\_features.py](/extract_features.py) | Python | 13 | 2 | 5 | 20 |
| [infer.py](/infer.py) | Python | 17 | 1 | 6 | 24 |
| [infer\_record.py](/infer_record.py) | Python | 40 | 7 | 12 | 59 |
| [macls/\_\_init\_\_.py](/macls/__init__.py) | Python | 1 | 0 | 1 | 2 |
| [macls/data\_utils/\_\_init\_\_.py](/macls/data_utils/__init__.py) | Python | 0 | 0 | 1 | 1 |
| [macls/data\_utils/collate\_fn.py](/macls/data_utils/collate_fn.py) | Python | 17 | 4 | 3 | 24 |
| [macls/data\_utils/featurizer.py](/macls/data_utils/featurizer.py) | Python | 88 | 36 | 9 | 133 |
| [macls/data\_utils/reader.py](/macls/data_utils/reader.py) | Python | 114 | 33 | 11 | 158 |
| [macls/metric/\_\_init\_\_.py](/macls/metric/__init__.py) | Python | 0 | 0 | 1 | 1 |
| [macls/metric/metrics.py](/macls/metric/metrics.py) | Python | 9 | 1 | 3 | 13 |
| [macls/optimizer/\_\_init\_\_.py](/macls/optimizer/__init__.py) | Python | 26 | 0 | 7 | 33 |
| [macls/optimizer/scheduler.py](/macls/optimizer/scheduler.py) | Python | 42 | 0 | 7 | 49 |
| [macls/predict.py](/macls/predict.py) | Python | 124 | 47 | 7 | 178 |
| [macls/trainer.py](/macls/trainer.py) | Python | 338 | 99 | 20 | 457 |
| [macls/utils/\_\_init\_\_.py](/macls/utils/__init__.py) | Python | 0 | 0 | 1 | 1 |
| [macls/utils/checkpoint.py](/macls/utils/checkpoint.py) | Python | 113 | 40 | 10 | 163 |
| [macls/utils/record.py](/macls/utils/record.py) | Python | 18 | 8 | 6 | 32 |
| [macls/utils/utils.py](/macls/utils/utils.py) | Python | 99 | 16 | 17 | 132 |
| [record\_audio.py](/record_audio.py) | Python | 10 | 0 | 5 | 15 |
| [requirements.txt](/requirements.txt) | pip requirements | 17 | 0 | 1 | 18 |
| [setup.py](/setup.py) | Python | 43 | 1 | 11 | 55 |
| [tools/download\_language\_data.sh](/tools/download_language_data.sh) | Shell Script | 19 | 1 | 10 | 30 |
| [train.py](/train.py) | Python | 25 | 1 | 5 | 31 |
[Summary](results.md) / Details / [Diff Summary](diff.md) / [Diff Details](diff-details.md)

@ -0,0 +1,15 @@
# Diff Details
Date : 2025-06-10 08:55:17
Directory e:\\pycharm_projects\\AudioClassification-Pytorch-master
Total : 0 files, 0 codes, 0 comments, 0 blanks, all 0 lines
[Summary](results.md) / [Details](details.md) / [Diff Summary](diff.md) / Diff Details
## Files
| filename | language | code | comment | blank | total |
| :--- | :--- | ---: | ---: | ---: | ---: |
[Summary](results.md) / [Details](details.md) / [Diff Summary](diff.md) / Diff Details

@ -0,0 +1,2 @@
"filename", "language", "", "comment", "blank", "total"
"Total", "-", , 0, 0, 0
1 filename language comment blank total
2 Total - 0 0 0

@ -0,0 +1,19 @@
# Diff Summary
Date : 2025-06-10 08:55:17
Directory e:\\pycharm_projects\\AudioClassification-Pytorch-master
Total : 0 files, 0 codes, 0 comments, 0 blanks, all 0 lines
[Summary](results.md) / [Details](details.md) / Diff Summary / [Diff Details](diff-details.md)
## Languages
| language | files | code | comment | blank | total |
| :--- | ---: | ---: | ---: | ---: | ---: |
## Directories
| path | files | code | comment | blank | total |
| :--- | ---: | ---: | ---: | ---: | ---: |
[Summary](results.md) / [Details](details.md) / Diff Summary / [Diff Details](diff-details.md)

@ -0,0 +1,22 @@
Date : 2025-06-10 08:55:17
Directory : e:\pycharm_projects\AudioClassification-Pytorch-master
Total : 0 files, 0 codes, 0 comments, 0 blanks, all 0 lines
Languages
+----------+------------+------------+------------+------------+------------+
| language | files | code | comment | blank | total |
+----------+------------+------------+------------+------------+------------+
+----------+------------+------------+------------+------------+------------+
Directories
+------+------------+------------+------------+------------+------------+
| path | files | code | comment | blank | total |
+------+------------+------------+------------+------------+------------+
+------+------------+------------+------------+------------+------------+
Files
+----------+----------+------------+------------+------------+------------+
| filename | language | code | comment | blank | total |
+----------+----------+------------+------------+------------+------------+
| Total | | 0 | 0 | 0 | 0 |
+----------+----------+------------+------------+------------+------------+

@ -0,0 +1,56 @@
"filename", "language", "Python", "pip requirements", "Markdown", "Shell Script", "YAML", "Batch", "JavaScript", "Vue", "JSON", "HTML", "comment", "blank", "total"
"e:\pycharm_projects\AudioClassification-Pytorch-master\README.md", "Markdown", 0, 0, 302, 0, 0, 0, 0, 0, 0, 0, 0, 72, 374
"e:\pycharm_projects\AudioClassification-Pytorch-master\README_en.md", "Markdown", 0, 0, 231, 0, 0, 0, 0, 0, 0, 0, 0, 45, 276
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\README.md", "Markdown", 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 41, 137
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\app.py", "Python", 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 61, 351
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\config.py", "Python", 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 10, 40
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\requirements.txt", "pip requirements", 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 8
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\index.html", "HTML", 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 2, 27
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\package.json", "JSON", 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 1, 25
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\App.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 88, 0, 0, 5, 14, 107
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 551, 0, 0, 38, 102, 691
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder_new.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 811, 0, 0, 48, 162, 1021
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioUpload.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 580, 0, 0, 28, 107, 715
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\HistoryList.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 513, 0, 0, 25, 79, 617
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\PredictionResult.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 803, 0, 0, 22, 117, 942
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\main.js", "JavaScript", 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 1, 5, 19
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\router\index.js", "JavaScript", 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 4, 18
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\utils\api.js", "JavaScript", 0, 0, 0, 0, 0, 0, 156, 0, 0, 0, 25, 33, 214
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\views\HomePage.vue", "Vue", 0, 0, 0, 0, 0, 0, 0, 692, 0, 0, 40, 110, 842
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\vite.config.js", "JavaScript", 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 2, 34
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.bat", "Batch", 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 5, 25
"e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.sh", "Shell Script", 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 3, 8, 37
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\augmentation.yml", "YAML", 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 21, 5, 47
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\cam++.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\ecapa_tdnn.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\eres2net.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\panns.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\res2net.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\resnet_se.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\configs\tdnn.yml", "YAML", 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 33, 5, 81
"e:\pycharm_projects\AudioClassification-Pytorch-master\create_data.py", "Python", 75, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 16, 100
"e:\pycharm_projects\AudioClassification-Pytorch-master\eval.py", "Python", 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 27
"e:\pycharm_projects\AudioClassification-Pytorch-master\extract_features.py", "Python", 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 20
"e:\pycharm_projects\AudioClassification-Pytorch-master\infer.py", "Python", 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 6, 24
"e:\pycharm_projects\AudioClassification-Pytorch-master\infer_record.py", "Python", 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 12, 59
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\__init__.py", "Python", 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\__init__.py", "Python", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\collate_fn.py", "Python", 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 24
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\featurizer.py", "Python", 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36, 9, 133
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\reader.py", "Python", 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33, 11, 158
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\__init__.py", "Python", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\metrics.py", "Python", 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 13
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\__init__.py", "Python", 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 33
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\scheduler.py", "Python", 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 49
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\predict.py", "Python", 124, 0, 0, 0, 0, 0, 0, 0, 0, 0, 47, 7, 178
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\trainer.py", "Python", 338, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 20, 457
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\__init__.py", "Python", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\checkpoint.py", "Python", 113, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 10, 163
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\record.py", "Python", 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 6, 32
"e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\utils.py", "Python", 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 132
"e:\pycharm_projects\AudioClassification-Pytorch-master\record_audio.py", "Python", 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 15
"e:\pycharm_projects\AudioClassification-Pytorch-master\requirements.txt", "pip requirements", 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 18
"e:\pycharm_projects\AudioClassification-Pytorch-master\setup.py", "Python", 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 11, 55
"e:\pycharm_projects\AudioClassification-Pytorch-master\tools\download_language_data.sh", "Shell Script", 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 1, 10, 30
"e:\pycharm_projects\AudioClassification-Pytorch-master\train.py", "Python", 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 5, 31
"Total", "-", 1509, 24, 629, 45, 322, 20, 215, 4038, 24, 25, 838, 1201, 8890
1 filename language Python pip requirements Markdown Shell Script YAML Batch JavaScript Vue JSON HTML comment blank total
2 e:\pycharm_projects\AudioClassification-Pytorch-master\README.md Markdown 0 0 302 0 0 0 0 0 0 0 0 72 374
3 e:\pycharm_projects\AudioClassification-Pytorch-master\README_en.md Markdown 0 0 231 0 0 0 0 0 0 0 0 45 276
4 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\README.md Markdown 0 0 96 0 0 0 0 0 0 0 0 41 137
5 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\app.py Python 255 0 0 0 0 0 0 0 0 0 35 61 351
6 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\config.py Python 22 0 0 0 0 0 0 0 0 0 8 10 40
7 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\requirements.txt pip requirements 0 7 0 0 0 0 0 0 0 0 0 1 8
8 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\index.html HTML 0 0 0 0 0 0 0 0 0 25 0 2 27
9 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\package.json JSON 0 0 0 0 0 0 0 0 24 0 0 1 25
10 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\App.vue Vue 0 0 0 0 0 0 0 88 0 0 5 14 107
11 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder.vue Vue 0 0 0 0 0 0 0 551 0 0 38 102 691
12 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder_new.vue Vue 0 0 0 0 0 0 0 811 0 0 48 162 1021
13 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioUpload.vue Vue 0 0 0 0 0 0 0 580 0 0 28 107 715
14 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\HistoryList.vue Vue 0 0 0 0 0 0 0 513 0 0 25 79 617
15 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\PredictionResult.vue Vue 0 0 0 0 0 0 0 803 0 0 22 117 942
16 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\main.js JavaScript 0 0 0 0 0 0 13 0 0 0 1 5 19
17 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\router\index.js JavaScript 0 0 0 0 0 0 14 0 0 0 0 4 18
18 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\utils\api.js JavaScript 0 0 0 0 0 0 156 0 0 0 25 33 214
19 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\views\HomePage.vue Vue 0 0 0 0 0 0 0 692 0 0 40 110 842
20 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\vite.config.js JavaScript 0 0 0 0 0 0 32 0 0 0 0 2 34
21 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.bat Batch 0 0 0 0 0 20 0 0 0 0 0 5 25
22 e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.sh Shell Script 0 0 0 26 0 0 0 0 0 0 3 8 37
23 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\augmentation.yml YAML 0 0 0 0 21 0 0 0 0 0 21 5 47
24 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\cam++.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
25 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\ecapa_tdnn.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
26 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\eres2net.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
27 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\panns.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
28 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\res2net.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
29 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\resnet_se.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
30 e:\pycharm_projects\AudioClassification-Pytorch-master\configs\tdnn.yml YAML 0 0 0 0 43 0 0 0 0 0 33 5 81
31 e:\pycharm_projects\AudioClassification-Pytorch-master\create_data.py Python 75 0 0 0 0 0 0 0 0 0 9 16 100
32 e:\pycharm_projects\AudioClassification-Pytorch-master\eval.py Python 20 0 0 0 0 0 0 0 0 0 2 5 27
33 e:\pycharm_projects\AudioClassification-Pytorch-master\extract_features.py Python 13 0 0 0 0 0 0 0 0 0 2 5 20
34 e:\pycharm_projects\AudioClassification-Pytorch-master\infer.py Python 17 0 0 0 0 0 0 0 0 0 1 6 24
35 e:\pycharm_projects\AudioClassification-Pytorch-master\infer_record.py Python 40 0 0 0 0 0 0 0 0 0 7 12 59
36 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\__init__.py Python 1 0 0 0 0 0 0 0 0 0 0 1 2
37 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\__init__.py Python 0 0 0 0 0 0 0 0 0 0 0 1 1
38 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\collate_fn.py Python 17 0 0 0 0 0 0 0 0 0 4 3 24
39 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\featurizer.py Python 88 0 0 0 0 0 0 0 0 0 36 9 133
40 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\reader.py Python 114 0 0 0 0 0 0 0 0 0 33 11 158
41 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\__init__.py Python 0 0 0 0 0 0 0 0 0 0 0 1 1
42 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\metrics.py Python 9 0 0 0 0 0 0 0 0 0 1 3 13
43 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\__init__.py Python 26 0 0 0 0 0 0 0 0 0 0 7 33
44 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\scheduler.py Python 42 0 0 0 0 0 0 0 0 0 0 7 49
45 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\predict.py Python 124 0 0 0 0 0 0 0 0 0 47 7 178
46 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\trainer.py Python 338 0 0 0 0 0 0 0 0 0 99 20 457
47 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\__init__.py Python 0 0 0 0 0 0 0 0 0 0 0 1 1
48 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\checkpoint.py Python 113 0 0 0 0 0 0 0 0 0 40 10 163
49 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\record.py Python 18 0 0 0 0 0 0 0 0 0 8 6 32
50 e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\utils.py Python 99 0 0 0 0 0 0 0 0 0 16 17 132
51 e:\pycharm_projects\AudioClassification-Pytorch-master\record_audio.py Python 10 0 0 0 0 0 0 0 0 0 0 5 15
52 e:\pycharm_projects\AudioClassification-Pytorch-master\requirements.txt pip requirements 0 17 0 0 0 0 0 0 0 0 0 1 18
53 e:\pycharm_projects\AudioClassification-Pytorch-master\setup.py Python 43 0 0 0 0 0 0 0 0 0 1 11 55
54 e:\pycharm_projects\AudioClassification-Pytorch-master\tools\download_language_data.sh Shell Script 0 0 0 19 0 0 0 0 0 0 1 10 30
55 e:\pycharm_projects\AudioClassification-Pytorch-master\train.py Python 25 0 0 0 0 0 0 0 0 0 1 5 31
56 Total - 1509 24 629 45 322 20 215 4038 24 25 838 1201 8890

@ -0,0 +1,50 @@
# Summary
Date : 2025-06-10 08:55:17
Directory e:\\pycharm_projects\\AudioClassification-Pytorch-master
Total : 54 files, 6851 codes, 838 comments, 1201 blanks, all 8890 lines
Summary / [Details](details.md) / [Diff Summary](diff.md) / [Diff Details](diff-details.md)
## Languages
| language | files | code | comment | blank | total |
| :--- | ---: | ---: | ---: | ---: | ---: |
| Vue | 7 | 4,038 | 206 | 691 | 4,935 |
| Python | 25 | 1,509 | 350 | 240 | 2,099 |
| Markdown | 3 | 629 | 0 | 158 | 787 |
| YAML | 8 | 322 | 252 | 40 | 614 |
| JavaScript | 4 | 215 | 26 | 44 | 285 |
| Shell Script | 2 | 45 | 4 | 18 | 67 |
| HTML | 1 | 25 | 0 | 2 | 27 |
| pip requirements | 2 | 24 | 0 | 2 | 26 |
| JSON | 1 | 24 | 0 | 1 | 25 |
| Batch | 1 | 20 | 0 | 5 | 25 |
## Directories
| path | files | code | comment | blank | total |
| :--- | ---: | ---: | ---: | ---: | ---: |
| . | 54 | 6,851 | 838 | 1,201 | 8,890 |
| . (Files) | 11 | 793 | 23 | 183 | 999 |
| audio-classification-platform | 19 | 4,728 | 278 | 864 | 5,870 |
| audio-classification-platform (Files) | 3 | 142 | 3 | 54 | 199 |
| audio-classification-platform\\backend | 3 | 284 | 43 | 72 | 399 |
| audio-classification-platform\\frontend | 13 | 4,302 | 232 | 738 | 5,272 |
| audio-classification-platform\\frontend (Files) | 3 | 81 | 0 | 5 | 86 |
| audio-classification-platform\\frontend\\src | 10 | 4,221 | 232 | 733 | 5,186 |
| audio-classification-platform\\frontend\\src (Files) | 2 | 101 | 6 | 19 | 126 |
| audio-classification-platform\\frontend\\src\\components | 5 | 3,258 | 161 | 567 | 3,986 |
| audio-classification-platform\\frontend\\src\\router | 1 | 14 | 0 | 4 | 18 |
| audio-classification-platform\\frontend\\src\\utils | 1 | 156 | 25 | 33 | 214 |
| audio-classification-platform\\frontend\\src\\views | 1 | 692 | 40 | 110 | 842 |
| configs | 8 | 322 | 252 | 40 | 614 |
| macls | 15 | 989 | 284 | 104 | 1,377 |
| macls (Files) | 3 | 463 | 146 | 28 | 637 |
| macls\\data_utils | 4 | 219 | 73 | 24 | 316 |
| macls\\metric | 2 | 9 | 1 | 4 | 14 |
| macls\\optimizer | 2 | 68 | 0 | 14 | 82 |
| macls\\utils | 4 | 230 | 64 | 34 | 328 |
| tools | 1 | 19 | 1 | 10 | 30 |
Summary / [Details](details.md) / [Diff Summary](diff.md) / [Diff Details](diff-details.md)

@ -0,0 +1,107 @@
Date : 2025-06-10 08:55:17
Directory : e:\pycharm_projects\AudioClassification-Pytorch-master
Total : 54 files, 6851 codes, 838 comments, 1201 blanks, all 8890 lines
Languages
+------------------+------------+------------+------------+------------+------------+
| language | files | code | comment | blank | total |
+------------------+------------+------------+------------+------------+------------+
| Vue | 7 | 4,038 | 206 | 691 | 4,935 |
| Python | 25 | 1,509 | 350 | 240 | 2,099 |
| Markdown | 3 | 629 | 0 | 158 | 787 |
| YAML | 8 | 322 | 252 | 40 | 614 |
| JavaScript | 4 | 215 | 26 | 44 | 285 |
| Shell Script | 2 | 45 | 4 | 18 | 67 |
| HTML | 1 | 25 | 0 | 2 | 27 |
| pip requirements | 2 | 24 | 0 | 2 | 26 |
| JSON | 1 | 24 | 0 | 1 | 25 |
| Batch | 1 | 20 | 0 | 5 | 25 |
+------------------+------------+------------+------------+------------+------------+
Directories
+------------------------------------------------------------------------------------------------------------------------------------+------------+------------+------------+------------+------------+
| path | files | code | comment | blank | total |
+------------------------------------------------------------------------------------------------------------------------------------+------------+------------+------------+------------+------------+
| . | 54 | 6,851 | 838 | 1,201 | 8,890 |
| . (Files) | 11 | 793 | 23 | 183 | 999 |
| audio-classification-platform | 19 | 4,728 | 278 | 864 | 5,870 |
| audio-classification-platform (Files) | 3 | 142 | 3 | 54 | 199 |
| audio-classification-platform\backend | 3 | 284 | 43 | 72 | 399 |
| audio-classification-platform\frontend | 13 | 4,302 | 232 | 738 | 5,272 |
| audio-classification-platform\frontend (Files) | 3 | 81 | 0 | 5 | 86 |
| audio-classification-platform\frontend\src | 10 | 4,221 | 232 | 733 | 5,186 |
| audio-classification-platform\frontend\src (Files) | 2 | 101 | 6 | 19 | 126 |
| audio-classification-platform\frontend\src\components | 5 | 3,258 | 161 | 567 | 3,986 |
| audio-classification-platform\frontend\src\router | 1 | 14 | 0 | 4 | 18 |
| audio-classification-platform\frontend\src\utils | 1 | 156 | 25 | 33 | 214 |
| audio-classification-platform\frontend\src\views | 1 | 692 | 40 | 110 | 842 |
| configs | 8 | 322 | 252 | 40 | 614 |
| macls | 15 | 989 | 284 | 104 | 1,377 |
| macls (Files) | 3 | 463 | 146 | 28 | 637 |
| macls\data_utils | 4 | 219 | 73 | 24 | 316 |
| macls\metric | 2 | 9 | 1 | 4 | 14 |
| macls\optimizer | 2 | 68 | 0 | 14 | 82 |
| macls\utils | 4 | 230 | 64 | 34 | 328 |
| tools | 1 | 19 | 1 | 10 | 30 |
+------------------------------------------------------------------------------------------------------------------------------------+------------+------------+------------+------------+------------+
Files
+------------------------------------------------------------------------------------------------------------------------------------+------------------+------------+------------+------------+------------+
| filename | language | code | comment | blank | total |
+------------------------------------------------------------------------------------------------------------------------------------+------------------+------------+------------+------------+------------+
| e:\pycharm_projects\AudioClassification-Pytorch-master\README.md | Markdown | 302 | 0 | 72 | 374 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\README_en.md | Markdown | 231 | 0 | 45 | 276 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\README.md | Markdown | 96 | 0 | 41 | 137 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\app.py | Python | 255 | 35 | 61 | 351 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\config.py | Python | 22 | 8 | 10 | 40 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\backend\requirements.txt | pip requirements | 7 | 0 | 1 | 8 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\index.html | HTML | 25 | 0 | 2 | 27 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\package.json | JSON | 24 | 0 | 1 | 25 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\App.vue | Vue | 88 | 5 | 14 | 107 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder.vue | Vue | 551 | 38 | 102 | 691 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioRecorder_new.vue | Vue | 811 | 48 | 162 | 1,021 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\AudioUpload.vue | Vue | 580 | 28 | 107 | 715 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\HistoryList.vue | Vue | 513 | 25 | 79 | 617 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\components\PredictionResult.vue | Vue | 803 | 22 | 117 | 942 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\main.js | JavaScript | 13 | 1 | 5 | 19 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\router\index.js | JavaScript | 14 | 0 | 4 | 18 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\utils\api.js | JavaScript | 156 | 25 | 33 | 214 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\src\views\HomePage.vue | Vue | 692 | 40 | 110 | 842 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\frontend\vite.config.js | JavaScript | 32 | 0 | 2 | 34 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.bat | Batch | 20 | 0 | 5 | 25 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\audio-classification-platform\start.sh | Shell Script | 26 | 3 | 8 | 37 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\augmentation.yml | YAML | 21 | 21 | 5 | 47 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\cam++.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\ecapa_tdnn.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\eres2net.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\panns.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\res2net.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\resnet_se.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\configs\tdnn.yml | YAML | 43 | 33 | 5 | 81 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\create_data.py | Python | 75 | 9 | 16 | 100 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\eval.py | Python | 20 | 2 | 5 | 27 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\extract_features.py | Python | 13 | 2 | 5 | 20 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\infer.py | Python | 17 | 1 | 6 | 24 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\infer_record.py | Python | 40 | 7 | 12 | 59 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\__init__.py | Python | 1 | 0 | 1 | 2 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\__init__.py | Python | 0 | 0 | 1 | 1 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\collate_fn.py | Python | 17 | 4 | 3 | 24 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\featurizer.py | Python | 88 | 36 | 9 | 133 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\data_utils\reader.py | Python | 114 | 33 | 11 | 158 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\__init__.py | Python | 0 | 0 | 1 | 1 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\metric\metrics.py | Python | 9 | 1 | 3 | 13 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\__init__.py | Python | 26 | 0 | 7 | 33 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\optimizer\scheduler.py | Python | 42 | 0 | 7 | 49 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\predict.py | Python | 124 | 47 | 7 | 178 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\trainer.py | Python | 338 | 99 | 20 | 457 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\__init__.py | Python | 0 | 0 | 1 | 1 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\checkpoint.py | Python | 113 | 40 | 10 | 163 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\record.py | Python | 18 | 8 | 6 | 32 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\macls\utils\utils.py | Python | 99 | 16 | 17 | 132 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\record_audio.py | Python | 10 | 0 | 5 | 15 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\requirements.txt | pip requirements | 17 | 0 | 1 | 18 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\setup.py | Python | 43 | 1 | 11 | 55 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\tools\download_language_data.sh | Shell Script | 19 | 1 | 10 | 30 |
| e:\pycharm_projects\AudioClassification-Pytorch-master\train.py | Python | 25 | 1 | 5 | 31 |
| Total | | 6,851 | 838 | 1,201 | 8,890 |
+------------------------------------------------------------------------------------------------------------------------------------+------------------+------------+------------+------------+------------+

@ -0,0 +1,296 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
macls.egg-info/
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django/Flask stuff
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
instance/
.webassets-cache
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Node.js dependencies
node_modules/
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
package-lock.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Coverage directory used by tools like istanbul
coverage/
# nyc test coverage
.nyc_output
# Grunt intermediate storage
.grunt
# Bower dependency directory
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons
build/Release
# Dependency directories
jspm_packages/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
.env.local
.env.production
# parcel-bundler cache
.cache
.parcel-cache
# Next.js build output
.next
# Nuxt.js build / generate output
.nuxt
# Gatsby files
# Comment in the public line in if your project uses Gatsby
# public
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# Audio Classification specific ignores
# Model files (usually large)
*.pth
*.pt
*.h5
*.ckpt
*.pb
*.onnx
*.pkl
*.joblib
# Dataset directories (usually large audio files)
dataset/
dataset/*/audio/
dataset/*/wav/
dataset/*/mp3/
dataset/*/flac/
# Uncomment if you want to ignore all audio files
# *.wav
# *.mp3
# *.flac
# *.ogg
# *.m4a
# *.aac
# Training artifacts and logs
log/
logs/
output/
outputs/
uploads/
results/
checkpoints/
models/
pretrained_models/
feature_models/
runs/
wandb/
mlruns/
.mlflow/
# Temporary files
temp/
tmp/
*.tmp
test*.py
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# IDE files
.idea/
.vscode/
*.swp
*.swo
*~
# Audio processing temporary files
*.spec
*.mfcc
*.mel
# Frontend build files
audio-classification-platform/frontend/dist/
audio-classification-platform/frontend/build/
audio-classification-platform/frontend/.vite/
# Uploaded files directory
audio-classification-platform/backend/uploads/
# Local development configuration
audio-classification-platform/backend/.env
audio-classification-platform/frontend/.env.local

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -0,0 +1,373 @@
简体中文 | [English](./README_en.md)
# 基于Pytorch实现的声音分类系统
![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
![GitHub forks](https://img.shields.io/github/forks/yeyupiaoling/AudioClassification-Pytorch)
![GitHub Repo stars](https://img.shields.io/github/stars/yeyupiaoling/AudioClassification-Pytorch)
![GitHub](https://img.shields.io/github/license/yeyupiaoling/AudioClassification-Pytorch)
![支持系统](https://img.shields.io/badge/支持系统-Win/Linux/MAC-9cf)
# 前言
本项目是基于Pytorch的声音分类项目旨在实现对各种环境声音、动物叫声和语种的识别。项目提供了多种声音分类模型如EcapaTdnn、PANNS、ResNetSE、CAMPPlus和ERes2Net以支持不同的应用场景。此外项目还提供了常用的Urbansound8K数据集测试报告和一些方言数据集的下载和使用例子。用户可以根据自己的需求选择适合的模型和数据集以实现更准确的声音分类。项目的应用场景广泛可以用于室外的环境监测、野生动物保护、语音识别等领域。同时项目也鼓励用户探索更多的使用场景以推动声音分类技术的发展和应用。
**欢迎大家扫码入知识星球或者QQ群讨论知识星球里面提供项目的模型文件和博主其他相关项目的模型文件也包括其他一些资源。**
<div align="center">
<img src="https://yeyupiaoling.cn/zsxq.png" alt="知识星球" width="400">
<img src="https://yeyupiaoling.cn/qq.png" alt="QQ群" width="400">
</div>
# 目录
- [前言](#前言)
- [项目特性](#项目特性)
- [模型测试表](#模型测试表)
- [安装环境](#安装环境)
- [创建数据](#创建数据)
- [修改预处理方法(可选)](#修改预处理方法可选)
- [提取特征(可选)](#提取特征可选)
- [训练模型](#训练模型)
- [评估模型](#评估模型)
- [预测](#预测)
- [其他功能](#其他功能)
# 使用准备
- Anaconda 3
- Python 3.11
- Pytorch 2.0.1
- Windows 11 or Ubuntu 22.04
# 项目特性
1. 支持模型EcapaTdnn、PANNS、TDNN、Res2Net、ResNetSE、CAMPPlus、ERes2Net
2. 支持池化层AttentiveStatsPool(ASP)、SelfAttentivePooling(SAP)、TemporalStatisticsPooling(TSP)、TemporalAveragePooling(TAP)
4. 支持预处理方法MelSpectrogram、Spectrogram、MFCC、Fbank、Wav2vec2.0、WavLM
**模型论文:**
- EcapaTdnn[ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification](https://arxiv.org/abs/2005.07143v3)
- PANNS[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/abs/1912.10211v5)
- TDNN[Prediction of speech intelligibility with DNN-based performance measures](https://arxiv.org/abs/2203.09148)
- Res2Net[Res2Net: A New Multi-scale Backbone Architecture](https://arxiv.org/abs/1904.01169)
- ResNetSE[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
- CAMPPlus[CAM++: A Fast and Efficient Network for Speaker Verification Using Context-Aware Masking](https://arxiv.org/abs/2303.00332v3)
- ERes2Net[An Enhanced Res2Net with Local and Global Feature Fusion for Speaker Verification](https://arxiv.org/abs/2305.12838v1)
# 模型测试表
| 模型 | Params(M) | 预处理方法 | 数据集 | 类别数量 | 准确率 | 获取模型 |
|:------------:|:---------:|:-----:|:------------:|:----:|:-------:|:--------:|
| ResNetSE | 7.8 | Flank | UrbanSound8K | 10 | 0.96233 | 加入知识星球获取 |
| ERes2NetV2 | 5.4 | Flank | UrbanSound8K | 10 | 0.95662 | 加入知识星球获取 |
| CAMPPlus | 7.1 | Flank | UrbanSound8K | 10 | 0.95454 | 加入知识星球获取 |
| EcapaTdnn | 6.4 | Flank | UrbanSound8K | 10 | 0.95227 | 加入知识星球获取 |
| ERes2Net | 6.6 | Flank | UrbanSound8K | 10 | 0.94292 | 加入知识星球获取 |
| TDNN | 2.6 | Flank | UrbanSound8K | 10 | 0.93977 | 加入知识星球获取 |
| PANNSCNN10 | 5.2 | Flank | UrbanSound8K | 10 | 0.92954 | 加入知识星球获取 |
| Res2Net | 5.0 | Flank | UrbanSound8K | 10 | 0.92580 | 加入知识星球获取 |
**说明:**
1. 使用的测试集为从数据集中每10条音频取一条共874条。
## 安装环境
- 首先安装的是Pytorch的GPU版本如果已经安装过了请跳过。
```shell
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia
```
- 安装macls库。
使用pip安装命令如下
```shell
python -m pip install macls -U -i https://pypi.tuna.tsinghua.edu.cn/simple
```
**建议源码安装**,源码安装能保证使用最新代码。
```shell
git clone https://github.com/yeyupiaoling/AudioClassification-Pytorch.git
cd AudioClassification-Pytorch/
pip install .
```
## 创建数据
生成数据列表,用于下一步的读取需要,`audio_path`为音频文件路径,用户需要提前把音频数据集存放在`dataset/audio`目录下每个文件夹存放一个类别的音频数据每条音频数据长度在3秒以上`dataset/audio/鸟叫声/······`。`audio`是数据列表存放的位置,生成的数据类别的格式为 `音频路径\t音频对应的类别标签`,音频路径和标签用制表符 `\t`分开。读者也可以根据自己存放数据的方式修改以下函数。
以Urbansound8K为例Urbansound8K是目前应用较为广泛的用于自动城市环境声分类研究的公共数据集包含10个分类空调声、汽车鸣笛声、儿童玩耍声、狗叫声、钻孔声、引擎空转声、枪声、手提钻、警笛声和街道音乐声。数据集下载地址[UrbanSound8K.tar.gz](https://aistudio.baidu.com/aistudio/datasetdetail/36625)。以下是针对Urbansound8K生成数据列表的函数。如果读者想使用该数据集请下载并解压到 `dataset`目录下,把生成数据列表代码改为以下代码。
执行`create_data.py`即可生成数据列表,里面提供了生成多种数据集列表方式,具体看代码。
```shell
python create_data.py
```
生成的列表是长这样的前面是音频的路径后面是该音频对应的标签从0开始路径和标签之间用`\t`隔开。
```shell
dataset/UrbanSound8K/audio/fold2/104817-4-0-2.wav 4
dataset/UrbanSound8K/audio/fold9/105029-7-2-5.wav 7
dataset/UrbanSound8K/audio/fold3/107228-5-0-0.wav 5
dataset/UrbanSound8K/audio/fold4/109711-3-2-4.wav 3
```
# 修改预处理方法(可选)
配置文件中默认使用的是Fbank预处理方法如果要使用其他预处理方法可以修改配置文件中的安装下面方式修改具体的值可以根据自己情况修改。如果不清楚如何设置参数可以直接删除该部分直接使用默认值。
```yaml
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
```
# 提取特征(可选)
在训练过程中,首先是要读取音频数据,然后提取特征,最后再进行训练。其中读取音频数据、提取特征也是比较消耗时间的,所以我们可以选择提前提取好取特征,训练模型的是就可以直接加载提取好的特征,这样训练速度会更快。这个提取特征是可选择,如果没有提取好的特征,训练模型的时候就会从读取音频数据,然后提取特征开始。提取特征步骤如下:
1. 执行`extract_features.py`,提取特征,特征会保存在`dataset/features`目录下,并生成新的数据列表`train_list_features.txt`和`test_list_features.txt`。
```shell
python extract_features.py --configs=configs/cam++.yml --save_dir=dataset/features
```
2. 修改配置文件,将`dataset_conf.train_list`和`dataset_conf.test_list`修改为`train_list_features.txt`和`test_list_features.txt`。
## 训练模型
接着就可以开始训练模型了,创建 `train.py`。配置文件里面的参数一般不需要修改,但是这几个是需要根据自己实际的数据集进行调整的,首先最重要的就是分类大小`dataset_conf.num_class`,这个每个数据集的分类大小可能不一样,根据自己的实际情况设定。然后是`dataset_conf.batch_size`,如果是显存不够的话,可以减小这个参数。
```shell
# 单卡训练
CUDA_VISIBLE_DEVICES=0 python train.py
# 多卡训练
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc_per_node=2 train.py
```
训练输出日志:
```
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:14 - ----------- 额外配置参数 -----------
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - configs: configs/ecapa_tdnn.yml
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - local_rank: 0
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - pretrained_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - resume_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - save_model_path: models/
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - use_gpu: True
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:17 - ------------------------------------------------
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:19 - ----------- 配置文件参数 -----------
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:22 - dataset_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - aug_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - noise_aug_prob: 0.2
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - noise_dir: dataset/noise
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - speed_perturb: True
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - volume_aug_prob: 0.2
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - volume_perturb: False
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - dataLoader:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - batch_size: 64
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - num_workers: 4
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - do_vad: False
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - eval_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - batch_size: 1
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - max_duration: 20
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - label_list_path: dataset/label_list.txt
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - max_duration: 3
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - min_duration: 0.5
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - sample_rate: 16000
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - spec_aug_args:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - freq_mask_width: [0, 8]
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - time_mask_width: [0, 10]
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - target_dB: -20
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - test_list: dataset/test_list.txt
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - train_list: dataset/train_list.txt
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - use_dB_normalization: True
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - use_spec_aug: True
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:22 - model_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - num_class: 10
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - pooling_type: ASP
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:22 - optimizer_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - learning_rate: 0.001
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - optimizer: Adam
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - scheduler: WarmupCosineSchedulerLR
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:25 - scheduler_args:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - max_lr: 0.001
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - min_lr: 1e-05
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - warmup_epoch: 5
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - weight_decay: 1e-06
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:22 - preprocess_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - feature_method: Fbank
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:25 - method_args:
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:27 - num_mel_bins: 80
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:27 - sample_frequency: 16000
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:22 - train_conf:
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - log_interval: 10
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - max_epoch: 30
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:31 - use_model: EcapaTdnn
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:32 - ------------------------------------------------
[2023-08-07 22:54:22.213166 WARNING] trainer:__init__:67 - Windows系统不支持多线程读取数据已自动关闭
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
EcapaTdnn [1, 10] --
├─Conv1dReluBn: 1-1 [1, 512, 98] --
│ └─Conv1d: 2-1 [1, 512, 98] 204,800
│ └─BatchNorm1d: 2-2 [1, 512, 98] 1,024
├─Sequential: 1-2 [1, 512, 98] --
│ └─Conv1dReluBn: 2-3 [1, 512, 98] --
│ │ └─Conv1d: 3-1 [1, 512, 98] 262,144
│ │ └─BatchNorm1d: 3-2 [1, 512, 98] 1,024
│ └─Res2Conv1dReluBn: 2-4 [1, 512, 98] --
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
···································
│ │ └─ModuleList: 3-56 -- (recursive)
│ │ └─ModuleList: 3-55 -- (recursive)
│ │ └─ModuleList: 3-56 -- (recursive)
│ │ └─ModuleList: 3-55 -- (recursive)
│ │ └─ModuleList: 3-56 -- (recursive)
│ └─Conv1dReluBn: 2-13 [1, 512, 98] --
│ │ └─Conv1d: 3-57 [1, 512, 98] 262,144
│ │ └─BatchNorm1d: 3-58 [1, 512, 98] 1,024
│ └─SE_Connect: 2-14 [1, 512, 98] --
│ │ └─Linear: 3-59 [1, 256] 131,328
│ │ └─Linear: 3-60 [1, 512] 131,584
├─Conv1d: 1-5 [1, 1536, 98] 2,360,832
├─AttentiveStatsPool: 1-6 [1, 3072] --
│ └─Conv1d: 2-15 [1, 128, 98] 196,736
│ └─Conv1d: 2-16 [1, 1536, 98] 198,144
├─BatchNorm1d: 1-7 [1, 3072] 6,144
├─Linear: 1-8 [1, 192] 590,016
├─BatchNorm1d: 1-9 [1, 192] 384
├─Linear: 1-10 [1, 10] 1,930
==========================================================================================
Total params: 6,188,490
Trainable params: 6,188,490
Non-trainable params: 0
Total mult-adds (M): 470.96
==========================================================================================
Input size (MB): 0.03
Forward/backward pass size (MB): 10.28
Params size (MB): 24.75
Estimated Total Size (MB): 35.07
==========================================================================================
[2023-08-07 22:54:26.726095 INFO ] trainer:train:344 - 训练数据8644
[2023-08-07 22:54:30.092504 INFO ] trainer:__train_epoch:296 - Train epoch: [1/30], batch: [0/4], loss: 2.57033, accuracy: 0.06250, learning rate: 0.00001000, speed: 19.02 data/sec, eta: 0:06:43
```
**训练可视化:**
项目的根目录执行下面命令,并网页访问`http://localhost:8040/`,如果是服务器,需要修改`localhost`为服务器的IP地址。
```shell
visualdl --logdir=log --host=0.0.0.0
```
打开的网页如下:
<br/>
<div align="center">
<img src="docs/images/log.jpg" alt="混淆矩阵" width="600">
</div>
# 评估模型
执行下面命令执行评估。
```shell
python eval.py --configs=configs/bi_lstm.yml
```
评估输出如下:
```shell
[2024-02-03 15:13:25.469242 INFO ] trainer:evaluate:461 - 成功加载模型models/CAMPPlus_Fbank/best_model/model.pth
100%|██████████████████████████████| 150/150 [00:00<00:00, 1281.96it/s]
评估消耗时间1sloss0.61840accuracy0.87333
```
评估会出来输出准确率,还保存了混淆矩阵图片,保存路径`output/images/`,如下。
<br/>
<div align="center">
<img src="docs/images/image1.png" alt="混淆矩阵" width="600">
</div>
注意如果类别标签是中文的需要设置安装字体才能正常显示一般情况下Windows无需安装Ubuntu需要安装。如果Windows确实是缺少字体只需要[字体文件](https://github.com/tracyone/program_font)这里下载`.ttf`格式的文件,复制到`C:\Windows\Fonts`即可。Ubuntu系统操作如下。
1. 安装字体
```shell
git clone https://github.com/tracyone/program_font && cd program_font && ./install.sh
```
2. 执行下面Python代码
```python
import matplotlib
import shutil
import os
path = matplotlib.matplotlib_fname()
path = path.replace('matplotlibrc', 'fonts/ttf/')
print(path)
shutil.copy('/usr/share/fonts/MyFonts/simhei.ttf', path)
user_dir = os.path.expanduser('~')
shutil.rmtree(f'{user_dir}/.cache/matplotlib', ignore_errors=True)
```
# 预测
在训练结束之后,我们得到了一个模型参数文件,我们使用这个模型预测音频。
```shell
python infer.py --audio_path=dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav
```
# 其他功能
- 为了方便读取录制数据和制作数据集,这里提供了录音程序`record_audio.py`这个用于录制音频录制的音频采样率为16000单通道16bit。
```shell
python record_audio.py
```
- `infer_record.py`这个程序是用来不断进行录音识别,我们可以大致理解为这个程序在实时录音识别。通过这个应该我们可以做一些比较有趣的事情,比如把麦克风放在小鸟经常来的地方,通过实时录音识别,一旦识别到有鸟叫的声音,如果你的数据集足够强大,有每种鸟叫的声音数据集,这样你还能准确识别是那种鸟叫。如果识别到目标鸟类,就启动程序,例如拍照等等。
```shell
python infer_record.py --record_seconds=3
```
## 打赏作者
<br/>
<div align="center">
<p>打赏一块钱支持一下作者</p>
<img src="https://yeyupiaoling.cn/reward.png" alt="打赏作者" width="400">
</div>
# 参考资料
1. https://github.com/PaddlePaddle/PaddleSpeech
2. https://github.com/yeyupiaoling/PaddlePaddle-MobileFaceNets
3. https://github.com/yeyupiaoling/PPASR
4. https://github.com/alibaba-damo-academy/3D-Speaker

@ -0,0 +1,275 @@
[简体中文](./README.md) | English
# Sound classification system implemented in Pytorch
![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
![GitHub forks](https://img.shields.io/github/forks/yeyupiaoling/AudioClassification-Pytorch)
![GitHub Repo stars](https://img.shields.io/github/stars/yeyupiaoling/AudioClassification-Pytorch)
![GitHub](https://img.shields.io/github/license/yeyupiaoling/AudioClassification-Pytorch)
![支持系统](https://img.shields.io/badge/支持系统-Win/Linux/MAC-9cf)
**Disclaimer, this document was obtained through machine translation, please check the original document [here](./README.md).**
# Introduction
This project is a sound classification project based on Pytorch, aiming to realize the recognition of various environmental sounds, animal calls and languages. Several sound classification models such as EcapaTdnn, PANNS, ResNetSE, CAMPPlus, and ERes2Net are provided to support different application scenarios. In addition, the project also provides the commonly used Urbansound8K dataset test report and some dialect datasets download and use examples. Users can choose suitable models and datasets according to their needs to achieve more accurate sound classification. The project has a wide range of application scenarios, and can be used in outdoor environmental monitoring, wildlife protection, speech recognition and other fields. At the same time, the project also encourages users to explore more usage scenarios to promote the development and application of sound classification technology.
# Environment
- Anaconda 3
- Python 3.11
- Pytorch 2.0.1
- Windows 11 or Ubuntu 22.04
# Project Features
1. Supporting models: EcapaTdnn、PANNS、TDNN、Res2Net、ResNetSE、CAMPPlus、ERes2Net
2. Supporting pooling: AttentiveStatsPool(ASP)、SelfAttentivePooling(SAP)、TemporalStatisticsPooling(TSP)、TemporalAveragePooling(TAP)
3. Support preprocessing methods: MelSpectrogram、Spectrogram、MFCC、Fbank、Wav2vec2.0、WavLM
**Model Paper**
- EcapaTdnn[ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification](https://arxiv.org/abs/2005.07143v3)
- PANNS[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/abs/1912.10211v5)
- TDNN[Prediction of speech intelligibility with DNN-based performance measures](https://arxiv.org/abs/2203.09148)
- Res2Net[Res2Net: A New Multi-scale Backbone Architecture](https://arxiv.org/abs/1904.01169)
- ResNetSE[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
- CAMPPlus[CAM++: A Fast and Efficient Network for Speaker Verification Using Context-Aware Masking](https://arxiv.org/abs/2303.00332v3)
- ERes2Net[An Enhanced Res2Net with Local and Global Feature Fusion for Speaker Verification](https://arxiv.org/abs/2305.12838v1)
# Model Test
| Model | Params(M) | Preprocessing method | Dataset | Number Class | Accuracy |
|:------------:|:---------:|:--------------------:|:------------:|:------------:|:--------:|
| ResNetSE | 7.8 | Flank | UrbanSound8K | 10 | 0.96233 |
| ERes2NetV2 | 5.4 | Flank | UrbanSound8K | 10 | 0.95662 |
| CAMPPlus | 7.1 | Flank | UrbanSound8K | 10 | 0.95454 |
| EcapaTdnn | 6.4 | Flank | UrbanSound8K | 10 | 0.95227 |
| ERes2Net | 6.6 | Flank | UrbanSound8K | 10 | 0.94292 |
| TDNN | 2.6 | Flank | UrbanSound8K | 10 | 0.93977 |
| PANNSCNN10 | 5.2 | Flank | UrbanSound8K | 10 | 0.92954 |
| Res2Net | 5.0 | Flank | UrbanSound8K | 10 | 0.92580 |
## Installation Environment
- The GPU version of Pytorch will be installed first, please skip it if you already have it installed.
```shell
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia
```
- Install macls.
Install it using pip with the following command:
```shell
python -m pip install macls -U -i https://pypi.tuna.tsinghua.edu.cn/simple
```
**Source installation is recommended**, which ensures that the latest code is used.
```shell
git clone https://github.com/yeyupiaoling/AudioClassification_Pytorch.git
cd AudioClassification_Pytorch/
python setup.py install
```
## Preparing Data
The `audio_path` is the audio file path. The user needs to store the audio dataset in the `dataset/audio` directory in advance. Each folder stores a category of audio data, and the length of each audio data is more than 3 seconds. For example, `dataset/audio/ bird song /······`. `audio` is where the data list is stored, and the format of the generated data category is`audio_path\tcategory_label_audio`, and the audio path and label are separated by a TAB character `\t`. You can also modify the following functions depending on how you store your data:
Taking Urbansound8K as an example, it is a widely used public dataset for automatic urban environmental sound classification research. Urbansound8K contains 10 categories: air condition sound, car whistle sound, children playing sound, dog bark, drilling sound, engine idling sound, gun sound, jackdrill, siren sound, and street music sound. Data set download address: [UrbanSound8K](https://zenodo.org/record/1203745/files/UrbanSound8K.tar.gz). Here is the function to generate a list of data for Urbansound8K. If you want to use this dataset, please download and unzip it into the `dataset` directory and change the code to generate the list of data as follows.
`create_data.py` can be used to generate a list of data sets. There are many ways to generate a list of data sets.
```shell
python create_data.py
```
The resulting list looks like this, with the path to the audio followed by the tag for that audio, starting at 0, and separated by `\t`.
```shell
dataset/UrbanSound8K/audio/fold2/104817-4-0-2.wav 4
dataset/UrbanSound8K/audio/fold9/105029-7-2-5.wav 7
dataset/UrbanSound8K/audio/fold3/107228-5-0-0.wav 5
dataset/UrbanSound8K/audio/fold4/109711-3-2-4.wav 3
```
# Change preprocessing methods
By default, the Fbank preprocessing method is used in the configuration file. If you want to use other preprocessing methods, you can modify the following installation in the configuration file, and the specific value can be modified according to your own situation. If it's not clear how to set the parameters, you can remove that section and just use the default values.
```yaml
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
```
## 训练
Now we can train the model. We will create `train.py`. The parameters in the configuration file generally do not need to be modified, but these few need to be adjusted according to your actual dataset. The first and most important is the class size `dataset_conf.num_class`, which may be different for each dataset. Then there is` dataset_conf.batch_size `, which can be reduced if memory is insufficient.
```shell
# Single GPU training
CUDA_VISIBLE_DEVICES=0 python train.py
# Multi GPU training
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc_per_node=2 train.py
```
Train log:
```
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:14 - ----------- 额外配置参数 -----------
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - configs: configs/ecapa_tdnn.yml
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - local_rank: 0
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - pretrained_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - resume_model: None
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - save_model_path: models/
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:16 - use_gpu: True
[2023-08-07 22:54:22.148973 INFO ] utils:print_arguments:17 - ------------------------------------------------
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:19 - ----------- 配置文件参数 -----------
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:22 - dataset_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - aug_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - noise_aug_prob: 0.2
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - noise_dir: dataset/noise
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - speed_perturb: True
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - volume_aug_prob: 0.2
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - volume_perturb: False
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - dataLoader:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - batch_size: 64
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - num_workers: 4
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - do_vad: False
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - eval_conf:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - batch_size: 1
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - max_duration: 20
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - label_list_path: dataset/label_list.txt
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - max_duration: 3
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - min_duration: 0.5
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:29 - sample_rate: 16000
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:25 - spec_aug_args:
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - freq_mask_width: [0, 8]
[2023-08-07 22:54:22.202166 INFO ] utils:print_arguments:27 - time_mask_width: [0, 10]
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - target_dB: -20
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - test_list: dataset/test_list.txt
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - train_list: dataset/train_list.txt
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - use_dB_normalization: True
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:29 - use_spec_aug: True
[2023-08-07 22:54:22.203167 INFO ] utils:print_arguments:22 - model_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - num_class: 10
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - pooling_type: ASP
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:22 - optimizer_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - learning_rate: 0.001
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - optimizer: Adam
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - scheduler: WarmupCosineSchedulerLR
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:25 - scheduler_args:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - max_lr: 0.001
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - min_lr: 1e-05
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:27 - warmup_epoch: 5
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - weight_decay: 1e-06
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:22 - preprocess_conf:
[2023-08-07 22:54:22.207167 INFO ] utils:print_arguments:29 - feature_method: Fbank
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:25 - method_args:
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:27 - num_mel_bins: 80
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:27 - sample_frequency: 16000
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:22 - train_conf:
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - log_interval: 10
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:29 - max_epoch: 30
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:31 - use_model: EcapaTdnn
[2023-08-07 22:54:22.208167 INFO ] utils:print_arguments:32 - ------------------------------------------------
[2023-08-07 22:54:22.213166 WARNING] trainer:__init__:67 - Windows系统不支持多线程读取数据已自动关闭
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
EcapaTdnn [1, 10] --
├─Conv1dReluBn: 1-1 [1, 512, 98] --
│ └─Conv1d: 2-1 [1, 512, 98] 204,800
│ └─BatchNorm1d: 2-2 [1, 512, 98] 1,024
├─Sequential: 1-2 [1, 512, 98] --
│ └─Conv1dReluBn: 2-3 [1, 512, 98] --
│ │ └─Conv1d: 3-1 [1, 512, 98] 262,144
│ │ └─BatchNorm1d: 3-2 [1, 512, 98] 1,024
│ └─Res2Conv1dReluBn: 2-4 [1, 512, 98] --
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
│ │ └─ModuleList: 3-15 -- (recursive)
│ │ └─ModuleList: 3-16 -- (recursive)
···································
│ │ └─ModuleList: 3-56 -- (recursive)
│ │ └─ModuleList: 3-55 -- (recursive)
│ │ └─ModuleList: 3-56 -- (recursive)
│ │ └─ModuleList: 3-55 -- (recursive)
│ │ └─ModuleList: 3-56 -- (recursive)
│ └─Conv1dReluBn: 2-13 [1, 512, 98] --
│ │ └─Conv1d: 3-57 [1, 512, 98] 262,144
│ │ └─BatchNorm1d: 3-58 [1, 512, 98] 1,024
│ └─SE_Connect: 2-14 [1, 512, 98] --
│ │ └─Linear: 3-59 [1, 256] 131,328
│ │ └─Linear: 3-60 [1, 512] 131,584
├─Conv1d: 1-5 [1, 1536, 98] 2,360,832
├─AttentiveStatsPool: 1-6 [1, 3072] --
│ └─Conv1d: 2-15 [1, 128, 98] 196,736
│ └─Conv1d: 2-16 [1, 1536, 98] 198,144
├─BatchNorm1d: 1-7 [1, 3072] 6,144
├─Linear: 1-8 [1, 192] 590,016
├─BatchNorm1d: 1-9 [1, 192] 384
├─Linear: 1-10 [1, 10] 1,930
==========================================================================================
Total params: 6,188,490
Trainable params: 6,188,490
Non-trainable params: 0
Total mult-adds (M): 470.96
==========================================================================================
Input size (MB): 0.03
Forward/backward pass size (MB): 10.28
Params size (MB): 24.75
Estimated Total Size (MB): 35.07
==========================================================================================
[2023-08-07 22:54:26.726095 INFO ] trainer:train:344 - 训练数据8644
[2023-08-07 22:54:30.092504 INFO ] trainer:__train_epoch:296 - Train epoch: [1/30], batch: [0/4], loss: 2.57033, accuracy: 0.06250, learning rate: 0.00001000, speed: 19.02 data/sec, eta: 0:06:43
```
# Eval
At the end of each training round, we can perform an evaluation, which will output the accuracy. We also save the mixture matrix image, and save the path `output/images/` as follows.
![混合矩阵](docs/images/image1.png)
# Inference
At the end of the training, we are given a model parameter file, and we use this model to predict the audio.
```shell
python infer.py --audio_path=dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav
```
# Other Functions
- In order to read the recorded data and make a dataset easily, we provide the recording program `record_audio.py`, which is used to record audio with a sample rate of 16,000, single channel, 16bit.
```shell
python record_audio.py
```
- `infer_record.py`This program is used to continuously perform recording recognition, and we can roughly understand this program as recording recognition in real time. And this should allow us to do some interesting things, like put a microphone in a place where birds often come, and recognize it by recording it in real time, and once you recognize that there's a bird calling, if your dataset is powerful enough, and you have a dataset of every bird calling, then you can identify exactly which bird is calling. If the target bird is identified, the procedure is initiated, such as taking photos, etc.
```shell
python infer_record.py --record_seconds=3
```
# Reference
1. https://github.com/PaddlePaddle/PaddleSpeech
2. https://github.com/yeyupiaoling/PaddlePaddle-MobileFaceNets
3. https://github.com/yeyupiaoling/PPASR
4. https://github.com/alibaba-damo-academy/3D-Speaker

@ -0,0 +1,136 @@
# 声纹识别系统
这是一个基于深度学习的声纹识别系统包含Flask后端和Vue前端。
## 功能特性
- 🎵 支持多种音频格式上传识别
- 🎤 实时录音识别功能
- 📊 置信度可视化展示
- 📱 响应式设计,支持移动端
- 📝 识别历史记录
- 📋 结果导出功能
## 技术栈
### 后端
- Flask + Flask-CORS
- PyTorch + MAClsPredictor
- librosa + soundfile
### 前端
- Vue 3 + Composition API
- Element Plus UI框架
- Axios HTTP客户端
- Vite构建工具
## 快速开始
### 环境要求
- Python 3.8+
- Node.js 16+
- 已训练好的音频分类模型
### 1. 安装后端依赖
```bash
cd audio-classification-platform/backend
pip install -r requirements.txt
```
### 2. 安装前端依赖
```bash
cd audio-classification-platform/frontend
npm install
```
### 3. 启动后端服务
```bash
cd audio-classification-platform/backend
python app.py
```
后端服务将在 http://localhost:5000 启动
### 4. 启动前端服务
```bash
cd audio-classification-platform/frontend
npm run dev
```
前端服务将在 http://localhost:3000 启动
## 使用说明
1. 打开浏览器访问 http://localhost:3000
2. 首次使用需要初始化模型(确保模型文件路径正确)
3. 使用文件上传功能或实时录音功能进行音频识别
4. 查看识别结果和置信度
5. 可以查看历史识别记录
## 配置说明
### 后端配置
`backend/config.py` 中可以配置:
- 模型路径
- 上传文件限制
- GPU使用设置
- 音频处理参数
### 前端配置
`frontend/vite.config.js` 中可以配置:
- 开发服务器端口
- 代理设置
- 构建选项
## API接口
- `GET /api/health` - 健康检查
- `POST /api/init` - 初始化模型
- `POST /api/upload` - 上传音频文件识别
- `POST /api/predict` - 录音数据识别
- `GET /api/labels` - 获取分类标签
- `GET /api/model/info` - 获取模型信息
## 部署
### 生产环境构建
```bash
# 构建前端
cd frontend
npm run build
# 启动后端推荐使用gunicorn
cd backend
gunicorn -w 4 -b 0.0.0.0:5000 app:app
```
## 故障排除
1. **模型初始化失败**
- 检查模型文件路径是否正确
- 确保配置文件存在
- 检查GPU驱动如果使用GPU
2. **音频上传失败**
- 检查文件格式是否支持
- 检查文件大小是否超限
- 检查网络连接
3. **录音功能不可用**
- 检查浏览器麦克风权限
- 使用HTTPS协议录音功能需要
- 检查设备麦克风是否正常
## 许可证
Apache License 2.0

@ -0,0 +1,350 @@
import os
import sys
import tempfile
import uuid
from datetime import datetime
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from werkzeug.utils import secure_filename
import librosa
import soundfile as sf
import numpy as np
from loguru import logger
# 添加项目根目录到Python路径
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from macls.predict import MAClsPredictor
app = Flask(__name__)
CORS(app) # 允许跨域请求
# 配置
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 最大50MB
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['ALLOWED_EXTENSIONS'] = {'wav', 'mp3', 'flac', 'm4a', 'ogg', 'aac'}
# 确保上传目录存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 全局预测器实例
predictor = None
def allowed_file(filename):
"""检查文件扩展名是否被允许"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def convert_audio_format(input_path, target_sample_rate=16000):
"""将音频转换为模型需要的格式"""
try:
# 使用librosa加载音频自动转换采样率
audio_data, sr = librosa.load(input_path, sr=target_sample_rate)
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_path = temp_file.name
temp_file.close()
# 保存为wav格式
sf.write(temp_path, audio_data, target_sample_rate)
return temp_path, audio_data, target_sample_rate
except Exception as e:
logger.error(f"音频格式转换失败: {str(e)}")
raise e
@app.route('/api/health', methods=['GET'])
def health_check():
"""健康检查接口"""
return jsonify({
'status': 'success',
'message': '服务正常运行',
'timestamp': datetime.now().isoformat()
})
@app.route('/api/init', methods=['POST'])
def init_model():
"""初始化模型"""
global predictor
try:
data = request.get_json()
# 默认配置路径
configs = data.get('configs', '../../configs/cam++.yml')
model_path = data.get('model_path', '../../models/CAMPPlus_Fbank/best_model/')
use_gpu = data.get('use_gpu', True)
# 转换为绝对路径
configs = os.path.abspath(os.path.join(os.path.dirname(__file__), configs))
model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
if not os.path.exists(configs):
return jsonify({
'status': 'error',
'message': f'配置文件不存在: {configs}'
}), 400
if not os.path.exists(model_path):
return jsonify({
'status': 'error',
'message': f'模型路径不存在: {model_path}'
}), 400
# 初始化预测器
predictor = MAClsPredictor(
configs=configs,
model_path=model_path,
use_gpu=use_gpu
)
logger.info("模型初始化成功")
return jsonify({
'status': 'success',
'message': '模型初始化成功',
'config': {
'configs': configs,
'model_path': model_path,
'use_gpu': use_gpu
}
})
except Exception as e:
logger.error(f"模型初始化失败: {str(e)}")
return jsonify({
'status': 'error',
'message': f'模型初始化失败: {str(e)}'
}), 500
@app.route('/api/upload', methods=['POST'])
def upload_and_predict():
"""上传音频文件并进行预测"""
global predictor
if predictor is None:
return jsonify({
'status': 'error',
'message': '模型未初始化,请先调用 /api/init 接口'
}), 400
if 'file' not in request.files:
return jsonify({
'status': 'error',
'message': '没有上传文件'
}), 400
file = request.files['file']
if file.filename == '':
return jsonify({
'status': 'error',
'message': '没有选择文件'
}), 400
if file and allowed_file(file.filename):
try:
# 生成唯一文件名
filename = secure_filename(file.filename)
unique_filename = f"{uuid.uuid4()}_{filename}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
# 保存上传的文件
file.save(filepath)
# 转换音频格式
converted_path, audio_data, sample_rate = convert_audio_format(filepath)
# 进行预测
start_time = datetime.now()
label, score = predictor.predict(audio_data=converted_path)
end_time = datetime.now()
# 计算预测时间
prediction_time = (end_time - start_time).total_seconds()
# 清理临时文件
if os.path.exists(converted_path):
os.remove(converted_path)
# 可选:保留原始文件或删除
# os.remove(filepath)
logger.info(f"预测完成 - 文件: {filename}, 结果: {label}, 得分: {score:.4f}, 耗时: {prediction_time:.3f}s")
return jsonify({
'status': 'success',
'result': {
'predicted_class': label, # 前端期望的字段名
'confidence': float(score), # 前端期望的字段名
'label': label, # 保持兼容性
'score': float(score), # 保持兼容性
'filename': filename,
'prediction_time': prediction_time,
'audio_info': {
'sample_rate': sample_rate,
'duration': len(audio_data) / sample_rate
}
}
})
except Exception as e:
logger.error(f"预测失败: {str(e)}")
return jsonify({
'status': 'error',
'message': f'预测失败: {str(e)}'
}), 500
else:
return jsonify({
'status': 'error',
'message': '不支持的文件格式支持的格式wav, mp3, flac, m4a, ogg, aac'
}), 400
@app.route('/api/predict', methods=['POST'])
def predict_audio_data():
"""直接预测音频数据(用于录音功能)"""
global predictor
if predictor is None:
return jsonify({
'status': 'error',
'message': '模型未初始化,请先调用 /api/init 接口'
}), 400
try:
data = request.get_json()
if 'audio_data' not in data:
return jsonify({
'status': 'error',
'message': '缺少音频数据'
}), 400
audio_data = np.array(data['audio_data'], dtype=np.float32)
sample_rate = data.get('sample_rate', 16000)
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_path = temp_file.name
temp_file.close()
# 保存音频数据
sf.write(temp_path, audio_data, sample_rate)
# 进行预测
start_time = datetime.now()
label, score = predictor.predict(audio_data=temp_path)
end_time = datetime.now()
prediction_time = (end_time - start_time).total_seconds()
# 清理临时文件
os.remove(temp_path)
logger.info(f"录音预测完成 - 结果: {label}, 得分: {score:.4f}, 耗时: {prediction_time:.3f}s")
return jsonify({
'status': 'success',
'result': {
'predicted_class': label, # 前端期望的字段名
'confidence': float(score), # 前端期望的字段名
'label': label, # 保持兼容性
'score': float(score), # 保持兼容性
'prediction_time': prediction_time,
'audio_info': {
'sample_rate': sample_rate,
'duration': len(audio_data) / sample_rate
}
}
})
except Exception as e:
logger.error(f"录音预测失败: {str(e)}")
return jsonify({
'status': 'error',
'message': f'录音预测失败: {str(e)}'
}), 500
@app.route('/api/labels', methods=['GET'])
def get_labels():
"""获取分类标签列表"""
try:
# 读取标签文件
label_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../dataset/label_list.txt'))
if not os.path.exists(label_file_path):
return jsonify({
'status': 'error',
'message': '标签文件不存在'
}), 404
with open(label_file_path, 'r', encoding='utf-8') as f:
labels = [line.strip() for line in f.readlines() if line.strip()]
return jsonify({
'status': 'success',
'labels': labels,
'count': len(labels)
})
except Exception as e:
logger.error(f"读取标签失败: {str(e)}")
return jsonify({
'status': 'error',
'message': f'读取标签失败: {str(e)}'
}), 500
@app.route('/api/model/info', methods=['GET'])
def get_model_info():
"""获取模型信息"""
global predictor
if predictor is None:
return jsonify({
'status': 'error',
'message': '模型未初始化'
}), 400
try:
# 获取模型配置信息
return jsonify({
'status': 'success',
'model_info': {
'initialized': True,
'model_type': 'AudioClassification',
'framework': 'PyTorch'
}
})
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
return jsonify({
'status': 'error',
'message': f'获取模型信息失败: {str(e)}'
}), 500
@app.errorhandler(413)
def too_large(e):
"""文件过大错误处理"""
return jsonify({
'status': 'error',
'message': '文件过大最大支持50MB'
}), 413
@app.errorhandler(404)
def not_found(e):
"""404错误处理"""
return jsonify({
'status': 'error',
'message': '接口不存在'
}), 404
@app.errorhandler(500)
def internal_error(e):
"""500错误处理"""
return jsonify({
'status': 'error',
'message': '服务器内部错误'
}), 500
if __name__ == '__main__':
logger.info("启动音频分类API服务器...")
app.run(host='0.0.0.0', port=5000, debug=True)

@ -0,0 +1,39 @@
import os
import sys
# 配置
class Config:
# 模型配置
MODEL_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../configs/cam++.yml'))
MODEL_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../models/CAMPPlus_Fbank/best_model/'))
LABEL_FILE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../dataset/label_list.txt'))
# Flask配置
SECRET_KEY = 'your-secret-key-here'
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
# 文件上传配置
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a', 'ogg', 'aac'}
# 音频处理配置
TARGET_SAMPLE_RATE = 16000
# 日志配置
LOG_LEVEL = 'INFO'
# GPU配置
USE_GPU = True
class DevelopmentConfig(Config):
DEBUG = True
class ProductionConfig(Config):
DEBUG = False
# 根据环境变量选择配置
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}

@ -0,0 +1,26 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>声纹识别系统</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Microsoft YaHei', sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
}
</style>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.js"></script>
</body>
</html>

@ -0,0 +1,24 @@
{
"name": "audio-classification-frontend",
"version": "1.0.0",
"private": true,
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview",
"serve": "vite preview"
}, "dependencies": {
"vue": "^3.3.4",
"vue-router": "^4.2.4",
"axios": "^1.5.0",
"element-plus": "^2.3.9",
"@element-plus/icons-vue": "^2.1.0",
"echarts": "^5.4.3"
},
"devDependencies": {
"@vitejs/plugin-vue": "^4.3.4",
"vite": "^4.4.9",
"unplugin-vue-components": "^0.25.2",
"unplugin-auto-import": "^0.16.6"
}
}

@ -0,0 +1,106 @@
<template>
<div id="app">
<router-view />
</div>
</template>
<script setup>
//
</script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
#app {
min-height: 100vh;
background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%);
font-family: 'Inter', 'Helvetica Neue', Helvetica, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', '微软雅黑', Arial, sans-serif;
color: #333;
overflow-x: hidden;
}
/* 全局美化样式 */
body {
margin: 0;
padding: 0;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
/* 自定义滚动条 */
::-webkit-scrollbar {
width: 8px;
}
::-webkit-scrollbar-track {
background: rgba(255, 255, 255, 0.1);
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: rgba(255, 255, 255, 0.3);
border-radius: 4px;
transition: background 0.3s ease;
}
::-webkit-scrollbar-thumb:hover {
background: rgba(255, 255, 255, 0.5);
}
/* 全局动画 */
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-30px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
@keyframes slideInRight {
from {
opacity: 0;
transform: translateX(30px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
@keyframes pulse {
0% {
transform: scale(1);
}
50% {
transform: scale(1.05);
}
100% {
transform: scale(1);
}
}
/* 响应式设计 */
@media (max-width: 768px) {
#app {
font-size: 14px;
}
}
</style>

@ -0,0 +1,743 @@
<!-- 音频录制组件模板 -->
<template>
<div class="audio-recorder">
<div class="recorder-content">
<!-- 录音状态显示区域 -->
<div class="recorder-status">
<!-- 状态指示器根据录音状态动态切换样式 -->
<div class="status-indicator" :class="{
active: isRecording,
ready: isReady,
processing: isProcessing
}">
<div class="status-bg"></div>
<!-- 图标切换动画 -->
<transition name="icon-switch" mode="out-in">
<!-- 处理中状态图标 -->
<el-icon class="status-icon processing" v-if="isProcessing" key="processing">
<Loading />
</el-icon>
<!-- 录音中状态图标 -->
<el-icon class="status-icon recording" v-else-if="isRecording" key="recording">
<Microphone />
</el-icon>
<!-- 就绪状态图标 -->
<el-icon class="status-icon ready" v-else-if="isReady" key="ready">
<Microphone />
</el-icon>
<!-- 默认状态图标 -->
<el-icon class="status-icon default" v-else key="default">
<MicrophoneOne />
</el-icon>
</transition>
<!-- 录音波纹动画效果仅在录音时显示 -->
<div v-if="isRecording" class="recording-waves">
<div class="wave wave-1"></div>
<div class="wave wave-2"></div>
<div class="wave wave-3"></div>
</div>
</div>
<!-- 状态文本信息显示区域 -->
<div class="status-text">
<h3 class="status-title">{{ statusTitle }}</h3>
<p class="primary-text">{{ statusText }}</p>
<!-- 录音时长显示仅在录音时显示 -->
<transition name="slide-down">
<p class="secondary-text" v-if="recordingTime > 0">
<el-icon><Timer /></el-icon>
录音时长: {{ formatTime(recordingTime) }}
</p>
</transition>
</div>
</div>
<!-- 录音控制按钮区域 -->
<div class="recorder-controls">
<!-- 按钮切换动画 -->
<transition name="button-switch" mode="out-in">
<!-- 开始录音按钮仅在未录音且未处理时显示 -->
<el-button
v-if="!isRecording && !isProcessing"
type="primary"
size="large"
:disabled="disabled || !isReady"
@click="startRecording"
class="record-button"
key="start"
>
<template #icon>
<el-icon><Microphone /></el-icon>
</template>
<span>开始录音</span>
<div class="button-glow"></div>
</el-button>
<!-- 停止录音按钮仅在录音时显示 -->
<el-button
v-else-if="isRecording"
type="danger"
size="large"
@click="stopRecording"
class="stop-button"
key="stop"
>
<template #icon>
<el-icon><VideoPause /></el-icon>
</template>
<span>停止录音</span>
</el-button>
<!-- 处理中按钮显示加载状态 -->
<el-button
v-else
size="large"
loading
class="processing-button"
key="processing"
>
<span>正在处理...</span>
</el-button>
</transition>
</div>
<!-- 录音设置区域 -->
<div class="recorder-settings">
<!-- 设置区域标题 -->
<div class="settings-header">
<el-icon><Setting /></el-icon>
<span>录音设置</span>
</div>
<!-- 录音设置表单 -->
<el-form :model="settings" label-width="80px" size="small" class="settings-form">
<!-- 录音时长设置项 -->
<el-form-item label="录音时长">
<!-- 录音时长选择器录音时禁用 -->
<el-select v-model="settings.duration" :disabled="isRecording" class="duration-select">
<!-- 3秒选项 -->
<el-option label="3秒" :value="3">
<div class="option-content">
<span>3</span>
<el-tag size="small" type="info">快速</el-tag>
</div>
</el-option>
<!-- 5秒选项推荐 -->
<el-option label="5秒" :value="5">
<div class="option-content">
<span>5</span>
<el-tag size="small" type="success">推荐</el-tag>
</div>
</el-option>
<!-- 10秒选项 -->
<el-option label="10秒" :value="10">
<div class="option-content">
<span>10</span>
<el-tag size="small" type="warning">详细</el-tag>
</div>
</el-option>
<!-- 15秒和30秒选项 -->
<el-option label="15秒" :value="15" />
<el-option label="30秒" :value="30" />
</el-select>
</el-form-item>
<!-- 自动停止设置项 -->
<el-form-item label="自动停止">
<!-- 自动停止开关录音时禁用 -->
<el-switch
v-model="settings.autoStop"
:disabled="isRecording"
active-color="#52c41a"
inactive-color="rgba(255, 255, 255, 0.3)"
/>
</el-form-item>
</el-form>
</div>
<!-- 音频可视化区域仅在录音时显示 -->
<transition name="fade">
<div v-if="isRecording" class="audio-visualizer">
<!-- 可视化区域标题 -->
<div class="visualizer-header">
<el-icon><DataAnalysis /></el-icon>
<span>实时音频波形</span>
</div>
<!-- 波形画布 -->
<canvas ref="canvasRef" class="visualizer-canvas"></canvas>
<!-- 音量信息显示 -->
<div class="visualizer-info">
<div class="info-item">
<span class="label">音量:</span>
<div class="volume-bar">
<div class="volume-fill" :style="{ width: `${volume}%` }"></div>
</div>
</div>
</div>
</div>
</transition>
<!-- 录音预览区域有录音时显示 -->
<transition name="slide-up">
<div v-if="recordedAudio" class="audio-preview">
<!-- 预览区域头部 -->
<div class="preview-header">
<div class="preview-title">
<el-icon class="preview-icon"><Headphone /></el-icon>
<span>录音预览</span>
</div>
<!-- 预览操作按钮组 -->
<div class="preview-actions">
<!-- 播放录音按钮 -->
<el-button type="text" size="small" @click="playRecording" class="action-button">
<el-icon><VideoPlay /></el-icon>
播放
</el-button>
<!-- 删除录音按钮 -->
<el-button type="text" size="small" @click="clearRecording" class="action-button danger">
<el-icon><Delete /></el-icon>
删除
</el-button>
</div>
</div>
<!-- 音频播放器容器 -->
<div class="audio-player-container">
<!-- HTML5音频播放器 -->
<audio
ref="audioPlayerRef"
controls
:src="recordedAudio.url"
class="audio-player"
>
您的浏览器不支持音频播放
</audio>
</div>
<!-- 录音信息显示区域 -->
<div class="preview-info">
<!-- 录音时长信息 -->
<div class="info-item">
<el-icon><Timer /></el-icon>
<span>时长: {{ formatTime(recordedAudio.duration) }}</span>
</div>
<!-- 文件大小信息 -->
<div class="info-item">
<el-icon><DataBoard /></el-icon>
<span>大小: {{ formatFileSize(recordedAudio.size) }}</span>
</div>
</div>
<!-- 识别录音按钮 -->
<el-button
type="primary"
@click="submitRecording"
:loading="isSubmitting"
class="submit-button"
>
识别录音
</el-button>
</div>
</transition>
</div>
</div>
</template>
<script setup>
// Vue 3 Composition API
import { ref, computed, onMounted, onUnmounted } from 'vue'
// Element Plus
import { ElMessage } from 'element-plus'
// API
import { apiService } from '../utils/api'
//
const props = defineProps({
disabled: {
type: Boolean,
default: false //
}
})
//
const emit = defineEmits(['record-success', 'record-error'])
//
const isRecording = ref(false) //
const isReady = ref(false) //
const isSubmitting = ref(false) //
const isProcessing = ref(false) //
const recordingTime = ref(0) //
const recordedAudio = ref(null) //
const canvasRef = ref() //
const audioPlayerRef = ref() //
const volume = ref(0) //
//
let mediaRecorder = null // MediaRecorder
let audioChunks = [] //
let recordingTimer = null //
let audioContext = null // Web Audio API
let analyser = null //
let microphone = null //
let animationId = null // ID
//
const settings = ref({
duration: 5, // 5
autoStop: true //
})
//
const statusTitle = computed(() => {
if (isProcessing.value) return '处理中'
if (isRecording.value) return '录音中'
if (isReady.value) return '就绪'
return '准备中'
})
const statusText = computed(() => {
if (isProcessing.value) return '正在处理录音数据...'
if (!isReady.value) return '正在初始化麦克风...'
if (isRecording.value) return '正在录音中,请对着麦克风说话'
return '点击开始录音按钮进行音频识别'
})
//
const initRecorder = async () => {
try {
//
const stream = await navigator.mediaDevices.getUserMedia({
audio: {
sampleRate: 16000, // 16kHz
channelCount: 1, //
echoCancellation: true, //
noiseSuppression: true //
}
})
// MediaRecorder使WebM
mediaRecorder = new MediaRecorder(stream, {
mimeType: 'audio/webm;codecs=opus'
})
// Web Audio API
audioContext = new (window.AudioContext || window.webkitAudioContext)()
analyser = audioContext.createAnalyser()
microphone = audioContext.createMediaStreamSource(stream)
microphone.connect(analyser)
analyser.fftSize = 256 // FFT
// MediaRecorder
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
audioChunks.push(event.data) //
}
}
mediaRecorder.onstop = () => {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' })
createAudioPreview(audioBlob) //
audioChunks = [] //
}
isReady.value = true
ElMessage.success('麦克风初始化成功')
} catch (error) {
//
console.error('麦克风初始化失败:', error)
ElMessage.error('无法访问麦克风,请检查权限设置')
}
}
//
const startRecording = () => {
// MediaRecorder
if (!mediaRecorder || mediaRecorder.state !== 'inactive') return
audioChunks = [] //
recordingTime.value = 0 //
isRecording.value = true //
mediaRecorder.start() //
//
recordingTimer = setInterval(() => {
recordingTime.value++
//
if (settings.value.autoStop && recordingTime.value >= settings.value.duration) {
stopRecording()
}
}, 1000)
//
startVisualization()
ElMessage.info('开始录音')
}
//
const stopRecording = () => {
// MediaRecorder
if (!mediaRecorder || mediaRecorder.state !== 'recording') return
isRecording.value = false //
mediaRecorder.stop() //
//
if (recordingTimer) {
clearInterval(recordingTimer)
recordingTimer = null
}
//
stopVisualization()
ElMessage.success('录音完成')
}
//
const createAudioPreview = (audioBlob) => {
const url = URL.createObjectURL(audioBlob) // URL
recordedAudio.value = {
blob: audioBlob, // Blob
url: url, // URL
duration: recordingTime.value, //
size: audioBlob.size //
}
}
//
const playRecording = () => {
if (audioPlayerRef.value) {
audioPlayerRef.value.play() //
}
}
//
const clearRecording = () => {
if (recordedAudio.value) {
URL.revokeObjectURL(recordedAudio.value.url) // URL
recordedAudio.value = null //
}
recordingTime.value = 0 //
}
//
const submitRecording = async () => {
if (!recordedAudio.value) return
isSubmitting.value = true
try {
// BlobArrayBuffer
const arrayBuffer = await recordedAudio.value.blob.arrayBuffer()
//
const tempAudioContext = new (window.AudioContext || window.webkitAudioContext)()
const audioBuffer = await tempAudioContext.decodeAudioData(arrayBuffer)
//
const channelData = audioBuffer.getChannelData(0)
const audioData = Array.from(channelData)
//
await tempAudioContext.close()
// API
const response = await apiService.predictAudioData({
audio_data: audioData,
sample_rate: audioBuffer.sampleRate
})
if (response.data.status === 'success') {
emit('record-success', response.data.result) //
clearRecording() //
} else {
throw new Error(response.data.message)
}
} catch (error) {
emit('record-error', error.message) //
} finally {
isSubmitting.value = false //
}
}
//
const startVisualization = () => {
if (!canvasRef.value || !analyser) return
const canvas = canvasRef.value
const ctx = canvas.getContext('2d')
const bufferLength = analyser.frequencyBinCount
const dataArray = new Uint8Array(bufferLength)
//
const draw = () => {
if (!isRecording.value) return
animationId = requestAnimationFrame(draw) //
analyser.getByteFrequencyData(dataArray) //
//
ctx.fillStyle = 'rgb(255, 255, 255)'
ctx.fillRect(0, 0, canvas.width, canvas.height)
const barWidth = (canvas.width / bufferLength) * 2.5
let barHeight
let x = 0
//
for (let i = 0; i < bufferLength; i++) {
barHeight = dataArray[i] / 255 * canvas.height
ctx.fillStyle = `rgb(${barHeight + 100}, 102, 234)`
ctx.fillRect(x, canvas.height - barHeight, barWidth, barHeight)
x += barWidth + 1
}
}
draw() //
}
//
const stopVisualization = () => {
if (animationId) {
cancelAnimationFrame(animationId) //
animationId = null // ID
}
}
// :
const formatTime = (seconds) => {
const mins = Math.floor(seconds / 60) //
const secs = seconds % 60 //
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`
}
//
const formatFileSize = (bytes) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB']
const i = Math.floor(Math.log(bytes) / Math.log(k)) //
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
}
//
//
onMounted(() => {
initRecorder()
})
//
onUnmounted(() => {
//
if (recordingTimer) {
clearInterval(recordingTimer)
}
//
if (animationId) {
cancelAnimationFrame(animationId)
}
//
if (mediaRecorder && mediaRecorder.stream) {
mediaRecorder.stream.getTracks().forEach(track => track.stop())
}
//
if (audioContext && audioContext.state !== 'closed') {
audioContext.close()
}
//
clearRecording()
})
</script>
<style scoped>
.audio-recorder {
width: 100%;
}
.recorder-content {
text-align: center;
}
.recorder-status {
margin-bottom: 30px;
}
.status-indicator {
width: 80px;
height: 80px;
border-radius: 50%;
border: 3px solid #e4e7ed;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 15px;
transition: all 0.3s ease;
background: #f5f5f5;
}
.status-indicator.ready {
border-color: #67c23a;
background: #f0f9ff;
}
.status-indicator.active {
border-color: #e6a23c;
background: #fef0e6;
animation: pulse 2s infinite;
}
@keyframes pulse {
0% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(230, 162, 60, 0.7);
}
70% {
transform: scale(1.05);
box-shadow: 0 0 0 10px rgba(230, 162, 60, 0);
}
100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(230, 162, 60, 0);
}
}
.status-icon {
font-size: 2.5rem;
color: #999;
}
.status-indicator.ready .status-icon {
color: #67c23a;
}
.status-indicator.active .status-icon {
color: #e6a23c;
}
.status-text {
text-align: center;
}
.primary-text {
font-size: 1.1rem;
font-weight: 500;
color: #333;
margin-bottom: 5px;
}
.secondary-text {
font-size: 0.9rem;
color: #666;
}
.recorder-controls {
margin-bottom: 25px;
}
.record-button,
.stop-button {
min-width: 120px;
height: 45px;
font-size: 1rem;
}
.recorder-settings {
margin-bottom: 25px;
padding: 15px;
background: #f9f9f9;
border-radius: 8px;
}
.audio-visualizer {
margin-bottom: 25px;
padding: 15px;
background: #f0f0f0;
border-radius: 8px;
}
.visualizer-canvas {
width: 100%;
height: 100px;
border-radius: 4px;
background: white;
}
.audio-preview {
margin-top: 25px;
padding: 20px;
border: 1px solid #e4e7ed;
border-radius: 8px;
background: #f9f9f9;
}
.preview-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 15px;
font-weight: 500;
color: #333;
}
.preview-header span {
display: flex;
align-items: center;
gap: 8px;
}
.preview-actions {
display: flex;
gap: 10px;
}
.audio-player {
width: 100%;
margin-bottom: 15px;
}
.preview-info {
display: flex;
justify-content: space-between;
font-size: 0.9rem;
color: #666;
margin-bottom: 15px;
}
.submit-button {
width: 100%;
}
/* 响应式设计 */
@media (max-width: 480px) {
.status-indicator {
width: 60px;
height: 60px;
}
.status-icon {
font-size: 2rem;
}
.record-button,
.stop-button {
min-width: 100px;
height: 40px;
font-size: 0.9rem;
}
}
</style>

@ -0,0 +1,741 @@
<!-- 音频上传组件模板 -->
<template>
<div class="audio-upload">
<!-- Element Plus上传组件支持拖拽上传 -->
<el-upload
ref="uploadRef"
class="upload-dragger"
drag
:action="uploadUrl"
:headers="uploadHeaders"
:data="uploadData"
:on-success="handleSuccess"
:on-error="handleError"
:on-progress="handleProgress"
:before-upload="beforeUpload"
:disabled="disabled"
accept="audio/*,.wav,.mp3,.flac,.m4a,.ogg,.aac"
:show-file-list="false"
>
<div class="upload-content">
<!-- 上传图标和动画区域 -->
<div class="upload-icon-wrapper">
<!-- 图标切换动画 -->
<transition name="icon-switch" mode="out-in">
<!-- 上传成功图标 -->
<el-icon class="upload-icon success" v-if="uploadSuccess" key="success">
<Check />
</el-icon>
<!-- 上传中图标 -->
<el-icon class="upload-icon uploading" v-else-if="uploading" key="uploading">
<Loading />
</el-icon>
<!-- 默认上传图标 -->
<el-icon class="upload-icon default" v-else key="default">
<Upload />
</el-icon>
</transition>
<!-- 图标发光效果 -->
<div class="icon-glow" :class="{ active: uploading || uploadSuccess }"></div>
</div>
<!-- 上传文本信息区域 -->
<div class="upload-text">
<transition name="fade" mode="out-in">
<!-- 上传成功状态文本 -->
<div v-if="uploadSuccess" key="success" class="success-message">
<p class="primary-text success"> 文件上传成功</p>
<p class="secondary-text">正在进行音频识别...</p>
</div>
<!-- 上传中状态文本 -->
<div v-else-if="uploading" key="uploading" class="uploading-message">
<p class="primary-text uploading">正在上传并识别中...</p>
<p class="secondary-text">请稍等模型正在分析您的音频</p>
</div>
<!-- 默认状态文本 -->
<div v-else key="default" class="default-message">
<p class="primary-text">
<span class="highlight">点击</span> <span class="highlight">拖拽</span> 音频文件到此处
</p>
<p class="secondary-text">
支持 <span class="format-tag">WAV</span><span class="format-tag">MP3</span><span class="format-tag">FLAC</span><span class="format-tag">M4A</span> 等格式
</p>
<p class="secondary-text size-limit">
文件大小限制<span class="limit-value">50MB</span>
</p>
</div>
</transition>
</div>
<!-- 上传进度显示区域 -->
<transition name="slide-down">
<div v-if="uploading" class="progress-container">
<!-- 进度条容器 -->
<div class="progress-wrapper">
<!-- Element Plus进度条组件 -->
<el-progress
:percentage="uploadProgress"
:stroke-width="8"
:show-text="false"
:color="progressColor"
class="custom-progress"
/>
<!-- 进度指示器 -->
<div class="progress-indicator">
<span class="progress-text">{{ uploadProgress }}%</span>
<!-- 进度动画点 -->
<div class="progress-dots">
<span class="dot"></span>
<span class="dot"></span>
<span class="dot"></span>
</div>
</div>
</div>
</div>
</transition>
</div>
</el-upload>
<!-- 音频预览区域 -->
<transition name="slide-up">
<div v-if="audioPreview" class="audio-preview">
<!-- 预览区域头部 -->
<div class="preview-header">
<div class="preview-title">
<el-icon class="preview-icon"><Headphone /></el-icon>
<span>音频预览</span>
</div>
<!-- 关闭预览按钮 -->
<el-button type="text" size="small" @click="clearPreview" class="close-button">
<el-icon><Close /></el-icon>
</el-button>
</div>
<!-- 音频播放器容器 -->
<div class="audio-player-container">
<!-- HTML5音频播放器 -->
<audio
ref="audioPlayerRef"
controls
:src="audioPreview.url"
class="audio-player"
>
您的浏览器不支持音频播放
</audio>
</div>
<!-- 音频信息显示区域 -->
<div class="audio-info">
<!-- 文件名信息 -->
<div class="info-item">
<el-icon><Document /></el-icon>
<span>{{ audioPreview.name }}</span>
</div>
<!-- 文件大小信息 -->
<div class="info-item">
<el-icon><DataBoard /></el-icon>
<span>{{ formatFileSize(audioPreview.size) }}</span>
</div>
</div>
</div>
</transition>
</div>
</template>
<script setup>
// Vue 3 Composition API
import { ref, computed } from 'vue'
// Element Plus
import { ElMessage } from 'element-plus'
//
const props = defineProps({
disabled: {
type: Boolean,
default: false //
}
})
//
const emit = defineEmits(['upload-success', 'upload-error'])
//
const uploadRef = ref() //
const audioPlayerRef = ref() //
const uploading = ref(false) //
const uploadSuccess = ref(false) //
const uploadProgress = ref(0) //
const audioPreview = ref(null) //
//
const progressColor = computed(() => {
//
if (uploadProgress.value < 30) return '#409eff' //
if (uploadProgress.value < 70) return '#e6a23c' //
return '#67c23a' // 绿
})
//
const uploadUrl = '/api/upload' //
const uploadHeaders = {
// Content-Type boundary
}
const uploadData = {} //
//
const beforeUpload = (file) => {
//
const allowedTypes = ['audio/wav', 'audio/mpeg', 'audio/flac', 'audio/m4a', 'audio/ogg', 'audio/aac']
const fileExtension = file.name.split('.').pop().toLowerCase()
const allowedExtensions = ['wav', 'mp3', 'flac', 'm4a', 'ogg', 'aac']
//
if (!allowedTypes.includes(file.type) && !allowedExtensions.includes(fileExtension)) {
ElMessage.error('不支持的文件格式,请上传音频文件')
return false
}
// (50MB)
const maxSize = 50 * 1024 * 1024
if (file.size > maxSize) {
ElMessage.error('文件大小不能超过50MB')
return false
}
//
createAudioPreview(file)
//
uploading.value = true
uploadProgress.value = 0
return true
}
//
const createAudioPreview = (file) => {
const url = URL.createObjectURL(file) // URL
audioPreview.value = {
name: file.name, //
size: file.size, //
url: url // URL
}
}
//
const clearPreview = () => {
if (audioPreview.value) {
URL.revokeObjectURL(audioPreview.value.url) // URL
audioPreview.value = null //
}
}
//
const handleProgress = (event) => {
uploadProgress.value = Math.round(event.percent) //
}
//
const handleSuccess = (response, file) => {
uploading.value = false
uploadProgress.value = 100
if (response.status === 'success') {
emit('upload-success', response.result, file.name) //
ElMessage.success('音频上传并识别成功')
} else {
emit('upload-error', response.message) //
ElMessage.error(`识别失败: ${response.message}`)
}
}
//
const handleError = (error, file) => {
uploading.value = false
uploadProgress.value = 0
//
let errorMessage = '上传失败'
try {
const errorData = JSON.parse(error.message)
errorMessage = errorData.message || errorMessage
} catch (e) {
errorMessage = error.message || errorMessage
}
emit('upload-error', errorMessage) //
ElMessage.error(`上传失败: ${errorMessage}`)
}
//
const formatFileSize = (bytes) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB']
const i = Math.floor(Math.log(bytes) / Math.log(k)) //
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
}
</script>
<style scoped>
.audio-upload {
width: 100%;
}
.upload-dragger {
width: 100%;
}
:deep(.el-upload) {
width: 100%;
}
:deep(.el-upload-dragger) {
width: 100%;
height: 280px;
border: 2px dashed rgba(255, 255, 255, 0.3);
border-radius: 20px;
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
overflow: hidden;
}
:deep(.el-upload-dragger:hover) {
border-color: rgba(255, 255, 255, 0.6);
background: rgba(255, 255, 255, 0.1);
transform: translateY(-2px);
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2);
}
:deep(.el-upload-dragger.is-dragover) {
border-color: #52c41a;
background: rgba(82, 196, 26, 0.1);
transform: scale(1.02);
}
.upload-content {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
padding: 30px;
position: relative;
}
/* 图标样式 */
.upload-icon-wrapper {
position: relative;
margin-bottom: 25px;
}
.upload-icon {
font-size: 4rem;
transition: all 0.4s ease;
display: block;
}
.upload-icon.default {
color: rgba(255, 255, 255, 0.6);
}
.upload-icon.uploading {
color: #409eff;
animation: rotating 2s linear infinite;
}
.upload-icon.success {
color: #52c41a;
animation: bounce 0.6s ease;
}
.icon-glow {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 100px;
height: 100px;
border-radius: 50%;
background: radial-gradient(circle, rgba(255, 255, 255, 0.2) 0%, transparent 70%);
opacity: 0;
transition: opacity 0.3s ease;
}
.icon-glow.active {
opacity: 1;
animation: pulse 2s ease-in-out infinite;
}
/* 动画 */
@keyframes rotating {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
@keyframes bounce {
0%, 20%, 53%, 80%, 100% { transform: scale(1); }
40%, 43% { transform: scale(1.2); }
70% { transform: scale(1.1); }
}
@keyframes pulse {
0%, 100% { opacity: 0.6; transform: translate(-50%, -50%) scale(1); }
50% { opacity: 1; transform: translate(-50%, -50%) scale(1.1); }
}
/* 文本样式 */
.upload-text {
text-align: center;
width: 100%;
}
.primary-text {
font-size: 1.3rem;
font-weight: 600;
color: white;
margin-bottom: 12px;
line-height: 1.4;
}
.primary-text.uploading {
color: #409eff;
}
.primary-text.success {
color: #52c41a;
}
.highlight {
color: #409eff;
font-weight: 700;
text-shadow: 0 0 10px rgba(64, 158, 255, 0.3);
}
.secondary-text {
font-size: 1rem;
color: rgba(255, 255, 255, 0.7);
margin-bottom: 8px;
line-height: 1.3;
}
.format-tag {
display: inline-block;
padding: 2px 8px;
background: rgba(255, 255, 255, 0.2);
border-radius: 12px;
font-size: 0.85rem;
font-weight: 500;
margin: 0 2px;
}
.size-limit {
font-size: 0.9rem;
}
.limit-value {
color: #409eff;
font-weight: 600;
}
/* 进度条样式 */
.progress-container {
width: 100%;
max-width: 350px;
margin-top: 25px;
}
.progress-wrapper {
position: relative;
}
.custom-progress {
margin-bottom: 15px;
}
:deep(.el-progress-bar__outer) {
background: rgba(255, 255, 255, 0.2);
border-radius: 20px;
overflow: hidden;
}
:deep(.el-progress-bar__inner) {
border-radius: 20px;
background: linear-gradient(90deg, #409eff 0%, #52c41a 100%);
transition: all 0.3s ease;
}
.progress-indicator {
display: flex;
align-items: center;
justify-content: space-between;
}
.progress-text {
font-size: 1rem;
font-weight: 600;
color: #409eff;
}
.progress-dots {
display: flex;
gap: 4px;
}
.dot {
width: 6px;
height: 6px;
border-radius: 50%;
background: #409eff;
animation: dot-pulse 1.5s ease-in-out infinite;
}
.dot:nth-child(2) {
animation-delay: 0.2s;
}
.dot:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes dot-pulse {
0%, 60%, 100% {
opacity: 0.3;
transform: scale(1);
}
30% {
opacity: 1;
transform: scale(1.3);
}
}
/* 音频预览样式 */
.audio-preview {
margin-top: 25px;
padding: 25px;
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 16px;
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
transition: all 0.3s ease;
}
.audio-preview:hover {
background: rgba(255, 255, 255, 0.15);
transform: translateY(-2px);
}
.preview-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 20px;
}
.preview-title {
display: flex;
align-items: center;
gap: 10px;
font-weight: 600;
color: white;
font-size: 1.1rem;
}
.preview-icon {
color: #409eff;
font-size: 1.2rem;
}
.close-button {
color: rgba(255, 255, 255, 0.6);
transition: all 0.3s ease;
border-radius: 50%;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
}
.close-button:hover {
color: #ff4d4f;
background: rgba(255, 77, 79, 0.1);
transform: scale(1.1);
}
.audio-player-container {
margin-bottom: 20px;
padding: 15px;
background: rgba(255, 255, 255, 0.05);
border-radius: 12px;
}
.audio-player {
width: 100%;
height: 40px;
border-radius: 8px;
background: rgba(255, 255, 255, 0.1);
}
:deep(.audio-player::-webkit-media-controls-panel) {
background: rgba(255, 255, 255, 0.1);
}
.audio-info {
display: flex;
justify-content: space-between;
gap: 20px;
}
.info-item {
display: flex;
align-items: center;
gap: 8px;
font-size: 0.95rem;
color: rgba(255, 255, 255, 0.8);
padding: 8px 12px;
background: rgba(255, 255, 255, 0.1);
border-radius: 20px;
flex: 1;
justify-content: center;
}
.info-item .el-icon {
color: #409eff;
}
/* 过渡动画 */
.icon-switch-enter-active,
.icon-switch-leave-active {
transition: all 0.3s ease;
}
.icon-switch-enter-from {
opacity: 0;
transform: scale(0.8) rotate(90deg);
}
.icon-switch-leave-to {
opacity: 0;
transform: scale(0.8) rotate(-90deg);
}
.fade-enter-active,
.fade-leave-active {
transition: all 0.4s ease;
}
.fade-enter-from,
.fade-leave-to {
opacity: 0;
transform: translateY(10px);
}
.slide-down-enter-active,
.slide-down-leave-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.slide-down-enter-from {
opacity: 0;
transform: translateY(-20px);
max-height: 0;
}
.slide-down-leave-to {
opacity: 0;
transform: translateY(-20px);
max-height: 0;
}
.slide-up-enter-active,
.slide-up-leave-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.slide-up-enter-from {
opacity: 0;
transform: translateY(20px);
}
.slide-up-leave-to {
opacity: 0;
transform: translateY(-20px);
}
/* 禁用状态样式 */
:deep(.el-upload.is-disabled .el-upload-dragger) {
background: rgba(255, 255, 255, 0.02);
border-color: rgba(255, 255, 255, 0.1);
cursor: not-allowed;
}
:deep(.el-upload.is-disabled .upload-icon) {
color: rgba(255, 255, 255, 0.3);
}
:deep(.el-upload.is-disabled .primary-text) {
color: rgba(255, 255, 255, 0.4);
}
:deep(.el-upload.is-disabled .secondary-text) {
color: rgba(255, 255, 255, 0.3);
}
/* 响应式设计 */
@media (max-width: 768px) {
:deep(.el-upload-dragger) {
height: 240px;
}
.upload-content {
padding: 20px;
}
.upload-icon {
font-size: 3rem;
}
.primary-text {
font-size: 1.1rem;
}
.secondary-text {
font-size: 0.9rem;
}
.audio-info {
flex-direction: column;
gap: 10px;
}
}
@media (max-width: 480px) {
:deep(.el-upload-dragger) {
height: 200px;
}
.upload-content {
padding: 15px;
}
.upload-icon {
font-size: 2.5rem;
}
.primary-text {
font-size: 1rem;
}
.secondary-text {
font-size: 0.85rem;
}
.audio-preview {
padding: 20px;
}
}</style>

@ -0,0 +1,635 @@
<!-- 音频识别历史记录列表组件模板 -->
<template>
<div class="history-list">
<!-- 空状态显示当没有历史记录时显示 -->
<div v-if="history.length === 0" class="empty-state">
<el-icon class="empty-icon"><Document /></el-icon>
<p>暂无识别历史</p>
</div>
<!-- 历史记录列表有记录时显示 -->
<div v-else class="history-items">
<!-- 遍历分页后的历史记录项 -->
<div
v-for="(item, index) in paginatedHistory"
:key="index"
class="history-item"
@click="selectItem(item)"
> <!-- 历史记录项头部信息 -->
<div class="item-header">
<!-- 识别结果标题区域 -->
<div class="item-title">
<!-- 识别结果标签根据置信度显示不同颜色 -->
<el-tag
:type="getResultType(item.confidence || item.score)"
size="small"
class="result-tag"
>
{{ item.predicted_class || item.label }}
</el-tag>
<!-- 置信度百分比显示 -->
<span class="confidence">
{{ ((item.confidence || item.score) * 100).toFixed(1) }}%
</span>
</div>
<!-- 元数据信息区域 -->
<div class="item-meta">
<!-- 来源标识上传或录音 -->
<span class="source-badge" :class="item.source">
<el-icon>
<Upload v-if="item.source === 'upload'" />
<Microphone v-else />
</el-icon>
{{ item.source === 'upload' ? '上传' : '录音' }}
</span>
<!-- 时间戳显示 -->
<span class="timestamp">{{ formatTime(item.timestamp) }}</span>
</div>
</div>
<!-- 详细信息区域 -->
<div class="item-details">
<!-- 文件名信息如果有 -->
<div v-if="item.filename" class="detail-row">
<el-icon><Document /></el-icon>
<span>{{ truncateFilename(item.filename) }}</span>
</div>
<!-- 音频时长信息如果有 -->
<div v-if="item.audio_info" class="detail-row">
<el-icon><Timer /></el-icon>
<span>{{ formatDuration(item.audio_info.duration) }}</span>
</div>
<!-- 预测耗时信息 -->
<div class="detail-row">
<el-icon><Cpu /></el-icon>
<span>{{ item.prediction_time?.toFixed(3) }}s</span>
</div>
</div>
<!-- 置信度可视化进度条 -->
<div class="confidence-bar">
<!-- 置信度填充条根据置信度值动态调整宽度和颜色 -->
<div
class="confidence-fill"
:style="{
width: `${(item.confidence || item.score) * 100}%`,
backgroundColor: getConfidenceColor(item.confidence || item.score)
}"
></div>
</div>
</div>
</div>
<!-- 分页组件当历史记录数量超过页面大小时显示 -->
<div v-if="history.length > pageSize" class="pagination">
<!-- Element Plus分页组件 -->
<el-pagination
:current-page="currentPage"
:page-size="pageSize"
:total="history.length"
layout="prev, pager, next"
@current-change="handlePageChange"
small
/>
</div>
</div>
</template>
<script setup>
// Vue 3 Composition API
import { ref, computed } from 'vue'
// Element Plus
import {
Document,
Upload,
Microphone,
Timer,
Cpu
} from '@element-plus/icons-vue'
//
const props = defineProps({
history: {
type: Array,
default: () => [] //
}
})
//
const emit = defineEmits(['select-item'])
//
const currentPage = ref(1) //
const pageSize = ref(10) //
//
const paginatedHistory = computed(() => {
const start = (currentPage.value - 1) * pageSize.value //
const end = start + pageSize.value //
return props.history.slice(start, end) //
})
//
const getResultType = (score) => {
if (score >= 0.8) return 'success' // 绿
if (score >= 0.6) return 'warning' //
return 'danger' //
}
//
const getConfidenceColor = (score) => {
if (score >= 0.8) return '#67c23a' // 绿
if (score >= 0.6) return '#e6a23c' //
return '#f56c6c' //
}
//
const formatTime = (timestamp) => {
const date = new Date(timestamp)
const now = new Date()
const diffInSeconds = Math.floor((now - date) / 1000) //
if (diffInSeconds < 60) {
return '刚刚' // 1
} else if (diffInSeconds < 3600) {
return `${Math.floor(diffInSeconds / 60)}分钟前` // 1
} else if (diffInSeconds < 86400) {
return `${Math.floor(diffInSeconds / 3600)}小时前` // 24
} else {
return date.toLocaleDateString() + ' ' + date.toLocaleTimeString().slice(0, 5) // 24
}
}
//
const formatDuration = (seconds) => {
if (!seconds) return 'N/A' // N/A
if (seconds < 60) {
return `${seconds.toFixed(1)}s` // 1
}
const minutes = Math.floor(seconds / 60) //
const remainingSeconds = seconds % 60 //
return `${minutes}m${remainingSeconds.toFixed(1)}s` // ms
}
//
const truncateFilename = (filename, maxLength = 20) => {
if (!filename || filename.length <= maxLength) return filename //
const extension = filename.split('.').pop() //
const nameWithoutExt = filename.slice(0, -(extension.length + 1)) //
const truncatedName = nameWithoutExt.slice(0, maxLength - extension.length - 4) + '...' //
return truncatedName + '.' + extension //
}
//
const selectItem = (item) => {
emit('select-item', item) //
}
//
const handlePageChange = (page) => {
currentPage.value = page //
}
</script>
<style scoped>
/* 历史列表主容器 */
.history-list {
width: 100%;
animation: fadeIn 0.6s ease-out;
position: relative;
}
/* 空状态 */
.empty-state {
text-align: center;
padding: 80px 30px;
background: rgba(255, 255, 255, 0.08);
backdrop-filter: blur(25px);
border: 1px solid rgba(255, 255, 255, 0.15);
border-radius: 24px;
animation: fadeInUp 0.6s ease-out;
box-shadow:
0 8px 32px rgba(0, 0, 0, 0.1),
inset 0 1px 0 rgba(255, 255, 255, 0.2);
position: relative;
overflow: hidden;
}
.empty-state::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(135deg, rgba(255, 255, 255, 0.1) 0%, transparent 50%, rgba(255, 255, 255, 0.05) 100%);
pointer-events: none;
}
.empty-icon {
font-size: 5rem;
color: rgba(255, 255, 255, 0.5);
margin-bottom: 25px;
animation: float 3s ease-in-out infinite;
filter: drop-shadow(0 4px 15px rgba(255, 255, 255, 0.2));
}
.empty-state p {
font-size: 1.2rem;
margin: 0;
color: rgba(255, 255, 255, 0.8);
font-weight: 500;
}
/* 历史项目列表 */
.history-items {
display: flex;
flex-direction: column;
gap: 20px;
}
/* 历史项目卡片 */
.history-item {
padding: 25px;
background: rgba(255, 255, 255, 0.12);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 16px;
cursor: pointer;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
overflow: hidden;
animation: slideInLeft 0.6s ease-out;
animation-fill-mode: both;
}
.history-item:nth-child(even) {
animation-delay: 0.1s;
}
.history-item:nth-child(odd) {
animation-delay: 0.2s;
}
.history-item::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.1), transparent);
transition: left 0.5s ease;
}
.history-item:hover::before {
left: 100%;
}
.history-item:hover {
background: rgba(255, 255, 255, 0.18);
border-color: rgba(102, 126, 234, 0.5);
transform: translateY(-5px) scale(1.02);
box-shadow: 0 15px 40px rgba(102, 126, 234, 0.2);
}
/* 项目头部 */
.item-header {
display: flex;
justify-content: space-between;
align-items: flex-start;
margin-bottom: 20px;
}
.item-title {
display: flex;
align-items: center;
gap: 15px;
}
.result-tag {
font-weight: 600 !important;
border-radius: 12px !important;
padding: 6px 14px !important;
border: none !important;
backdrop-filter: blur(10px);
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
.result-tag.el-tag--success {
background: rgba(82, 196, 26, 0.2) !important;
color: #52c41a !important;
border: 1px solid rgba(82, 196, 26, 0.3) !important;
}
.result-tag.el-tag--warning {
background: rgba(250, 219, 20, 0.2) !important;
color: #fadb14 !important;
border: 1px solid rgba(250, 219, 20, 0.3) !important;
}
.result-tag.el-tag--danger {
background: rgba(255, 77, 79, 0.2) !important;
color: #ff4d4f !important;
border: 1px solid rgba(255, 77, 79, 0.3) !important;
}
.confidence {
font-size: 1rem;
font-weight: 700;
color: white;
background: rgba(64, 158, 255, 0.2);
padding: 6px 12px;
border-radius: 20px;
backdrop-filter: blur(10px);
border: 1px solid rgba(64, 158, 255, 0.3);
}
/* 项目元数据 */
.item-meta {
display: flex;
flex-direction: column;
align-items: flex-end;
gap: 8px;
}
.source-badge {
display: flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
backdrop-filter: blur(10px);
border: 1px solid;
transition: all 0.3s ease;
}
.source-badge:hover {
transform: scale(1.05);
}
.source-badge.upload {
background: rgba(82, 196, 26, 0.15);
color: #52c41a;
border-color: rgba(82, 196, 26, 0.3);
}
.source-badge.record {
background: rgba(250, 219, 20, 0.15);
color: #fadb14;
border-color: rgba(250, 219, 20, 0.3);
}
.timestamp {
font-size: 0.85rem;
color: rgba(255, 255, 255, 0.7);
font-weight: 500;
}
/* 项目详情 */
.item-details {
display: flex;
flex-wrap: wrap;
gap: 20px;
margin-bottom: 15px;
}
.detail-row {
display: flex;
align-items: center;
gap: 8px;
font-size: 0.9rem;
color: rgba(255, 255, 255, 0.8);
background: rgba(255, 255, 255, 0.1);
padding: 8px 14px;
border-radius: 20px;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.2);
transition: all 0.3s ease;
}
.detail-row:hover {
background: rgba(255, 255, 255, 0.15);
transform: translateY(-2px);
}
.detail-row .el-icon {
color: #409eff;
font-size: 1rem;
}
/* 置信度进度条 */
.confidence-bar {
width: 100%;
height: 8px;
background: rgba(255, 255, 255, 0.1);
border-radius: 4px;
overflow: hidden;
position: relative;
}
.confidence-bar::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(255, 255, 255, 0.05);
border-radius: 4px;
}
.confidence-fill {
height: 100%;
border-radius: 4px;
transition: all 0.8s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
background: linear-gradient(90deg, #52c41a, #73d13d);
box-shadow: 0 0 15px rgba(82, 196, 26, 0.3);
animation: progressGlow 2s ease-in-out infinite alternate;
}
.confidence-fill::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.3), transparent);
animation: shimmer 2s ease-in-out infinite;
}
/* 分页 */
.pagination {
margin-top: 30px;
display: flex;
justify-content: center;
padding: 20px;
background: rgba(255, 255, 255, 0.08);
border-radius: 16px;
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.15);
}
/* 自定义分页样式 */
:deep(.el-pagination) {
--el-pagination-bg-color: rgba(255, 255, 255, 0.1);
--el-pagination-text-color: white;
--el-pagination-border-radius: 8px;
}
:deep(.el-pagination .btn-prev),
:deep(.el-pagination .btn-next),
:deep(.el-pagination .el-pager li) {
background: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
color: white !important;
backdrop-filter: blur(10px);
transition: all 0.3s ease;
}
:deep(.el-pagination .btn-prev:hover),
:deep(.el-pagination .btn-next:hover),
:deep(.el-pagination .el-pager li:hover) {
background: rgba(255, 255, 255, 0.2) !important;
transform: translateY(-2px);
}
:deep(.el-pagination .el-pager li.is-active) {
background: rgba(64, 158, 255, 0.3) !important;
border-color: rgba(64, 158, 255, 0.5) !important;
color: white !important;
box-shadow: 0 0 15px rgba(64, 158, 255, 0.4);
}
/* 动画 */
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(30px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-30px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
@keyframes float {
0%, 100% {
transform: translateY(0);
}
50% {
transform: translateY(-10px);
}
}
@keyframes progressGlow {
0% {
box-shadow: 0 0 5px rgba(82, 196, 26, 0.3);
}
100% {
box-shadow: 0 0 20px rgba(82, 196, 26, 0.6);
}
}
@keyframes shimmer {
0% {
transform: translateX(-100%);
}
100% {
transform: translateX(100%);
}
}
/* 响应式设计 */
@media (max-width: 768px) {
.history-item {
padding: 20px;
}
.item-header {
flex-direction: column;
align-items: flex-start;
gap: 15px;
}
.item-meta {
flex-direction: row;
align-items: center;
gap: 15px;
width: 100%;
justify-content: space-between;
}
.item-details {
flex-direction: column;
gap: 12px;
}
.detail-row {
justify-content: center;
}
}
@media (max-width: 480px) {
.history-item {
padding: 15px;
}
.item-title {
flex-direction: column;
align-items: flex-start;
gap: 10px;
}
.confidence {
font-size: 0.9rem;
padding: 4px 10px;
}
.source-badge {
font-size: 0.8rem;
padding: 4px 10px;
}
.detail-row {
font-size: 0.85rem;
padding: 6px 12px;
}
.empty-state {
padding: 40px 15px;
}
.empty-icon {
font-size: 3rem;
}
.empty-state p {
font-size: 1rem;
}
}
</style>

@ -0,0 +1,995 @@
<!-- 音频预测结果展示组件模板 -->
<template>
<div class="prediction-result" v-if="result">
<!-- 主要预测结果展示卡片 -->
<div class="main-result-card">
<!-- 结果卡片头部区域 -->
<div class="result-header">
<!-- 结果图标 -->
<div class="result-icon">
<el-icon><TrophyBase /></el-icon>
</div>
<!-- 结果标题和描述 -->
<div class="result-title">
<h3>识别结果</h3>
<p>深度学习模型分析完成</p>
</div>
<!-- 操作按钮区域 -->
<div class="result-actions">
<!-- 导出结果按钮 -->
<el-button type="primary" size="small" @click="exportResult" class="action-btn">
<template #icon><el-icon><Download /></el-icon></template>
导出
</el-button>
<!-- 分享结果按钮 -->
<el-button size="small" @click="shareResult" class="action-btn">
<template #icon><el-icon><Share /></el-icon></template>
分享
</el-button>
</div>
</div>
<!-- 主要预测结果展示区域 -->
<div class="main-prediction">
<!-- 预测结果徽章 -->
<div class="prediction-badge">
<!-- 徽章图标 -->
<div class="badge-icon">
<el-icon><Star /></el-icon>
</div>
<!-- 徽章内容区域 -->
<div class="badge-content">
<!-- 预测类别名称 -->
<div class="predicted-class">{{ result.predicted_class || result.label }}</div>
<!-- 置信度分数显示 -->
<div class="confidence-score">
置信度: <span class="confidence-value">{{ ((result.confidence || result.score) * 100).toFixed(2) }}%</span>
</div>
</div>
<!-- 置信度环形进度条 -->
<div class="confidence-ring">
<!-- SVG环形进度图 -->
<svg class="ring-svg" viewBox="0 0 100 100">
<!-- 背景圆环 -->
<circle
class="ring-background"
cx="50"
cy="50"
r="40"
fill="none"
stroke="rgba(255, 255, 255, 0.2)"
stroke-width="8"
/>
<!-- 进度圆环 -->
<circle
class="ring-progress"
cx="50"
cy="50"
r="40"
fill="none"
stroke="url(#gradient)"
stroke-width="8"
stroke-linecap="round"
:stroke-dasharray="circumference"
:stroke-dashoffset="strokeDashoffset"
transform="rotate(-90 50 50)"
/>
<!-- 渐变定义 -->
<defs>
<linearGradient id="gradient" x1="0%" y1="0%" x2="100%" y2="0%">
<stop offset="0%" style="stop-color:#52c41a"/>
<stop offset="100%" style="stop-color:#73d13d"/>
</linearGradient>
</defs>
</svg>
<!-- 环形中心文本显示 -->
<div class="ring-text">{{ Math.round((result.confidence || result.score) * 100) }}%</div>
</div>
</div>
</div>
</div>
<!-- 预测结果详细信息卡片 -->
<div class="details-card">
<!-- 详细信息卡片头部 -->
<div class="details-header">
<el-icon><InfoFilled /></el-icon>
<span>详细信息</span>
</div>
<!-- 详细信息网格布局 -->
<div class="details-grid">
<!-- 预测类别信息项 -->
<div class="detail-item">
<div class="detail-icon">
<el-icon><Flag /></el-icon>
</div>
<div class="detail-content">
<div class="detail-label">预测类别</div>
<div class="detail-value">{{ result.predicted_class || result.label }}</div>
</div>
</div>
<!-- 置信度信息项 -->
<div class="detail-item">
<div class="detail-icon">
<el-icon><DataAnalysis /></el-icon>
</div>
<div class="detail-content">
<div class="detail-label">置信度</div>
<div class="detail-value">{{ ((result.confidence || result.score) * 100).toFixed(2) }}%</div>
</div>
</div>
<!-- 预测时间信息项 -->
<div class="detail-item">
<div class="detail-icon">
<el-icon><Timer /></el-icon>
</div>
<div class="detail-content">
<div class="detail-label">预测时间</div>
<div class="detail-value">{{ formatTime(result.timestamp) }}</div>
</div>
</div>
<!-- 音频时长信息项 -->
<div class="detail-item">
<div class="detail-icon">
<el-icon><Headphone /></el-icon>
</div>
<div class="detail-content">
<div class="detail-label">音频时长</div>
<div class="detail-value">{{ formatDuration(result.audio_info?.duration) }}</div>
</div>
</div>
</div>
</div>
<!-- 所有类别概率分布展示卡片 -->
<div v-if="result.all_probabilities" class="probability-card">
<!-- 概率分布卡片头部 -->
<div class="probability-header">
<el-icon><PieChart /></el-icon>
<span>所有类别概率分布</span>
<!-- 视图切换按钮 -->
<el-button type="text" size="small" @click="toggleChartView" class="toggle-view">
<el-icon><Switch /></el-icon>
{{ showChart ? '列表视图' : '图表视图' }}
</el-button>
</div>
<!-- ECharts图表视图 -->
<transition name="fade">
<div v-if="showChart" class="chart-container">
<div ref="chartContainer" class="echarts-chart"></div>
</div>
</transition>
<!-- 概率列表视图 -->
<transition name="fade">
<div v-if="!showChart" class="probability-list">
<!-- 遍历排序后的概率数据 -->
<div
v-for="(prob, className) in sortedProbabilities"
:key="className"
class="probability-item"
:class="{ active: className === (result.predicted_class || result.label) }"
>
<!-- 概率信息显示区域 -->
<div class="prob-info">
<!-- 类别名称和最高标识 -->
<div class="class-name">
<span class="name-text">{{ className }}</span>
<!-- 最高概率标识标签 -->
<el-tag v-if="className === (result.predicted_class || result.label)"
size="small" type="success" class="winner-tag">
<el-icon><Trophy /></el-icon>
最高
</el-tag>
</div>
<!-- 概率百分比值 -->
<div class="prob-value">{{ (prob * 100).toFixed(2) }}%</div>
</div> <!-- 概率进度条容器 -->
<div class="prob-bar-container">
<div class="prob-bar">
<!-- 概率填充条根据概率值动态调整宽度和颜色 -->
<div
class="prob-fill"
:style="{
width: `${prob * 100}%`,
background: getProgressColor(prob, className === (result.predicted_class || result.label))
}"
></div>
</div>
</div>
</div>
</div>
</transition>
</div>
</div>
<!-- 无结果状态显示 -->
<div v-else class="no-result">
<div class="empty-state">
<div class="empty-icon">
<el-icon><Document /></el-icon>
</div>
<h3>暂无预测结果</h3>
<p>请上传音频文件或录制音频进行识别</p>
</div>
</div>
</template>
<script setup>
// Vue 3 Composition API
import { ref, computed, watch, onMounted, nextTick, onBeforeUnmount } from 'vue'
// Element Plus
import { ElMessage } from 'element-plus'
// ECharts
import * as echarts from 'echarts'
//
const props = defineProps({
result: {
type: Object,
default: null // null
}
})
//
const chartContainer = ref(null) // ECharts
const showChart = ref(false) //
let chartInstance = null // ECharts
//
const circumference = computed(() => 2 * Math.PI * 40) // r=40
// SVGstroke-dashoffset
const strokeDashoffset = computed(() => {
if (!props.result) return circumference.value
const confidence = props.result.confidence || props.result.score || 0
return circumference.value - (confidence * circumference.value) //
})
//
const sortedProbabilities = computed(() => {
if (!props.result?.all_probabilities) return {}
const entries = Object.entries(props.result.all_probabilities)
entries.sort((a, b) => b[1] - a[1]) //
return Object.fromEntries(entries) //
})
//
const formatTime = (timestamp) => {
if (!timestamp) return 'N/A'
try {
if (typeof timestamp === 'string') {
return timestamp //
}
//
return new Date(timestamp).toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit'
})
} catch (error) {
console.error('时间格式化错误:', error)
return 'N/A'
}
}
//
const formatDuration = (seconds) => {
if (!seconds && seconds !== 0) return 'N/A'
try {
const duration = parseFloat(seconds)
if (isNaN(duration)) return 'N/A'
if (duration < 60) {
return `${duration.toFixed(1)}` // 1
}
const minutes = Math.floor(duration / 60) //
const remainingSeconds = (duration % 60).toFixed(1) //
return `${minutes}${remainingSeconds}` //
} catch (error) {
console.error('时长格式化错误:', error)
return 'N/A'
}
}
//
const getProgressColor = (prob, isWinner = false) => {
if (isWinner) {
return 'linear-gradient(135deg, #52c41a, #73d13d)' // 绿
}
if (prob > 0.7) return 'linear-gradient(135deg, #52c41a, #73d13d)' // 绿
if (prob > 0.4) return 'linear-gradient(135deg, #fadb14, #ffec3d)' //
if (prob > 0.2) return 'linear-gradient(135deg, #fa8c16, #ffa940)' //
return 'linear-gradient(135deg, #ff4d4f, #ff7875)' //
}
// ECharts
const initChart = () => {
if (!props.result?.all_probabilities || !chartContainer.value) return
if (chartInstance) {
chartInstance.dispose() //
}
chartInstance = echarts.init(chartContainer.value) //
//
const data = Object.entries(props.result.all_probabilities)
.map(([name, value]) => ({ name, value: (value * 100).toFixed(2) })) //
.sort((a, b) => b.value - a.value) //
// ECharts
const option = {
title: {
text: '类别置信度分布',
left: 'center'
},
tooltip: {
trigger: 'item',
formatter: '{a} <br/>{b}: {c}%' //
},
legend: {
orient: 'vertical',
left: 'left' //
},
series: [
{
name: '置信度',
type: 'pie', //
radius: '50%', //
data: data,
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowOffsetX: 0,
shadowColor: 'rgba(0, 0, 0, 0.5)' //
}
}
}
]
}
chartInstance.setOption(option) //
}
// JSON
const exportResult = () => {
try {
//
const exportData = {
predicted_class: props.result.predicted_class,
confidence: props.result.confidence,
timestamp: props.result.timestamp,
all_probabilities: props.result.all_probabilities
}
// JSON
const blob = new Blob([JSON.stringify(exportData, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `prediction_result_${Date.now()}.json` //
a.click()
URL.revokeObjectURL(url) // URL
ElMessage.success('结果导出成功')
} catch (error) {
ElMessage.error('导出失败: ' + error.message)
}
}
//
const shareResult = () => {
const shareText = `音频分类结果: ${props.result.predicted_class} (置信度: ${(props.result.confidence * 100).toFixed(2)}%)`
if (navigator.share) {
// 使API
navigator.share({
title: '音频分类结果',
text: shareText
})
} else {
//
navigator.clipboard.writeText(shareText).then(() => {
ElMessage.success('结果已复制到剪贴板')
}).catch(() => {
ElMessage.error('复制失败')
})
}
}
//
watch(() => props.result, () => {
if (props.result) {
nextTick(() => {
initChart() // DOM
})
}
}, { immediate: true })
//
onMounted(() => {
if (props.result) {
nextTick(() => {
initChart() // DOM
})
}
})
</script>
<style scoped>
/* 预测结果主容器 */
.prediction-result {
margin-top: 20px;
display: flex;
flex-direction: column;
gap: 25px;
animation: fadeInUp 0.6s ease-out;
}
/* 主要结果卡片 */
.main-result-card {
background: rgba(255, 255, 255, 0.15);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 20px;
padding: 30px;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
overflow: hidden;
}
.main-result-card::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
transition: left 0.5s ease;
}
.main-result-card:hover::before {
left: 100%;
}
.main-result-card:hover {
transform: translateY(-5px);
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.15);
}
/* 结果头部 */
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 30px;
}
.result-icon {
width: 50px;
height: 50px;
background: linear-gradient(135deg, #52c41a, #73d13d);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-size: 1.5rem;
box-shadow: 0 8px 25px rgba(82, 196, 26, 0.3);
animation: pulse 2s ease-in-out infinite;
}
.result-title h3 {
color: white;
font-size: 1.8rem;
font-weight: 700;
margin: 0 0 5px 0;
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.3);
}
.result-title p {
color: rgba(255, 255, 255, 0.8);
margin: 0;
font-size: 1rem;
}
.result-actions {
display: flex;
gap: 15px;
}
.action-btn {
background: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
color: white !important;
transition: all 0.3s ease;
backdrop-filter: blur(10px);
border-radius: 12px !important;
}
.action-btn:hover {
background: rgba(255, 255, 255, 0.2) !important;
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
}
/* 主要预测结果 */
.main-prediction {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
.prediction-badge {
display: flex;
align-items: center;
gap: 25px;
padding: 25px;
background: rgba(255, 255, 255, 0.1);
border-radius: 20px;
backdrop-filter: blur(15px);
border: 1px solid rgba(255, 255, 255, 0.2);
}
.badge-icon {
width: 60px;
height: 60px;
background: linear-gradient(135deg, #fadb14, #ffec3d);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-size: 1.8rem;
box-shadow: 0 8px 25px rgba(250, 219, 20, 0.4);
}
.badge-content {
flex: 1;
}
.predicted-class {
font-size: 2rem;
font-weight: 700;
color: white;
margin-bottom: 8px;
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.3);
}
.confidence-score {
font-size: 1.2rem;
color: rgba(255, 255, 255, 0.9);
}
.confidence-value {
font-weight: 700;
color: #52c41a;
text-shadow: 0 0 10px rgba(82, 196, 26, 0.5);
}
/* 置信度环形图 */
.confidence-ring {
position: relative;
width: 100px;
height: 100px;
}
.ring-svg {
width: 100%;
height: 100%;
transform: rotate(-90deg);
}
.ring-background {
stroke: rgba(255, 255, 255, 0.2);
}
.ring-progress {
stroke: url(#gradient);
stroke-linecap: round;
transition: stroke-dashoffset 1s ease-in-out;
}
.ring-text {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
font-size: 1.2rem;
font-weight: 700;
color: white;
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.3);
}
/* 详细信息卡片 */
.details-card {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 16px;
padding: 25px;
transition: all 0.3s ease;
}
.details-card:hover {
transform: translateY(-3px);
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.1);
}
.details-header {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 20px;
color: white;
font-size: 1.3rem;
font-weight: 600;
}
.details-header .el-icon {
color: #409eff;
font-size: 1.4rem;
}
.details-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 20px;
}
.detail-item {
display: flex;
align-items: center;
gap: 15px;
padding: 20px;
background: rgba(255, 255, 255, 0.1);
border-radius: 12px;
border: 1px solid rgba(255, 255, 255, 0.2);
transition: all 0.3s ease;
}
.detail-item:hover {
background: rgba(255, 255, 255, 0.15);
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1);
}
.detail-icon {
width: 45px;
height: 45px;
background: linear-gradient(135deg, #409eff, #52c41a);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-size: 1.2rem;
flex-shrink: 0;
}
.detail-content {
flex: 1;
}
.detail-label {
font-size: 0.9rem;
color: rgba(255, 255, 255, 0.7);
margin-bottom: 5px;
}
.detail-value {
font-size: 1.1rem;
font-weight: 600;
color: white;
}
/* 概率分布卡片 */
.probability-card {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 16px;
padding: 25px;
transition: all 0.3s ease;
}
.probability-card:hover {
transform: translateY(-3px);
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.1);
}
.probability-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 25px;
color: white;
font-size: 1.3rem;
font-weight: 600;
}
.probability-header .el-icon {
color: #fa8c16;
font-size: 1.4rem;
margin-right: 10px;
}
.toggle-view {
background: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
color: white !important;
border-radius: 8px !important;
}
.toggle-view:hover {
background: rgba(255, 255, 255, 0.2) !important;
}
/* 图表容器 */
.chart-container {
background: rgba(255, 255, 255, 0.05);
border-radius: 12px;
padding: 20px;
min-height: 400px;
}
.echarts-chart {
width: 100%;
height: 400px;
}
/* 概率列表 */
.probability-list {
display: flex;
flex-direction: column;
gap: 15px;
}
.probability-item {
padding: 20px;
background: rgba(255, 255, 255, 0.1);
border-radius: 12px;
border: 1px solid rgba(255, 255, 255, 0.2);
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.probability-item:hover {
background: rgba(255, 255, 255, 0.15);
transform: translateX(5px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1);
}
.probability-item.active {
background: rgba(82, 196, 26, 0.2);
border-color: rgba(82, 196, 26, 0.5);
box-shadow: 0 0 20px rgba(82, 196, 26, 0.3);
}
.prob-info {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
}
.class-name {
display: flex;
align-items: center;
gap: 10px;
}
.name-text {
font-size: 1.1rem;
font-weight: 600;
color: white;
}
.winner-tag {
background: rgba(82, 196, 26, 0.2) !important;
border-color: rgba(82, 196, 26, 0.5) !important;
color: #52c41a !important;
}
.prob-value {
font-size: 1.2rem;
font-weight: 700;
color: white;
}
.prob-bar-container {
width: 100%;
height: 8px;
background: rgba(255, 255, 255, 0.1);
border-radius: 4px;
overflow: hidden;
}
.prob-bar {
width: 100%;
height: 100%;
position: relative;
}
.prob-fill {
height: 100%;
border-radius: 4px;
transition: all 0.8s cubic-bezier(0.4, 0, 0.2, 1);
background: linear-gradient(90deg, #52c41a, #73d13d);
position: relative;
animation: progressFill 1s ease-out;
}
/* 空状态 */
.no-result {
text-align: center;
padding: 60px 20px;
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 16px;
}
.empty-state {
animation: fadeIn 0.6s ease-out;
}
.empty-icon {
font-size: 4rem;
color: rgba(255, 255, 255, 0.4);
margin-bottom: 20px;
animation: float 3s ease-in-out infinite;
}
.empty-state h3 {
color: white;
font-size: 1.5rem;
font-weight: 600;
margin-bottom: 10px;
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.3);
}
.empty-state p {
color: rgba(255, 255, 255, 0.7);
font-size: 1.1rem;
margin: 0;
}
/* 动画 */
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(30px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
@keyframes pulse {
0%, 100% {
transform: scale(1);
}
50% {
transform: scale(1.05);
}
}
@keyframes float {
0%, 100% {
transform: translateY(0);
}
50% {
transform: translateY(-10px);
}
}
@keyframes progressFill {
from {
width: 0%;
}
to {
width: var(--target-width, 100%);
}
}
/* 过渡动画 */
.fade-enter-active,
.fade-leave-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.fade-enter-from,
.fade-leave-to {
opacity: 0;
transform: translateY(20px);
}
/* 响应式设计 */
@media (max-width: 768px) {
.prediction-result {
gap: 20px;
}
.main-result-card {
padding: 20px;
}
.result-header {
flex-direction: column;
gap: 15px;
text-align: center;
}
.prediction-badge {
flex-direction: column;
text-align: center;
gap: 20px;
}
.details-grid {
grid-template-columns: 1fr;
gap: 15px;
}
.prob-info {
flex-direction: column;
gap: 10px;
text-align: center;
}
.chart-container {
min-height: 300px;
}
.echarts-chart {
height: 300px;
}
}
@media (max-width: 480px) {
.main-result-card,
.details-card,
.probability-card {
padding: 15px;
}
.predicted-class {
font-size: 1.5rem;
}
.confidence-ring {
width: 80px;
height: 80px;
}
.detail-item {
padding: 15px;
}
.probability-item {
padding: 15px;
}
}
</style>

@ -0,0 +1,18 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
const app = createApp(App)
// 注册所有图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
app.component(key, component)
}
app.use(ElementPlus)
app.use(router)
app.mount('#app')

@ -0,0 +1,17 @@
import { createRouter, createWebHistory } from 'vue-router'
import HomePage from '../views/HomePage.vue'
const routes = [
{
path: '/',
name: 'Home',
component: HomePage
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
export default router

@ -0,0 +1,213 @@
import axios from 'axios'
import { ElMessage } from 'element-plus'
// 创建axios实例
const api = axios.create({
baseURL: '/api',
timeout: 30000,
headers: {
'Content-Type': 'application/json'
}
})
// 请求拦截器
api.interceptors.request.use(
(config) => {
// 可以在这里添加token等认证信息
return config
},
(error) => {
return Promise.reject(error)
}
)
// 响应拦截器
api.interceptors.response.use(
(response) => {
return response
},
(error) => {
let errorMessage = '请求失败'
if (error.response) {
// 服务器返回错误状态码
const { status, data } = error.response
switch (status) {
case 400:
errorMessage = data.message || '请求参数错误'
break
case 401:
errorMessage = '未授权访问'
break
case 403:
errorMessage = '禁止访问'
break
case 404:
errorMessage = '请求的资源不存在'
break
case 413:
errorMessage = '文件过大最大支持50MB'
break
case 500:
errorMessage = data.message || '服务器内部错误'
break
default:
errorMessage = data.message || `请求失败 (${status})`
}
} else if (error.request) {
// 网络错误
errorMessage = '网络连接失败,请检查网络设置'
} else {
// 其他错误
errorMessage = error.message || '未知错误'
}
// 显示错误消息
ElMessage.error(errorMessage)
return Promise.reject(new Error(errorMessage))
}
)
// API服务类
export const apiService = {
// 健康检查
healthCheck: () => {
return api.get('/health')
},
// 初始化模型
initModel: (config) => {
return api.post('/init', config)
},
// 上传文件并预测
uploadAndPredict: (file, onProgress) => {
const formData = new FormData()
formData.append('file', file)
return api.post('/upload', formData, {
// 不要手动设置 Content-Type让浏览器自动设置 boundary
onUploadProgress: (progressEvent) => {
if (onProgress && progressEvent.total) {
const percent = Math.round((progressEvent.loaded * 100) / progressEvent.total)
onProgress(percent)
}
}
})
},
// 预测音频数据(录音)
predictAudioData: (audioData) => {
return api.post('/predict', audioData)
},
// 获取标签列表
getLabels: () => {
return api.get('/labels')
},
// 获取模型信息
getModelInfo: () => {
return api.get('/model/info')
}
}
// 工具函数
export const utils = {
// 格式化文件大小
formatFileSize: (bytes) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
const i = Math.floor(Math.log(bytes) / Math.log(k))
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
},
// 格式化时长
formatDuration: (seconds) => {
if (!seconds) return '0s'
const hrs = Math.floor(seconds / 3600)
const mins = Math.floor((seconds % 3600) / 60)
const secs = Math.floor(seconds % 60)
if (hrs > 0) {
return `${hrs}:${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`
} else if (mins > 0) {
return `${mins}:${secs.toString().padStart(2, '0')}`
} else {
return `${secs}s`
}
},
// 检查音频文件格式
isAudioFile: (file) => {
const allowedTypes = ['audio/wav', 'audio/mpeg', 'audio/flac', 'audio/m4a', 'audio/ogg', 'audio/aac']
const allowedExtensions = ['wav', 'mp3', 'flac', 'm4a', 'ogg', 'aac']
const fileExtension = file.name.split('.').pop().toLowerCase()
return allowedTypes.includes(file.type) || allowedExtensions.includes(fileExtension)
},
// 检查文件大小
isFileSizeValid: (file, maxSizeMB = 50) => {
const maxSize = maxSizeMB * 1024 * 1024
return file.size <= maxSize
},
// 生成唯一ID
generateId: () => {
return Date.now().toString(36) + Math.random().toString(36).substr(2)
},
// 下载文件
downloadFile: (content, filename, type = 'application/json') => {
const blob = new Blob([content], { type })
const url = URL.createObjectURL(blob)
const link = document.createElement('a')
link.href = url
link.download = filename
document.body.appendChild(link)
link.click()
document.body.removeChild(link)
URL.revokeObjectURL(url)
},
// 复制到剪贴板
copyToClipboard: async (text) => {
try {
await navigator.clipboard.writeText(text)
return true
} catch (error) {
console.error('复制失败:', error)
return false
}
},
// 获取音频时长
getAudioDuration: (file) => {
return new Promise((resolve, reject) => {
const audio = new Audio()
const url = URL.createObjectURL(file)
audio.addEventListener('loadedmetadata', () => {
URL.revokeObjectURL(url)
resolve(audio.duration)
})
audio.addEventListener('error', (error) => {
URL.revokeObjectURL(url)
reject(error)
})
audio.src = url
})
}
}
export default api

@ -0,0 +1,841 @@
<template>
<div class="home-page">
<!-- 装饰性背景元素 -->
<div class="background-decoration">
<div class="decoration-circle circle-1"></div>
<div class="decoration-circle circle-2"></div>
<div class="decoration-circle circle-3"></div>
</div>
<!-- 顶部导航 -->
<header class="header">
<div class="container">
<div class="header-content">
<h1 class="title">
<div class="title-wrapper">
<el-icon class="title-icon"><Microphone /></el-icon>
<span class="title-text">声纹识别系统</span>
</div>
<div class="title-glow"></div>
</h1>
<div class="subtitle">
<span class="subtitle-text">基于深度学习的音频分类平台</span>
<div class="subtitle-decoration"></div>
</div>
</div>
</div>
</header>
<!-- 主要内容区域 -->
<main class="main-content">
<div class="container">
<!-- 系统状态卡片 -->
<el-card class="status-card glass-card" shadow="never">
<div class="status-info">
<div class="status-display">
<div class="status-icon-wrapper">
<el-icon class="status-icon" :class="{ active: modelStatus === 'ready' }">
<Cpu />
</el-icon>
</div>
<div class="status-text">
<h3>模型状态</h3>
<el-tag
:type="modelStatus === 'ready' ? 'success' : 'warning'"
size="large"
class="status-tag"
>
{{ modelStatus === 'ready' ? '✓ 模型已就绪' : '⚠ 模型未初始化' }}
</el-tag>
</div>
</div>
<el-button
v-if="modelStatus !== 'ready'"
type="primary"
@click="initModel"
:loading="initLoading"
class="init-button"
size="large"
>
<template #icon>
<el-icon><Setting /></el-icon>
</template>
初始化模型
</el-button>
</div>
</el-card>
<!-- 功能区域 -->
<div class="function-area">
<!-- 文件上传区域 -->
<el-card class="upload-card glass-card" shadow="never">
<template #header>
<div class="card-header">
<div class="header-icon">
<el-icon><Upload /></el-icon>
</div>
<div class="header-text">
<h3>音频文件上传</h3>
<p>支持多种音频格式快速识别</p>
</div>
</div>
</template>
<AudioUpload
@upload-success="handleUploadSuccess"
@upload-error="handleUploadError"
:disabled="modelStatus !== 'ready'"
/>
</el-card>
<!-- 录音区域 -->
<el-card class="record-card glass-card" shadow="never">
<template #header>
<div class="card-header">
<div class="header-icon">
<el-icon><Microphone /></el-icon>
</div>
<div class="header-text">
<h3>实时录音识别</h3>
<p>直接录制音频进行实时分析</p>
</div>
</div>
</template>
<AudioRecorder
@record-success="handleRecordSuccess"
@record-error="handleRecordError"
:disabled="modelStatus !== 'ready'"
/>
</el-card>
</div>
<!-- 结果展示区域 -->
<transition name="slide-up" appear>
<el-card v-if="predictionResult" class="result-card glass-card" shadow="never">
<template #header>
<div class="card-header">
<div class="header-icon result-icon">
<el-icon><DataAnalysis /></el-icon>
</div>
<div class="header-text">
<h3>预测结果</h3>
<p>深度学习模型分析结果</p>
</div>
<div class="result-badge">
<el-icon><TrophyBase /></el-icon>
</div>
</div>
</template>
<PredictionResult :result="predictionResult" />
</el-card>
</transition>
<!-- 历史记录 -->
<el-card class="history-card glass-card" shadow="never">
<template #header>
<div class="card-header">
<div class="header-icon">
<el-icon><Clock /></el-icon>
</div>
<div class="header-text">
<h3>识别历史</h3>
<p>查看所有识别记录和结果</p>
</div>
<el-button
type="text"
size="small"
@click="clearHistory"
v-if="historyList.length > 0"
class="clear-button"
>
<el-icon><Delete /></el-icon>
清空历史
</el-button>
</div>
</template>
<HistoryList :history="historyList" @select-item="handleSelectHistory" />
</el-card>
</div>
</main>
<!-- 底部装饰 -->
<div class="footer-decoration">
<div class="wave wave-1"></div>
<div class="wave wave-2"></div>
</div>
</div>
</template>
<script setup>
import { ref, onMounted } from 'vue'
import { ElMessage, ElNotification } from 'element-plus'
import AudioUpload from '../components/AudioUpload.vue'
import AudioRecorder from '../components/AudioRecorder_new.vue'
import PredictionResult from '../components/PredictionResult.vue'
import HistoryList from '../components/HistoryList.vue'
import { apiService } from '../utils/api'
//
const modelStatus = ref('unknown') // unknown, ready, loading
const initLoading = ref(false)
const predictionResult = ref(null)
const historyList = ref([])
//
const initModel = async () => {
initLoading.value = true
try {
const response = await apiService.initModel({
configs: '../../configs/cam++.yml',
model_path: '../../models/CAMPPlus_Fbank/best_model/',
use_gpu: true
})
if (response.data.status === 'success') {
modelStatus.value = 'ready'
ElMessage.success('模型初始化成功')
} else {
throw new Error(response.data.message)
}
} catch (error) {
ElMessage.error(`模型初始化失败: ${error.message}`)
modelStatus.value = 'error'
} finally {
initLoading.value = false
}
}
//
const checkServerStatus = async () => {
try {
await apiService.healthCheck()
//
await initModel()
} catch (error) {
ElMessage.error('无法连接到服务器,请确保后端服务已启动')
}
}
//
const handleUploadSuccess = (result, filename) => {
predictionResult.value = {
...result,
filename,
source: 'upload',
timestamp: new Date().toLocaleString()
}
//
addToHistory(predictionResult.value)
ElNotification({
title: '识别成功',
message: `预测结果: ${result.predicted_class || result.label} (置信度: ${((result.confidence || result.score) * 100).toFixed(2)}%)`,
type: 'success'
})
}
//
const handleUploadError = (error) => {
ElMessage.error(`上传失败: ${error}`)
}
//
const handleRecordSuccess = (result) => {
predictionResult.value = {
...result,
source: 'record',
timestamp: new Date().toLocaleString()
}
//
addToHistory(predictionResult.value)
ElNotification({
title: '识别成功',
message: `预测结果: ${result.predicted_class || result.label} (置信度: ${((result.confidence || result.score) * 100).toFixed(2)}%)`,
type: 'success'
})
}
//
const handleRecordError = (error) => {
ElMessage.error(`录音识别失败: ${error}`)
}
//
const addToHistory = (result) => {
historyList.value.unshift(result)
// 20
if (historyList.value.length > 20) {
historyList.value = historyList.value.slice(0, 20)
}
// localStorage
localStorage.setItem('audio-classification-history', JSON.stringify(historyList.value))
}
//
const handleSelectHistory = (item) => {
predictionResult.value = item
}
//
const clearHistory = () => {
historyList.value = []
localStorage.removeItem('audio-classification-history')
ElMessage.success('历史记录已清空')
}
//
const loadHistory = () => {
const saved = localStorage.getItem('audio-classification-history')
if (saved) {
try {
historyList.value = JSON.parse(saved)
} catch (error) {
console.error('加载历史记录失败:', error)
}
}
}
//
onMounted(() => {
loadHistory()
checkServerStatus()
})
</script>
<style scoped>
.home-page {
min-height: 100vh;
position: relative;
overflow-x: hidden;
}
/* 装饰性背景元素 */
.background-decoration {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
pointer-events: none;
z-index: 0;
}
.decoration-circle {
position: absolute;
border-radius: 50%;
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
animation: float 6s ease-in-out infinite;
}
.circle-1 {
width: 200px;
height: 200px;
top: 10%;
right: 10%;
animation-delay: 0s;
}
.circle-2 {
width: 150px;
height: 150px;
bottom: 20%;
left: 5%;
animation-delay: 2s;
}
.circle-3 {
width: 100px;
height: 100px;
top: 50%;
left: 80%;
animation-delay: 4s;
}
@keyframes float {
0%, 100% {
transform: translateY(0px) rotate(0deg);
}
50% {
transform: translateY(-20px) rotate(10deg);
}
}
/* 头部样式 */
.header {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(20px);
border-bottom: 1px solid rgba(255, 255, 255, 0.2);
padding: 40px 0;
position: sticky;
top: 0;
z-index: 100;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
}
.container {
max-width: 1400px;
margin: 0 auto;
padding: 0 20px;
position: relative;
z-index: 1;
}
.header-content {
text-align: center;
animation: fadeIn 1s ease-out;
}
.title {
position: relative;
margin-bottom: 15px;
}
.title-wrapper {
display: flex;
align-items: center;
justify-content: center;
gap: 20px;
font-size: 3.5rem;
font-weight: 800;
color: white;
text-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
animation: slideInLeft 1s ease-out;
}
.title-icon {
font-size: 3.5rem;
color: #fff;
filter: drop-shadow(0 4px 8px rgba(0, 0, 0, 0.3));
animation: pulse 2s ease-in-out infinite;
}
.title-text {
background: linear-gradient(135deg, #fff 0%, #f0f0f0 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.title-glow {
position: absolute;
top: 0;
left: 50%;
transform: translateX(-50%);
width: 100%;
height: 100%;
background: radial-gradient(ellipse at center, rgba(255, 255, 255, 0.3) 0%, transparent 70%);
z-index: -1;
animation: glow 3s ease-in-out infinite alternate;
}
@keyframes glow {
from {
opacity: 0.5;
}
to {
opacity: 1;
}
}
.subtitle {
position: relative;
animation: slideInRight 1s ease-out;
}
.subtitle-text {
font-size: 1.3rem;
color: rgba(255, 255, 255, 0.9);
font-weight: 300;
letter-spacing: 1px;
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
}
.subtitle-decoration {
width: 100px;
height: 2px;
background: linear-gradient(90deg, transparent 0%, rgba(255, 255, 255, 0.7) 50%, transparent 100%);
margin: 15px auto 0;
animation: fadeIn 2s ease-out;
}
/* 主内容区域 */
.main-content {
padding: 50px 0;
position: relative;
z-index: 1;
}
/* 玻璃效果卡片 */
.glass-card {
background: rgba(255, 255, 255, 0.15);
backdrop-filter: blur(20px);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 20px;
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.1);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
animation: fadeIn 0.8s ease-out;
}
.glass-card:hover {
transform: translateY(-10px);
box-shadow: 0 25px 50px rgba(0, 0, 0, 0.15);
background: rgba(255, 255, 255, 0.2);
}
/* 状态卡片 */
.status-card {
margin-bottom: 40px;
animation-delay: 0.2s;
}
.status-info {
display: flex;
align-items: center;
justify-content: space-between;
padding: 20px;
}
.status-display {
display: flex;
align-items: center;
gap: 20px;
}
.status-icon-wrapper {
width: 60px;
height: 60px;
border-radius: 50%;
background: rgba(255, 255, 255, 0.2);
display: flex;
align-items: center;
justify-content: center;
transition: all 0.3s ease;
}
.status-icon {
font-size: 28px;
color: rgba(255, 255, 255, 0.8);
transition: all 0.3s ease;
}
.status-icon.active {
color: #52c41a;
animation: pulse 2s ease-in-out infinite;
}
.status-text h3 {
color: white;
font-size: 1.4rem;
font-weight: 600;
margin-bottom: 8px;
}
.status-tag {
font-size: 1rem;
padding: 8px 16px;
border-radius: 25px;
font-weight: 500;
}
.init-button {
border-radius: 25px;
padding: 12px 24px;
font-weight: 600;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
transition: all 0.3s ease;
}
.init-button:hover {
transform: translateY(-2px);
box-shadow: 0 12px 25px rgba(102, 126, 234, 0.6);
}
/* 功能区域 */
.function-area {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 40px;
margin-bottom: 40px;
}
.upload-card {
animation-delay: 0.4s;
}
.record-card {
animation-delay: 0.6s;
}
/* 卡片头部 */
.card-header {
display: flex;
align-items: center;
gap: 15px;
padding: 20px 0;
color: white;
}
.header-icon {
width: 50px;
height: 50px;
border-radius: 15px;
background: rgba(255, 255, 255, 0.2);
display: flex;
align-items: center;
justify-content: center;
transition: all 0.3s ease;
}
.header-icon .el-icon {
font-size: 24px;
color: white;
}
.result-icon {
background: linear-gradient(135deg, #52c41a 0%, #73d13d 100%);
}
.header-text h3 {
font-size: 1.3rem;
font-weight: 600;
margin-bottom: 4px;
color: white;
}
.header-text p {
font-size: 0.95rem;
color: rgba(255, 255, 255, 0.8);
margin: 0;
}
.result-badge {
margin-left: auto;
width: 40px;
height: 40px;
border-radius: 50%;
background: linear-gradient(135deg, #ffd700 0%, #ffed4e 100%);
display: flex;
align-items: center;
justify-content: center;
animation: pulse 2s ease-in-out infinite;
}
.result-badge .el-icon {
color: #fff;
font-size: 20px;
}
/* 结果卡片 */
.result-card {
margin-bottom: 40px;
animation-delay: 0.8s;
}
/* 历史记录卡片 */
.history-card {
animation-delay: 1s;
}
.clear-button {
margin-left: auto;
color: rgba(255, 255, 255, 0.8);
transition: all 0.3s ease;
}
.clear-button:hover {
color: #ff4d4f;
background: rgba(255, 77, 79, 0.1);
}
/* 动画效果 */
.slide-up-enter-active,
.slide-up-leave-active {
transition: all 0.5s cubic-bezier(0.4, 0, 0.2, 1);
}
.slide-up-enter-from {
opacity: 0;
transform: translateY(30px) scale(0.95);
}
.slide-up-leave-to {
opacity: 0;
transform: translateY(-30px) scale(0.95);
}
/* 底部装饰 */
.footer-decoration {
position: fixed;
bottom: 0;
left: 0;
width: 100%;
height: 200px;
pointer-events: none;
z-index: 0;
overflow: hidden;
}
.wave {
position: absolute;
bottom: 0;
left: 0;
width: 200%;
height: 100px;
background: rgba(255, 255, 255, 0.1);
animation: wave 10s linear infinite;
}
.wave-1 {
animation-delay: 0s;
opacity: 0.6;
}
.wave-2 {
animation-delay: -5s;
opacity: 0.4;
height: 80px;
}
@keyframes wave {
0% {
transform: translateX(-50%);
}
100% {
transform: translateX(-25%);
}
}
/* 响应式设计 */
@media (max-width: 1200px) {
.container {
max-width: 1000px;
}
.title-wrapper {
font-size: 3rem;
}
.title-icon {
font-size: 3rem;
}
}
@media (max-width: 992px) {
.function-area {
grid-template-columns: 1fr;
gap: 30px;
}
.title-wrapper {
font-size: 2.5rem;
}
.title-icon {
font-size: 2.5rem;
}
.subtitle-text {
font-size: 1.1rem;
}
}
@media (max-width: 768px) {
.header {
padding: 30px 0;
}
.title-wrapper {
font-size: 2rem;
flex-direction: column;
gap: 10px;
}
.title-icon {
font-size: 2rem;
}
.subtitle-text {
font-size: 1rem;
}
.status-info {
flex-direction: column;
gap: 20px;
}
.card-header {
flex-direction: column;
text-align: center;
gap: 10px;
}
.decoration-circle {
display: none;
}
}
@media (max-width: 480px) {
.container {
padding: 0 15px;
}
.main-content {
padding: 30px 0;
}
.glass-card {
border-radius: 15px;
}
.title-wrapper {
font-size: 1.8rem;
}
.title-icon {
font-size: 1.8rem;
}
}
/* Element Plus 组件样式重写 */
:deep(.el-card) {
background: transparent;
border: none;
box-shadow: none;
}
:deep(.el-card__header) {
background: transparent;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
padding: 0;
}
:deep(.el-card__body) {
padding: 30px;
background: transparent;
}
:deep(.el-tag) {
border-radius: 20px;
border: none;
font-weight: 500;
}
:deep(.el-button) {
border-radius: 20px;
font-weight: 500;
transition: all 0.3s ease;
}
:deep(.el-button:hover) {
transform: translateY(-2px);
}
</style>

@ -0,0 +1,33 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import AutoImport from 'unplugin-auto-import/vite'
import Components from 'unplugin-vue-components/vite'
import { ElementPlusResolver } from 'unplugin-vue-components/resolvers'
export default defineConfig({
plugins: [
vue(),
AutoImport({
resolvers: [ElementPlusResolver()],
}),
Components({
resolvers: [ElementPlusResolver()],
}),
],
server: {
port: 3000,
host: '0.0.0.0',
proxy: {
'/api': {
target: 'http://localhost:5000',
changeOrigin: true,
secure: false,
}
}
},
build: {
outDir: 'dist',
assetsDir: 'assets',
chunkSizeWarningLimit: 1000
}
})

@ -0,0 +1,24 @@
@echo off
echo ================================
echo 声纹识别系统启动脚本
echo ================================
echo.
echo 正在启动后端服务...
cd /d "%~dp0backend"
start "后端服务" cmd /k "python app.py"
echo 等待后端服务启动...
timeout /t 3 /nobreak >nul
echo 正在启动前端服务...
cd /d "%~dp0frontend"
start "前端服务" cmd /k "npm run dev"
echo.
echo 启动完成!
echo 后端服务: http://localhost:5000
echo 前端服务: http://localhost:3000
echo.
echo 请等待服务完全启动后访问前端地址
pause

@ -0,0 +1,36 @@
#!/bin/bash
echo "================================"
echo " 声纹识别系统启动脚本"
echo "================================"
echo
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
echo "正在启动后端服务..."
cd "$SCRIPT_DIR/backend"
python app.py &
BACKEND_PID=$!
echo "等待后端服务启动..."
sleep 3
echo "正在启动前端服务..."
cd "$SCRIPT_DIR/frontend"
npm run dev &
FRONTEND_PID=$!
echo
echo "启动完成!"
echo "后端服务: http://localhost:5000"
echo "前端服务: http://localhost:3000"
echo
echo "后端进程 PID: $BACKEND_PID"
echo "前端进程 PID: $FRONTEND_PID"
echo
echo "按 Ctrl+C 停止所有服务"
# 等待用户中断
trap 'echo "正在停止服务..."; kill $BACKEND_PID $FRONTEND_PID; exit' INT
wait

@ -0,0 +1,46 @@
# 语速增强
speed:
# 增强概率
prob: 1.0
# 音量增强
volume:
# 增强概率
prob: 0.0
# 最小增益
min_gain_dBFS: -15
# 最大增益
max_gain_dBFS: 15
# 噪声增强
noise:
# 增强概率
prob: 0.5
# 噪声增强的噪声文件夹
noise_dir: 'dataset/noise'
# 针对噪声的最小音量增益
min_snr_dB: 10
# 针对噪声的最大音量增益
max_snr_dB: 50
# 混响增强
reverb:
# 增强概率
prob: 0.5
# 混响增强的混响文件夹
reverb_dir: 'dataset/reverb'
# Spec增强
spec_aug:
# 增强概率
prob: 0.5
# 频域掩蔽的比例
freq_mask_ratio: 0.1
# 频域掩蔽次数
n_freq_masks: 1
# 频域掩蔽的比例
time_mask_ratio: 0.05
# 频域掩蔽次数
n_time_masks: 1
# 最大时间扭曲
max_time_warp: 0

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 64
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 8
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型
model: 'CAMPPlus'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 128
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 16
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型
model: 'EcapaTdnn'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 32
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 4
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型支持ERes2Net、ERes2NetV2
model: 'ERes2Net'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 64
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 8
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型支持PANNS_CNN6、PANNS_CNN10、PANNS_CNN14
model: 'PANNS_CNN10'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 32
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 4
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型
model: 'Res2Net'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 32
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 4
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型
model: 'ResNetSE'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,80 @@
# 数据集参数
dataset_conf:
dataset:
# 过滤最短的音频长度
min_duration: 0.4
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
dataLoader:
# 训练的批量大小
batch_size: 64
# 是否丢弃最后一个样本
drop_last: True
# 读取数据的线程数量
num_workers: 8
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 8
# 最长的音频长度
max_duration: 20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 测试数据的数据列表路径
test_list: 'dataset/test_list.txt'
# 标签列表
label_list_path: 'dataset/label_list.txt'
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
model_conf:
# 所使用的模型
model: 'TDNN'
# 模型参数
model_args:
# 分类大小如果为null自动通过标签列表获取
num_class: null
optimizer_conf:
# 优化方法
optimizer: 'Adam'
# 优化方法参数
optimizer_args:
lr: 0.001
weight_decay: !!float 1e-5
# 学习率衰减函数支持Pytorch支持的和项目提供的WarmupCosineSchedulerLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 1e-5
max_lr: 0.001
warmup_epoch: 5
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# CrossEntropyLoss类的label_smoothing参数
label_smoothing: 0.0
# 训练的轮数
max_epoch: 60
log_interval: 10

@ -0,0 +1,99 @@
import os
# 生成数据列表
def get_data_list(audio_path, list_path):
sound_sum = 0
audios = os.listdir(audio_path)
os.makedirs(list_path, exist_ok=True)
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w', encoding='utf-8')
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w', encoding='utf-8')
f_label = open(os.path.join(list_path, 'label_list.txt'), 'w', encoding='utf-8')
for i in range(len(audios)):
f_label.write(f'{audios[i]}\n')
sounds = os.listdir(os.path.join(audio_path, audios[i]))
for sound in sounds:
sound_path = os.path.join(audio_path, audios[i], sound).replace('\\', '/')
if sound_sum % 10 == 0:
f_test.write(f'{sound_path}\t{i}\n')
else:
f_train.write(f'{sound_path}\t{i}\n')
sound_sum += 1
print(f"Audio{i + 1}/{len(audios)}")
f_label.close()
f_test.close()
f_train.close()
# 下载数据方式,执行:./tools/download_3dspeaker_data.sh
# 生成生成方言数据列表
def get_language_identification_data_list(audio_path, list_path):
labels_dict = {0: 'Standard Mandarin', 3: 'Southwestern Mandarin', 6: 'Central Plains Mandarin',
4: 'JiangHuai Mandarin', 2: 'Wu dialect', 8: 'Gan dialect', 9: 'Jin dialect',
11: 'LiaoJiao Mandarin', 12: 'JiLu Mandarin', 10: 'Min dialect', 7: 'Yue dialect',
5: 'Hakka dialect', 1: 'Xiang dialect', 13: 'Northern Mandarin'}
with open(os.path.join(list_path, 'train_list.txt'), 'w', encoding='utf-8') as f:
train_dir = os.path.join(audio_path, 'train')
for root, dirs, files in os.walk(train_dir):
for file in files:
if not file.endswith('.wav'): continue
label = int(file.split('_')[-1].replace('.wav', '')[-2:])
file = os.path.join(root, file)
f.write(f'{file}\t{label}\n')
with open(os.path.join(list_path, 'test_list.txt'), 'w', encoding='utf-8') as f:
test_dir = os.path.join(audio_path, 'test')
for root, dirs, files in os.walk(test_dir):
for file in files:
if not file.endswith('.wav'): continue
label = int(file.split('_')[-1].replace('.wav', '')[-2:])
file = os.path.join(root, file)
f.write(f'{file}\t{label}\n')
with open(os.path.join(list_path, 'label_list.txt'), 'w', encoding='utf-8') as f:
for i in range(len(labels_dict)):
f.write(f'{labels_dict[i]}\n')
# 创建UrbanSound8K数据列表
def create_UrbanSound8K_list(audio_path, metadata_path, list_path):
sound_sum = 0
f_train = open(os.path.join(list_path, 'train_list.txt'), 'w', encoding='utf-8')
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w', encoding='utf-8')
f_label = open(os.path.join(list_path, 'label_list.txt'), 'w', encoding='utf-8')
with open(metadata_path) as f:
lines = f.readlines()
labels = {}
for i, line in enumerate(lines):
if i == 0:continue
data = line.replace('\n', '').split(',')
class_id = int(data[6])
if class_id not in labels.keys():
labels[class_id] = data[-1]
sound_path = os.path.join(audio_path, f'fold{data[5]}', data[0]).replace('\\', '/')
if sound_sum % 10 == 0:
f_test.write(f'{sound_path}\t{data[6]}\n')
else:
f_train.write(f'{sound_path}\t{data[6]}\n')
sound_sum += 1
for i in range(len(labels)):
f_label.write(f'{labels[i]}\n')
f_label.close()
f_test.close()
f_train.close()
if __name__ == '__main__':
# get_data_list('dataset/audio', 'dataset')
# 生成生成方言数据列表
# get_language_identification_data_list(audio_path='dataset/language',
# list_path='dataset/')
# 创建UrbanSound8K数据列表
create_UrbanSound8K_list(audio_path='dataset/UrbanSound8K/audio',
metadata_path='dataset/UrbanSound8K/metadata/UrbanSound8K.csv',
list_path='dataset')

@ -0,0 +1,26 @@
import argparse
import functools
import time
from macls.trainer import MAClsTrainer
from macls.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/cam++.yml', "配置文件")
add_arg("use_gpu", bool, True, "是否使用GPU评估模型")
add_arg('save_matrix_path', str, 'output/images/', "保存混合矩阵的路径")
add_arg('resume_model', str, 'models/CAMPPlus_Fbank/best_model/', "模型的路径")
add_arg('overwrites', str, None, '覆盖配置文件中的参数,比如"train_conf.max_epoch=100",多个用逗号隔开')
args = parser.parse_args()
print_arguments(args=args)
# 获取训练器
trainer = MAClsTrainer(configs=args.configs, use_gpu=args.use_gpu, overwrites=args.overwrites)
# 开始评估
start = time.time()
loss, accuracy = trainer.evaluate(resume_model=args.resume_model,
save_matrix_path=args.save_matrix_path)
end = time.time()
print('评估消耗时间:{}sloss{:.5f}accuracy{:.5f}'.format(int(end - start), loss, accuracy))

@ -0,0 +1,19 @@
import argparse
import functools
from macls.trainer import MAClsTrainer
from macls.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('save_dir', str, 'dataset/features', '保存特征的路径')
add_arg('max_duration', int, 100, '提取特征的最大时长,避免过长显存不足,单位秒')
args = parser.parse_args()
print_arguments(args=args)
# 获取训练器
trainer = MAClsTrainer(configs=args.configs)
# 提取特征保存文件
trainer.extract_features(save_dir=args.save_dir, max_duration=args.max_duration)

@ -0,0 +1,23 @@
import argparse
import functools
from macls.predict import MAClsPredictor
from macls.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('use_gpu', bool, True, '是否使用GPU预测')
add_arg('audio_path', str, 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav', '音频路径')
add_arg('model_path', str, 'models/CAMPPlus_Fbank/best_model/', '导出的预测模型文件路径')
args = parser.parse_args()
print_arguments(args=args)
# 获取识别器
predictor = MAClsPredictor(configs=args.configs,
model_path=args.model_path,
use_gpu=args.use_gpu)
label, score = predictor.predict(audio_data=args.audio_path)
print(f'音频:{args.audio_path} 的预测结果标签为:{label},得分:{score}')

@ -0,0 +1,58 @@
import argparse
import functools
import threading
import time
import numpy as np
import soundcard as sc
from macls.predict import MAClsPredictor
from macls.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('use_gpu', bool, True, '是否使用GPU预测')
add_arg('record_seconds', float, 3, '录音长度')
add_arg('model_path', str, 'models/CAMPPlus_Fbank/best_model/', '导出的预测模型文件路径')
args = parser.parse_args()
print_arguments(args=args)
# 获取识别器
predictor = MAClsPredictor(configs=args.configs,
model_path=args.model_path,
use_gpu=args.use_gpu)
all_data = []
# 获取默认麦克风
default_mic = sc.default_microphone()
# 录音采样率
samplerate = 16000
# 录音块大小
numframes = 1024
# 模型输入长度
infer_len = int(samplerate * args.record_seconds / numframes)
def infer_thread():
global all_data
s = time.time()
while True:
if len(all_data) < infer_len: continue
# 截取最新的音频数据
seg_data = all_data[-infer_len:]
d = np.concatenate(seg_data)
# 删除旧的音频数据
del all_data[:len(all_data) - infer_len]
label, score = predictor.predict(audio_data=d, sample_rate=samplerate)
print(f'{int(time.time() - s)}s 预测结果标签为:{label},得分:{score}')
thread = threading.Thread(target=infer_thread, args=())
thread.start()
with default_mic.recorder(samplerate=samplerate, channels=1) as mic:
while True:
data = mic.record(numframes=numframes)
all_data.append(data)

@ -0,0 +1,23 @@
import torch
# 对一个batch的数据处理
def collate_fn(batch):
# 找出音频长度最长的
batch_sorted = sorted(batch, key=lambda sample: sample[0].size(0), reverse=True)
freq_size = batch_sorted[0][0].size(1)
max_freq_length = batch_sorted[0][0].size(0)
batch_size = len(batch_sorted)
# 以最大的长度创建0张量
features = torch.zeros((batch_size, max_freq_length, freq_size), dtype=torch.float32)
input_lens, labels = [], []
for x in range(batch_size):
tensor, label = batch[x]
seq_length = tensor.size(0)
# 将数据插入都0张量中实现了padding
features[x, :seq_length, :] = tensor[:, :]
labels.append(label)
input_lens.append(seq_length)
labels = torch.tensor(labels, dtype=torch.int64)
input_lens = torch.tensor(input_lens, dtype=torch.int64)
return features, labels, input_lens

@ -0,0 +1,132 @@
import numpy as np
import torch
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
from loguru import logger
class AudioFeaturizer(nn.Module):
"""音频特征器
:param feature_method: 所使用的预处理方法
:type feature_method: str
:param use_hf_model: 是否使用HF上的Wav2Vec2类似模型提取音频特征
:type use_hf_model: bool
:param method_args: 预处理方法的参数
:type method_args: dict
"""
def __init__(self, feature_method='MelSpectrogram', use_hf_model=False, method_args={}):
super().__init__()
self._method_args = method_args
self._feature_method = feature_method
self.use_hf_model = use_hf_model
if self.use_hf_model:
from transformers import AutoModel, AutoFeatureExtractor
# 判断是否使用GPU提取特征
use_gpu = torch.cuda.is_available() and method_args.get('use_gpu', True)
self.device = torch.device("cuda") if use_gpu else torch.device("cpu")
# 加载Wav2Vec2类似模型
self.processor = AutoFeatureExtractor.from_pretrained(feature_method)
self.feature_model = AutoModel.from_pretrained(feature_method).to(self.device)
logger.info(f'使用模型【{feature_method}】提取特征,使用【{self.device}】设备提取')
# 获取模型的输出通道数
inputs = self.processor(np.ones(16000 * 1, dtype=np.float32), sampling_rate=16000,
return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.feature_model(**inputs)
self.output_channels = outputs.extract_features.shape[2]
else:
if feature_method == 'MelSpectrogram':
self.feat_fun = MelSpectrogram(**method_args)
elif feature_method == 'Spectrogram':
self.feat_fun = Spectrogram(**method_args)
elif feature_method == 'MFCC':
self.feat_fun = MFCC(**method_args)
elif feature_method == 'Fbank':
self.feat_fun = KaldiFbank(**method_args)
else:
raise Exception(f'预处理方法 {self._feature_method} 不存在!')
logger.info(f'使用【{feature_method}】提取特征')
def forward(self, waveforms, input_lens_ratio=None):
"""从AudioSegment中提取音频特征
:param waveforms: Audio segment to extract features from.
:type waveforms: AudioSegment
:param input_lens_ratio: input length ratio
:type input_lens_ratio: tensor
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
"""
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
if self.use_hf_model:
# 使用HF上的Wav2Vec2类似模型提取音频特征
if isinstance(waveforms, torch.Tensor):
waveforms = waveforms.numpy()
inputs = self.processor(waveforms, sampling_rate=16000,
return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.feature_model(**inputs)
feature = outputs.extract_features.cpu().detach()
else:
# 使用普通方法提取音频特征
feature = self.feat_fun(waveforms)
feature = feature.transpose(2, 1)
# 归一化
feature = feature - feature.mean(1, keepdim=True)
if input_lens_ratio is not None:
# 对掩码比例进行扩展
input_lens = (input_lens_ratio * feature.shape[1])
mask_lens = torch.round(input_lens).long()
mask_lens = mask_lens.unsqueeze(1)
# 生成掩码张量
idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
mask = idxs < mask_lens
mask = mask.unsqueeze(-1)
# 对特征进行掩码操作
feature = torch.where(mask, feature, torch.zeros_like(feature))
return feature
@property
def feature_dim(self):
"""返回特征大小
:return: 特征大小
:rtype: int
"""
if self.use_hf_model:
return self.output_channels
if self._feature_method == 'MelSpectrogram':
return self._method_args.get('n_mels', 128)
elif self._feature_method == 'Spectrogram':
return self._method_args.get('n_fft', 400) // 2 + 1
elif self._feature_method == 'MFCC':
return self._method_args.get('n_mfcc', 40)
elif self._feature_method == 'Fbank':
return self._method_args.get('num_mel_bins', 23)
else:
raise Exception('没有{}预处理方法'.format(self._feature_method))
class KaldiFbank(nn.Module):
def __init__(self, **kwargs):
super(KaldiFbank, self).__init__()
self.kwargs = kwargs
def forward(self, waveforms):
"""
:param waveforms: [Batch, Length]
:return: [Batch, Feature, Length]
"""
log_fbanks = []
for waveform in waveforms:
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
log_fbank = Kaldi.fbank(waveform, **self.kwargs)
log_fbank = log_fbank.transpose(0, 1)
log_fbanks.append(log_fbank)
log_fbank = torch.stack(log_fbanks)
return log_fbank

@ -0,0 +1,157 @@
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from yeaudio.audio import AudioSegment
from yeaudio.augmentation import SpeedPerturbAugmentor, VolumePerturbAugmentor, NoisePerturbAugmentor, \
ReverbPerturbAugmentor, SpecAugmentor
from macls.data_utils.featurizer import AudioFeaturizer
class MAClsDataset(Dataset):
def __init__(self,
data_list_path,
audio_featurizer: AudioFeaturizer,
max_duration=3,
min_duration=0.5,
mode='train',
sample_rate=16000,
aug_conf=None,
use_dB_normalization=True,
target_dB=-20):
"""音频数据加载器
Args:
data_list_path: 包含音频路径和标签的数据列表文件的路径
audio_featurizer: 声纹特征提取器
max_duration: 最长的音频长度大于这个长度会裁剪掉
min_duration: 过滤最短的音频长度
aug_conf: 用于指定音频增强的配置
mode: 数据集模式在训练模式下数据集可能会进行一些数据增强的预处理
sample_rate: 采样率
use_dB_normalization: 是否对音频进行音量归一化
target_dB: 音量归一化的大小
"""
super(MAClsDataset, self).__init__()
assert mode in ['train', 'eval', 'extract_feature']
self.data_list_path = data_list_path
self.max_duration = max_duration
self.min_duration = min_duration
self.mode = mode
self._target_sample_rate = sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
self.speed_augment = None
self.volume_augment = None
self.noise_augment = None
self.reverb_augment = None
self.spec_augment = None
# 获取特征器
self.audio_featurizer = audio_featurizer
# 获取特征裁剪的大小
self.max_feature_len = self.get_crop_feature_len()
# 获取数据列表
with open(self.data_list_path, 'r', encoding='utf-8') as f:
self.lines = f.readlines()
if mode == 'train' and aug_conf is not None:
# 获取数据增强器
self.get_augmentor(aug_conf)
# 评估模式下,数据列表需要排序
if self.mode == 'eval':
self.sort_list()
def __getitem__(self, idx):
# 分割数据文件路径和标签
data_path, label = self.lines[idx].replace('\n', '').split('\t')
# 如果后缀名为.npy的文件那么直接读取
if data_path.endswith('.npy'):
feature = np.load(data_path)
if feature.shape[0] > self.max_feature_len:
crop_start = random.randint(0, feature.shape[0] - self.max_feature_len) if self.mode == 'train' else 0
feature = feature[crop_start:crop_start + self.max_feature_len, :]
feature = torch.tensor(feature, dtype=torch.float32)
else:
audio_path, label = self.lines[idx].strip().split('\t')
# 读取音频
audio_segment = AudioSegment.from_file(audio_path)
# 数据太短不利于训练
if self.mode == 'train' or self.mode == 'extract_feature':
if audio_segment.duration < self.min_duration:
return self.__getitem__(idx + 1 if idx < len(self.lines) - 1 else 0)
# 音频增强
if self.mode == 'train':
audio_segment = self.augment_audio(audio_segment)
# 重采样
if audio_segment.sample_rate != self._target_sample_rate:
audio_segment.resample(self._target_sample_rate)
# 音量归一化
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# 裁剪需要的数据
if audio_segment.duration > self.max_duration:
audio_segment.crop(duration=self.max_duration, mode=self.mode)
samples = torch.tensor(audio_segment.samples, dtype=torch.float32)
feature = self.audio_featurizer(samples)
feature = feature.squeeze(0)
if self.mode == 'train' and self.spec_augment is not None:
feature = self.spec_augment(feature.cpu().numpy())
feature = torch.tensor(feature, dtype=torch.float32)
label = torch.tensor(int(label), dtype=torch.int64)
return feature, label
def __len__(self):
return len(self.lines)
# 获取特征裁剪的大小对应max_duration音频提取特征后的长度
def get_crop_feature_len(self):
samples = torch.randn((1, self.max_duration * self._target_sample_rate))
feature = self.audio_featurizer(samples).squeeze(0)
freq_len = feature.size(0)
return freq_len
# 数据列表需要排序
def sort_list(self):
lengths = []
for line in tqdm(self.lines, desc=f"对列表[{self.data_list_path}]进行长度排序"):
# 分割数据文件路径和标签
data_path, _ = line.split('\t')
if data_path.endswith('.npy'):
feature = np.load(data_path)
length = feature.shape[0]
lengths.append(length)
else:
# 读取音频
audio_segment = AudioSegment.from_file(data_path)
length = audio_segment.duration
lengths.append(length)
# 对长度排序并获取索引
sorted_indexes = np.argsort(lengths)
self.lines = [self.lines[i] for i in sorted_indexes]
# 获取数据增强器
def get_augmentor(self, aug_conf):
if aug_conf.speed is not None:
self.speed_augment = SpeedPerturbAugmentor(**aug_conf.speed)
if aug_conf.volume is not None:
self.volume_augment = VolumePerturbAugmentor(**aug_conf.volume)
if aug_conf.noise is not None:
self.noise_augment = NoisePerturbAugmentor(**aug_conf.noise)
if aug_conf.reverb is not None:
self.reverb_augment = ReverbPerturbAugmentor(**aug_conf.reverb)
if aug_conf.spec_aug is not None:
self.spec_augment = SpecAugmentor(**aug_conf.spec_aug)
# 音频增强
def augment_audio(self, audio_segment):
if self.speed_augment is not None:
audio_segment = self.speed_augment(audio_segment)
if self.volume_augment is not None:
audio_segment = self.volume_augment(audio_segment)
if self.noise_augment is not None:
audio_segment = self.noise_augment(audio_segment)
if self.reverb_augment is not None:
audio_segment = self.reverb_augment(audio_segment)
return audio_segment

@ -0,0 +1,12 @@
import numpy as np
import torch
# 计算准确率
def accuracy(output, label):
output = torch.nn.functional.softmax(output, dim=-1)
output = output.data.cpu().numpy()
output = np.argmax(output, axis=1)
label = label.data.cpu().numpy()
acc = np.mean((output == label).astype(int))
return acc

@ -0,0 +1,32 @@
import importlib
from loguru import logger
from torch.optim import *
from .scheduler import WarmupCosineSchedulerLR
from torch.optim.lr_scheduler import *
__all__ = ['build_optimizer', 'build_lr_scheduler']
def build_optimizer(params, configs):
use_optimizer = configs.optimizer_conf.get('optimizer', 'Adam')
optimizer_args = configs.optimizer_conf.get('optimizer_args', {})
optim = importlib.import_module(__name__)
optimizer = getattr(optim, use_optimizer)(params=params, **optimizer_args)
logger.info(f'成功创建优化方法:{use_optimizer},参数为:{optimizer_args}')
return optimizer
def build_lr_scheduler(optimizer, step_per_epoch, configs):
use_scheduler = configs.optimizer_conf.get('scheduler', 'WarmupCosineSchedulerLR')
scheduler_args = configs.optimizer_conf.get('scheduler_args', {})
if configs.optimizer_conf.scheduler == 'CosineAnnealingLR' and 'T_max' not in scheduler_args:
scheduler_args.T_max = int(configs.train_conf.max_epoch * 1.2) * step_per_epoch
if configs.optimizer_conf.scheduler == 'WarmupCosineSchedulerLR' and 'fix_epoch' not in scheduler_args:
scheduler_args.fix_epoch = configs.train_conf.max_epoch
if configs.optimizer_conf.scheduler == 'WarmupCosineSchedulerLR' and 'step_per_epoch' not in scheduler_args:
scheduler_args.step_per_epoch = step_per_epoch
optim = importlib.import_module(__name__)
scheduler = getattr(optim, use_scheduler)(optimizer=optimizer, **scheduler_args)
logger.info(f'成功创建学习率衰减:{use_scheduler},参数为:{scheduler_args}')
return scheduler

@ -0,0 +1,48 @@
import math
from typing import List
class WarmupCosineSchedulerLR:
def __init__(
self,
optimizer,
min_lr,
max_lr,
warmup_epoch,
fix_epoch,
step_per_epoch
):
self.optimizer = optimizer
assert min_lr <= max_lr
self.min_lr = min_lr
self.max_lr = max_lr
self.warmup_step = warmup_epoch * step_per_epoch
self.fix_step = fix_epoch * step_per_epoch
self.current_step = 0.0
def set_lr(self, ):
new_lr = self.clr(self.current_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def step(self, step=None):
if step is not None:
self.current_step = step
new_lr = self.set_lr()
self.current_step += 1
return new_lr
def clr(self, step):
if step < self.warmup_step:
return self.min_lr + (self.max_lr - self.min_lr) * \
(step / self.warmup_step)
elif self.warmup_step <= step < self.fix_step:
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * \
(1 + math.cos(math.pi * (step - self.warmup_step) /
(self.fix_step - self.warmup_step)))
else:
return self.min_lr
def get_last_lr(self) -> List[float]:
return [self.clr(self.current_step)]

@ -0,0 +1,177 @@
import os
import sys
from io import BufferedReader
from typing import List
import numpy as np
import torch
import yaml
from loguru import logger
from yeaudio.audio import AudioSegment
from macls.data_utils.featurizer import AudioFeaturizer
from macls.models import build_model
from macls.utils.utils import dict_to_object, print_arguments, convert_string_based_on_type
class MAClsPredictor:
def __init__(self,
configs,
model_path='models/CAMPPlus_Fbank/best_model/',
use_gpu=True,
overwrites=None,
log_level="info"):
"""声音分类预测工具
:param configs: 配置文件路径或者模型名称如果是模型名称则会使用默认的配置文件
:param model_path: 导出的预测模型文件夹路径
:param use_gpu: 是否使用GPU预测
:param overwrites: 覆盖配置文件中的参数比如"train_conf.max_epoch=100"多个用逗号隔开
:param log_level: 打印的日志等级可选值有"debug", "info", "warning", "error"
"""
if use_gpu:
assert (torch.cuda.is_available()), 'GPU不可用'
self.device = torch.device("cuda")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
self.device = torch.device("cpu")
self.log_level = log_level.upper()
logger.remove()
logger.add(sink=sys.stdout, level=self.log_level)
# 读取配置文件
if isinstance(configs, str):
# 获取当前程序绝对路径
absolute_path = os.path.dirname(__file__)
# 获取默认配置文件路径
config_path = os.path.join(absolute_path, f"configs/{configs}.yml")
configs = config_path if os.path.exists(config_path) else configs
with open(configs, 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
self.configs = dict_to_object(configs)
# 覆盖配置文件中的参数
if overwrites:
overwrites = overwrites.split(",")
for overwrite in overwrites:
keys, value = overwrite.strip().split("=")
attrs = keys.split('.')
current_level = self.configs
for attr in attrs[:-1]:
current_level = getattr(current_level, attr)
before_value = getattr(current_level, attrs[-1])
setattr(current_level, attrs[-1], convert_string_based_on_type(before_value, value))
# 打印配置信息
print_arguments(configs=self.configs)
# 获取特征器
self._audio_featurizer = AudioFeaturizer(feature_method=self.configs.preprocess_conf.feature_method,
use_hf_model=self.configs.preprocess_conf.get('use_hf_model', False),
method_args=self.configs.preprocess_conf.get('method_args', {}))
# 获取分类标签
with open(self.configs.dataset_conf.label_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.class_labels = [l.replace('\n', '') for l in lines]
# 自动获取列表数量
if self.configs.model_conf.model_args.get('num_class', None) is None:
self.configs.model_conf.model_args.num_class = len(self.class_labels)
# 获取模型
self.predictor = build_model(input_size=self._audio_featurizer.feature_dim, configs=self.configs)
self.predictor.to(self.device)
# 加载模型
if os.path.isdir(model_path):
model_path = os.path.join(model_path, 'model.pth')
assert os.path.exists(model_path), f"{model_path} 模型不存在!"
if torch.cuda.is_available() and use_gpu:
model_state_dict = torch.load(model_path, weights_only=False)
else:
model_state_dict = torch.load(model_path, weights_only=False, map_location='cpu')
self.predictor.load_state_dict(model_state_dict)
logger.info(f"成功加载模型参数:{model_path}")
self.predictor.eval()
def _load_audio(self, audio_data, sample_rate=16000):
"""加载音频
:param audio_data: 需要识别的数据支持文件路径文件对象字节numpy如果是字节的话必须是完整的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 识别的文本结果和解码的得分数
"""
# 加载音频文件,并进行预处理
if isinstance(audio_data, str):
audio_segment = AudioSegment.from_file(audio_data)
elif isinstance(audio_data, BufferedReader):
audio_segment = AudioSegment.from_file(audio_data)
elif isinstance(audio_data, np.ndarray):
audio_segment = AudioSegment.from_ndarray(audio_data, sample_rate)
elif isinstance(audio_data, bytes):
audio_segment = AudioSegment.from_bytes(audio_data)
else:
raise Exception(f'不支持该数据类型,当前数据类型为:{type(audio_data)}')
# 重采样
if audio_segment.sample_rate != self.configs.dataset_conf.dataset.sample_rate:
audio_segment.resample(self.configs.dataset_conf.dataset.sample_rate)
# decibel normalization
if self.configs.dataset_conf.dataset.use_dB_normalization:
audio_segment.normalize(target_db=self.configs.dataset_conf.dataset.target_dB)
assert audio_segment.duration >= self.configs.dataset_conf.dataset.min_duration, \
f'音频太短,最小应该为{self.configs.dataset_conf.dataset.min_duration}s当前音频为{audio_segment.duration}s'
return audio_segment
# 预测一个音频的特征
def predict(self,
audio_data,
sample_rate=16000):
"""预测一个音频
:param audio_data: 需要识别的数据支持文件路径文件对象字节numpy如果是字节的话必须是完整并带格式的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 结果标签和对应的得分
"""
# 加载音频文件,并进行预处理
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
input_data = torch.tensor(input_data.samples, dtype=torch.float32).unsqueeze(0)
audio_feature = self._audio_featurizer(input_data).to(self.device)
# 执行预测
output = self.predictor(audio_feature)
result = torch.nn.functional.softmax(output, dim=-1)[0]
result = result.data.cpu().numpy()
# 最大概率的label
lab = np.argsort(result)[-1]
score = result[lab]
return self.class_labels[lab], round(float(score), 5)
def predict_batch(self, audios_data: List, sample_rate=16000):
"""预测一批音频的特征
:param audios_data: 需要识别的数据支持文件路径文件对象字节numpy如果是字节的话必须是完整并带格式的字节文件
:param sample_rate: 如果传入的事numpy数据需要指定采样率
:return: 结果标签和对应的得分
"""
audios_data1 = []
for audio_data in audios_data:
# 加载音频文件,并进行预处理
input_data = self._load_audio(audio_data=audio_data, sample_rate=sample_rate)
audios_data1.append(input_data.samples)
# 找出音频长度最长的
batch = sorted(audios_data1, key=lambda a: a.shape[0], reverse=True)
max_audio_length = batch[0].shape[0]
batch_size = len(batch)
# 以最大的长度创建0张量
inputs = np.zeros((batch_size, max_audio_length), dtype=np.float32)
input_lens_ratio = []
for x in range(batch_size):
tensor = audios_data1[x]
seq_length = tensor.shape[0]
# 将数据插入都0张量中实现了padding
inputs[x, :seq_length] = tensor[:]
input_lens_ratio.append(seq_length / max_audio_length)
inputs = torch.tensor(inputs, dtype=torch.float32)
input_lens_ratio = torch.tensor(input_lens_ratio, dtype=torch.float32)
audio_feature = self._audio_featurizer(inputs, input_lens_ratio).to(self.device)
# 执行预测
output = self.predictor(audio_feature)
results = torch.nn.functional.softmax(output, dim=-1)
results = results.data.cpu().numpy()
labels, scores = [], []
for result in results:
lab = np.argsort(result)[-1]
score = result[lab]
labels.append(self.class_labels[lab])
scores.append(round(float(score), 5))
return labels, scores

@ -0,0 +1,456 @@
import os
import platform
import sys
import time
import uuid
from datetime import timedelta
import numpy as np
import torch
import torch.distributed as dist
import yaml
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torchinfo import summary
from tqdm import tqdm
from loguru import logger
from visualdl import LogWriter
from macls.data_utils.collate_fn import collate_fn
from macls.data_utils.featurizer import AudioFeaturizer
from macls.data_utils.reader import MAClsDataset
from macls.metric.metrics import accuracy
from macls.models import build_model
from macls.optimizer import build_optimizer, build_lr_scheduler
from macls.utils.checkpoint import load_pretrained, load_checkpoint, save_checkpoint
from macls.utils.utils import dict_to_object, plot_confusion_matrix, print_arguments, convert_string_based_on_type
class MAClsTrainer(object):
def __init__(self,
configs,
use_gpu=True,
data_augment_configs=None,
num_class=None,
overwrites=None,
log_level="info"):
"""声音分类训练工具类
:param configs: 配置文件路径或者模型名称如果是模型名称则会使用默认的配置文件
:param use_gpu: 是否使用GPU训练模型
:param data_augment_configs: 数据增强配置字典或者其文件路径
:param num_class: 分类大小对应配置文件中的model_conf.model_args.num_class
:param overwrites: 覆盖配置文件中的参数比如"train_conf.max_epoch=100"多个用逗号隔开
:param log_level: 打印的日志等级可选值有"debug", "info", "warning", "error"
"""
if use_gpu:
assert (torch.cuda.is_available()), 'GPU不可用'
self.device = torch.device("cuda")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
self.device = torch.device("cpu")
self.use_gpu = use_gpu
self.log_level = log_level.upper()
logger.remove()
logger.add(sink=sys.stdout, level=self.log_level)
# 读取配置文件
if isinstance(configs, str):
# 获取当前程序绝对路径
absolute_path = os.path.dirname(__file__)
# 获取默认配置文件路径
config_path = os.path.join(absolute_path, f"configs/{configs}.yml")
configs = config_path if os.path.exists(config_path) else configs
with open(configs, 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
self.configs = dict_to_object(configs)
if num_class is not None:
self.configs.model_conf.model_args.num_class = num_class
# 覆盖配置文件中的参数
if overwrites:
overwrites = overwrites.split(",")
for overwrite in overwrites:
keys, value = overwrite.strip().split("=")
attrs = keys.split('.')
current_level = self.configs
for attr in attrs[:-1]:
current_level = getattr(current_level, attr)
before_value = getattr(current_level, attrs[-1])
setattr(current_level, attrs[-1], convert_string_based_on_type(before_value, value))
# 打印配置信息
print_arguments(configs=self.configs)
self.model = None
self.optimizer = None
self.scheduler = None
self.audio_featurizer = None
self.train_dataset = None
self.train_loader = None
self.test_dataset = None
self.test_loader = None
self.amp_scaler = None
# 读取数据增强配置文件
if isinstance(data_augment_configs, str):
with open(data_augment_configs, 'r', encoding='utf-8') as f:
data_augment_configs = yaml.load(f.read(), Loader=yaml.FullLoader)
print_arguments(configs=data_augment_configs, title='数据增强配置')
self.data_augment_configs = dict_to_object(data_augment_configs)
# 获取分类标签
with open(self.configs.dataset_conf.label_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.class_labels = [l.replace('\n', '') for l in lines]
if platform.system().lower() == 'windows':
self.configs.dataset_conf.dataLoader.num_workers = 0
logger.warning('Windows系统不支持多线程读取数据已自动关闭')
if self.configs.preprocess_conf.get('use_hf_model', False):
self.configs.dataset_conf.dataLoader.num_workers = 0
logger.warning('使用HuggingFace模型不支持多线程进行特征提取已自动关闭')
self.max_step, self.train_step = None, None
self.train_loss, self.train_acc = None, None
self.train_eta_sec = None
self.eval_loss, self.eval_acc = None, None
self.test_log_step, self.train_log_step = 0, 0
self.stop_train, self.stop_eval = False, False
def __setup_dataloader(self, is_train=False):
""" 获取数据加载器
:param is_train: 是否获取训练数据
"""
# 获取特征器
self.audio_featurizer = AudioFeaturizer(feature_method=self.configs.preprocess_conf.feature_method,
use_hf_model=self.configs.preprocess_conf.get('use_hf_model', False),
method_args=self.configs.preprocess_conf.get('method_args', {}))
dataset_args = self.configs.dataset_conf.get('dataset', {})
data_loader_args = self.configs.dataset_conf.get('dataLoader', {})
if is_train:
self.train_dataset = MAClsDataset(data_list_path=self.configs.dataset_conf.train_list,
audio_featurizer=self.audio_featurizer,
aug_conf=self.data_augment_configs,
mode='train',
**dataset_args)
# 设置支持多卡训练
train_sampler = RandomSampler(self.train_dataset)
if torch.cuda.device_count() > 1:
# 设置支持多卡训练
train_sampler = DistributedSampler(dataset=self.train_dataset)
self.train_loader = DataLoader(dataset=self.train_dataset,
collate_fn=collate_fn,
sampler=train_sampler,
**data_loader_args)
# 获取测试数据
data_loader_args.drop_last = False
dataset_args.max_duration = self.configs.dataset_conf.eval_conf.max_duration
data_loader_args.batch_size = self.configs.dataset_conf.eval_conf.batch_size
self.test_dataset = MAClsDataset(data_list_path=self.configs.dataset_conf.test_list,
audio_featurizer=self.audio_featurizer,
mode='eval',
**dataset_args)
self.test_loader = DataLoader(dataset=self.test_dataset,
collate_fn=collate_fn,
shuffle=False,
**data_loader_args)
def extract_features(self, save_dir='dataset/features', max_duration=100):
""" 提取特征保存文件
:param save_dir: 保存路径
:param max_duration: 提取特征的最大时长避免过长显存不足单位秒
"""
self.audio_featurizer = AudioFeaturizer(feature_method=self.configs.preprocess_conf.feature_method,
use_hf_model=self.configs.preprocess_conf.get('use_hf_model', False),
method_args=self.configs.preprocess_conf.get('method_args', {}))
dataset_args = self.configs.dataset_conf.get('dataset', {})
dataset_args.max_duration = max_duration
data_loader_args = self.configs.dataset_conf.get('dataLoader', {})
data_loader_args.drop_last = False
for data_list in [self.configs.dataset_conf.train_list, self.configs.dataset_conf.test_list]:
test_dataset = MAClsDataset(data_list_path=data_list,
audio_featurizer=self.audio_featurizer,
mode='extract_feature',
**dataset_args)
test_loader = DataLoader(dataset=test_dataset,
collate_fn=collate_fn,
shuffle=False,
**data_loader_args)
save_data_list = data_list.replace('.txt', '_features.txt')
with open(save_data_list, 'w', encoding='utf-8') as f:
for features, labels, input_lens in tqdm(test_loader):
for i in range(len(features)):
feature, label, input_len = features[i], labels[i], input_lens[i]
feature = feature.numpy()[:input_len]
label = int(label)
save_path = os.path.join(save_dir, str(label),
f'{str(uuid.uuid4())}.npy').replace('\\', '/')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
np.save(save_path, feature)
f.write(f'{save_path}\t{label}\n')
logger.info(f'{data_list}列表中的数据已提取特征完成,新列表为:{save_data_list}')
def __setup_model(self, input_size, is_train=False):
""" 获取模型
:param input_size: 模型输入特征大小
:param is_train: 是否获取训练模型
"""
# 自动获取列表数量
if self.configs.model_conf.model_args.get('num_class', None) is None:
self.configs.model_conf.model_args.num_class = len(self.class_labels)
# 获取模型
self.model = build_model(input_size=input_size, configs=self.configs)
self.model.to(self.device)
if self.log_level == "DEBUG" or self.log_level == "INFO":
# 打印模型信息98是长度这个取决于输入的音频长度
summary(self.model, input_size=(1, 98, input_size))
# 使用Pytorch2.0的编译器
if self.configs.train_conf.use_compile and torch.__version__ >= "2" and platform.system().lower() == 'windows':
self.model = torch.compile(self.model, mode="reduce-overhead")
# print(self.model)
# 获取损失函数
label_smoothing = self.configs.train_conf.get('label_smoothing', 0.0)
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
if is_train:
if self.configs.train_conf.enable_amp:
self.amp_scaler = torch.GradScaler(init_scale=1024)
# 获取优化方法
self.optimizer = build_optimizer(params=self.model.parameters(), configs=self.configs)
# 学习率衰减函数
self.scheduler = build_lr_scheduler(optimizer=self.optimizer, step_per_epoch=len(self.train_loader),
configs=self.configs)
def __train_epoch(self, epoch_id, local_rank, writer, nranks=0):
"""训练一个epoch
:param epoch_id: 当前epoch
:param local_rank: 当前显卡id
:param writer: VisualDL对象
:param nranks: 所使用显卡的数量
"""
train_times, accuracies, loss_sum = [], [], []
start = time.time()
for batch_id, (features, label, input_len) in enumerate(self.train_loader):
if self.stop_train: break
if nranks > 1:
features = features.to(local_rank)
label = label.to(local_rank).long()
else:
features = features.to(self.device)
label = label.to(self.device).long()
# 执行模型计算,是否开启自动混合精度
with torch.autocast('cuda', enabled=self.configs.train_conf.enable_amp):
output = self.model(features)
# 计算损失值
los = self.loss(output, label)
# 是否开启自动混合精度
if self.configs.train_conf.enable_amp:
# loss缩放乘以系数loss_scaling
scaled = self.amp_scaler.scale(los)
scaled.backward()
else:
los.backward()
# 是否开启自动混合精度
if self.configs.train_conf.enable_amp:
self.amp_scaler.unscale_(self.optimizer)
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
# 计算准确率
acc = accuracy(output, label)
accuracies.append(acc)
loss_sum.append(los.data.cpu().numpy())
train_times.append((time.time() - start) * 1000)
self.train_step += 1
# 多卡训练只使用一个进程打印
if batch_id % self.configs.train_conf.log_interval == 0 and local_rank == 0:
batch_id = batch_id + 1
# 计算每秒训练数据量
train_speed = self.configs.dataset_conf.dataLoader.batch_size / (
sum(train_times) / len(train_times) / 1000)
# 计算剩余时间
self.train_eta_sec = (sum(train_times) / len(train_times)) * (self.max_step - self.train_step) / 1000
eta_str = str(timedelta(seconds=int(self.train_eta_sec)))
self.train_loss = sum(loss_sum) / len(loss_sum)
self.train_acc = sum(accuracies) / len(accuracies)
logger.info(f'Train epoch: [{epoch_id}/{self.configs.train_conf.max_epoch}], '
f'batch: [{batch_id}/{len(self.train_loader)}], '
f'loss: {self.train_loss:.5f}, accuracy: {self.train_acc:.5f}, '
f'learning rate: {self.scheduler.get_last_lr()[0]:>.8f}, '
f'speed: {train_speed:.2f} data/sec, eta: {eta_str}')
writer.add_scalar('Train/Loss', self.train_loss, self.train_log_step)
writer.add_scalar('Train/Accuracy', self.train_acc, self.train_log_step)
# 记录学习率
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], self.train_log_step)
train_times, accuracies, loss_sum = [], [], []
self.train_log_step += 1
start = time.time()
self.scheduler.step()
def train(self,
save_model_path='models/',
log_dir='log/',
max_epoch=None,
resume_model=None,
pretrained_model=None):
"""
训练模型
:param save_model_path: 模型保存的路径
:param log_dir: 保存VisualDL日志文件的路径
:param max_epoch: 最大训练轮数对应配置文件中的train_conf.max_epoch
:param resume_model: 恢复训练当为None则不使用预训练模型
:param pretrained_model: 预训练模型的路径当为None则不使用预训练模型
"""
# 获取有多少张显卡训练
nranks = torch.cuda.device_count()
local_rank = 0
writer = None
if local_rank == 0:
# 日志记录器
writer = LogWriter(logdir=log_dir)
if nranks > 1 and self.use_gpu:
# 初始化NCCL环境
dist.init_process_group(backend='nccl')
local_rank = int(os.environ["LOCAL_RANK"])
# 获取数据
self.__setup_dataloader(is_train=True)
# 获取模型
self.__setup_model(input_size=self.audio_featurizer.feature_dim, is_train=True)
# 加载预训练模型
self.model = load_pretrained(model=self.model, pretrained_model=pretrained_model, use_gpu=self.use_gpu)
# 加载恢复模型
self.model, self.optimizer, self.amp_scaler, self.scheduler, last_epoch, best_acc = \
load_checkpoint(configs=self.configs, model=self.model, optimizer=self.optimizer,
amp_scaler=self.amp_scaler, scheduler=self.scheduler, step_epoch=len(self.train_loader),
save_model_path=save_model_path, resume_model=resume_model)
# 支持多卡训练
if nranks > 1 and self.use_gpu:
self.model.to(local_rank)
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank])
logger.info('训练数据:{}'.format(len(self.train_dataset)))
self.train_loss, self.train_acc = None, None
self.eval_loss, self.eval_acc = None, None
self.test_log_step, self.train_log_step = 0, 0
if local_rank == 0:
writer.add_scalar('Train/lr', self.scheduler.get_last_lr()[0], last_epoch)
if max_epoch is not None:
self.configs.train_conf.max_epoch = max_epoch
# 最大步数
self.max_step = len(self.train_loader) * self.configs.train_conf.max_epoch
self.train_step = max(last_epoch, 0) * len(self.train_loader)
# 开始训练
for epoch_id in range(last_epoch, self.configs.train_conf.max_epoch):
if self.stop_train: break
epoch_id += 1
start_epoch = time.time()
# 训练一个epoch
self.__train_epoch(epoch_id=epoch_id, local_rank=local_rank, writer=writer, nranks=nranks)
# 多卡训练只使用一个进程执行评估和保存模型
if local_rank == 0:
if self.stop_eval: continue
logger.info('=' * 70)
self.eval_loss, self.eval_acc = self.evaluate()
logger.info('Test epoch: {}, time/epoch: {}, loss: {:.5f}, accuracy: {:.5f}'.format(
epoch_id, str(timedelta(seconds=(time.time() - start_epoch))), self.eval_loss, self.eval_acc))
logger.info('=' * 70)
writer.add_scalar('Test/Accuracy', self.eval_acc, self.test_log_step)
writer.add_scalar('Test/Loss', self.eval_loss, self.test_log_step)
self.test_log_step += 1
self.model.train()
# # 保存最优模型
if self.eval_acc >= best_acc:
best_acc = self.eval_acc
save_checkpoint(configs=self.configs, model=self.model, optimizer=self.optimizer,
amp_scaler=self.amp_scaler, save_model_path=save_model_path, epoch_id=epoch_id,
accuracy=self.eval_acc, best_model=True)
# 保存模型
save_checkpoint(configs=self.configs, model=self.model, optimizer=self.optimizer,
amp_scaler=self.amp_scaler, save_model_path=save_model_path, epoch_id=epoch_id,
accuracy=self.eval_acc)
def evaluate(self, resume_model=None, save_matrix_path=None):
"""
评估模型
:param resume_model: 所使用的模型
:param save_matrix_path: 保存混合矩阵的路径
:return: 评估结果
"""
if self.test_loader is None:
self.__setup_dataloader()
if self.model is None:
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
if resume_model is not None:
if os.path.isdir(resume_model):
resume_model = os.path.join(resume_model, 'model.pth')
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
model_state_dict = torch.load(resume_model, weights_only=False)
self.model.load_state_dict(model_state_dict)
logger.info(f'成功加载模型:{resume_model}')
self.model.eval()
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
eval_model = self.model.module
else:
eval_model = self.model
accuracies, losses, preds, labels = [], [], [], []
with torch.no_grad():
for batch_id, (features, label, input_lens) in enumerate(tqdm(self.test_loader, desc='执行评估')):
if self.stop_eval: break
features = features.to(self.device)
label = label.to(self.device).long()
output = eval_model(features)
los = self.loss(output, label)
# 计算准确率
acc = accuracy(output, label)
accuracies.append(acc)
# 模型预测标签
label = label.data.cpu().numpy()
output = output.data.cpu().numpy()
pred = np.argmax(output, axis=1)
preds.extend(pred.tolist())
# 真实标签
labels.extend(label.tolist())
losses.append(los.data.cpu().numpy())
loss = float(sum(losses) / len(losses)) if len(losses) > 0 else -1
acc = float(sum(accuracies) / len(accuracies)) if len(accuracies) > 0 else -1
# 保存混合矩阵
if save_matrix_path is not None:
try:
cm = confusion_matrix(labels, preds)
plot_confusion_matrix(cm=cm, save_path=os.path.join(save_matrix_path, f'{int(time.time())}.png'),
class_labels=self.class_labels)
except Exception as e:
logger.error(f'保存混淆矩阵失败:{e}')
self.model.train()
return loss, acc
def export(self, save_model_path='models/', resume_model='models/EcapaTdnn_Fbank/best_model/'):
"""
导出预测模型
:param save_model_path: 模型保存的路径
:param resume_model: 准备转换的模型路径
:return:
"""
self.__setup_model(input_size=self.audio_featurizer.feature_dim)
# 加载预训练模型
if os.path.isdir(resume_model):
resume_model = os.path.join(resume_model, 'model.pth')
assert os.path.exists(resume_model), f"{resume_model} 模型不存在!"
model_state_dict = torch.load(resume_model)
self.model.load_state_dict(model_state_dict)
logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
self.model.eval()
# 获取静态模型
infer_model = self.model.export()
infer_model_path = os.path.join(save_model_path,
f'{self.configs.use_model}_{self.configs.preprocess_conf.feature_method}',
'inference.pth')
os.makedirs(os.path.dirname(infer_model_path), exist_ok=True)
torch.jit.save(infer_model, infer_model_path)
logger.info("预测模型已保存:{}".format(infer_model_path))

@ -0,0 +1,162 @@
import json
import os
import shutil
import torch
from loguru import logger
from macls import __version__
def load_pretrained(model, pretrained_model, use_gpu=True):
"""加载预训练模型
:param model: 使用的模型
:param pretrained_model: 预训练模型路径
:param use_gpu: 模型是否使用GPU
:return: 加载的模型
"""
# 加载预训练模型
if pretrained_model is None: return model
if os.path.isdir(pretrained_model):
pretrained_model = os.path.join(pretrained_model, 'model.pth')
assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!"
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_dict = model.module.state_dict()
else:
model_dict = model.state_dict()
if torch.cuda.is_available() and use_gpu:
model_state_dict = torch.load(pretrained_model, weights_only=False)
else:
model_state_dict = torch.load(pretrained_model, weights_only=False, map_location='cpu')
# 过滤不存在的参数
for name, weight in model_dict.items():
if name in model_state_dict.keys():
if list(weight.shape) != list(model_state_dict[name].shape):
logger.warning(f'{name} not used, shape {list(model_state_dict[name].shape)} '
f'unmatched with {list(weight.shape)} in model.')
model_state_dict.pop(name, None)
# 加载权重
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
missing_keys, unexpected_keys = model.module.load_state_dict(model_state_dict, strict=False)
else:
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
if len(unexpected_keys) > 0:
logger.warning('Unexpected key(s) in state_dict: {}. '
.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
logger.warning('Missing key(s) in state_dict: {}. '
.format(', '.join('"{}"'.format(k) for k in missing_keys)))
logger.info('成功加载预训练模型:{}'.format(pretrained_model))
return model
def load_checkpoint(configs, model, optimizer, amp_scaler, scheduler,
step_epoch, save_model_path, resume_model):
"""加载模型
:param configs: 配置信息
:param model: 使用的模型
:param optimizer: 使用的优化方法
:param amp_scaler: 使用的自动混合精度
:param scheduler: 使用的学习率调整策略
:param step_epoch: 每个epoch的step数量
:param save_model_path: 模型保存路径
:param resume_model: 恢复训练的模型路径
"""
last_epoch1 = 0
accuracy1 = 0.
def load_model(model_path):
assert os.path.exists(os.path.join(model_path, 'model.pth')), "模型参数文件不存在!"
assert os.path.exists(os.path.join(model_path, 'optimizer.pth')), "优化方法参数文件不存在!"
state_dict = torch.load(os.path.join(model_path, 'model.pth'), weights_only=False)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
optimizer.load_state_dict(torch.load(os.path.join(model_path, 'optimizer.pth'), weights_only=False))
# 自动混合精度参数
if amp_scaler is not None and os.path.exists(os.path.join(model_path, 'scaler.pth')):
amp_scaler.load_state_dict(torch.load(os.path.join(model_path, 'scaler.pth')), weights_only=False)
with open(os.path.join(model_path, 'model.state'), 'r', encoding='utf-8') as f:
json_data = json.load(f)
last_epoch = json_data['last_epoch']
accuracy = json_data['accuracy']
logger.info('成功恢复模型参数和优化方法参数:{}'.format(model_path))
optimizer.step()
[scheduler.step() for _ in range(last_epoch * step_epoch)]
return last_epoch, accuracy
# 获取最后一个保存的模型
save_feature_method = configs.preprocess_conf.feature_method
if configs.preprocess_conf.get('use_hf_model', False):
save_feature_method = save_feature_method[:-1] if save_feature_method[-1] == '/' else save_feature_method
save_feature_method = os.path.basename(save_feature_method)
last_model_dir = os.path.join(save_model_path,
f'{configs.model_conf.model}_{save_feature_method}',
'last_model')
if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pth'))
and os.path.exists(os.path.join(last_model_dir, 'optimizer.pth'))):
if resume_model is not None:
last_epoch1, accuracy1 = load_model(resume_model)
else:
try:
# 自动获取最新保存的模型
last_epoch1, accuracy1 = load_model(last_model_dir)
except Exception as e:
logger.warning(f'尝试自动恢复最新模型失败,错误信息:{e}')
return model, optimizer, amp_scaler, scheduler, last_epoch1, accuracy1
# 保存模型
def save_checkpoint(configs, model, optimizer, amp_scaler, save_model_path, epoch_id,
accuracy=0., best_model=False):
"""保存模型
:param configs: 配置信息
:param model: 使用的模型
:param optimizer: 使用的优化方法
:param amp_scaler: 使用的自动混合精度
:param save_model_path: 模型保存路径
:param epoch_id: 当前epoch
:param accuracy: 当前准确率
:param best_model: 是否为最佳模型
"""
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
# 保存模型的路径
save_feature_method = configs.preprocess_conf.feature_method
if configs.preprocess_conf.get('use_hf_model', False):
save_feature_method = save_feature_method[:-1] if save_feature_method[-1] == '/' else save_feature_method
save_feature_method = os.path.basename(save_feature_method)
if best_model:
model_path = os.path.join(save_model_path,
f'{configs.model_conf.model}_{save_feature_method}', 'best_model')
else:
model_path = os.path.join(save_model_path,
f'{configs.model_conf.model}_{save_feature_method}', 'epoch_{}'.format(epoch_id))
os.makedirs(model_path, exist_ok=True)
# 保存模型参数
torch.save(optimizer.state_dict(), os.path.join(model_path, 'optimizer.pth'))
torch.save(state_dict, os.path.join(model_path, 'model.pth'))
# 自动混合精度参数
if amp_scaler is not None:
torch.save(amp_scaler.state_dict(), os.path.join(model_path, 'scaler.pth'))
with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f:
data = {"last_epoch": epoch_id, "accuracy": accuracy, "version": __version__,
"model": configs.model_conf.model, "feature_method": save_feature_method}
f.write(json.dumps(data, indent=4, ensure_ascii=False))
if not best_model:
last_model_path = os.path.join(save_model_path,
f'{configs.model_conf.model}_{save_feature_method}', 'last_model')
shutil.rmtree(last_model_path, ignore_errors=True)
shutil.copytree(model_path, last_model_path)
# 删除旧的模型
old_model_path = os.path.join(save_model_path,
f'{configs.model_conf.model}_{save_feature_method}',
'epoch_{}'.format(epoch_id - 3))
if os.path.exists(old_model_path):
shutil.rmtree(old_model_path)
logger.info('已保存模型:{}'.format(model_path))

@ -0,0 +1,31 @@
import os
import soundcard
import soundfile
class RecordAudio:
def __init__(self, channels=1, sample_rate=16000):
# 录音参数
self.channels = channels
self.sample_rate = sample_rate
# 获取麦克风
self.default_mic = soundcard.default_microphone()
def record(self, record_seconds=3, save_path=None):
"""录音
:param record_seconds: 录音时间默认3秒
:param save_path: 录音保存的路径后缀名为wav
:return: 音频的numpy数据
"""
print("开始录音......")
num_frames = int(record_seconds * self.sample_rate)
data = self.default_mic.record(samplerate=self.sample_rate, numframes=num_frames, channels=self.channels)
audio_data = data.squeeze()
print("录音已结束!")
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
soundfile.write(save_path, data=data, samplerate=self.sample_rate)
return audio_data

@ -0,0 +1,131 @@
import distutils.util
import os
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
def print_arguments(args=None, configs=None, title=None):
if args:
logger.info("----------- 额外配置参数 -----------")
for arg, value in sorted(vars(args).items()):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
if configs:
title = title if title else "配置文件参数"
logger.info(f"----------- {title} -----------")
for arg, value in sorted(configs.items()):
if isinstance(value, dict):
logger.info(f"{arg}:")
for a, v in sorted(value.items()):
if isinstance(v, dict):
logger.info(f"\t{a}:")
for a1, v1 in sorted(v.items()):
logger.info("\t\t%s: %s" % (a1, v1))
else:
logger.info("\t%s: %s" % (a, v))
else:
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
type = distutils.util.strtobool if type == bool else type
argparser.add_argument("--" + argname,
default=default,
type=type,
help=help + ' 默认: %(default)s.',
**kwargs)
class Dict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
def dict_to_object(dict_obj):
if not isinstance(dict_obj, dict):
return dict_obj
inst = Dict()
for k, v in dict_obj.items():
inst[k] = dict_to_object(v)
return inst
def plot_confusion_matrix(cm, save_path, class_labels, show=False):
"""
绘制混淆矩阵
@param cm: 混淆矩阵, 一个二维数组表示预测结果与真实结果的混淆情况
@param save_path: 保存路径, 字符串指定混淆矩阵图像的保存位置
@param class_labels: 类别名称, 一个列表包含各个类别的名称
@param show: 是否显示图像, 布尔值控制是否在绘图窗口显示混淆矩阵图像
"""
# 检测类别名称是否包含中文,是则设置相应字体
s = ''.join(class_labels)
is_ascii = all(ord(c) < 128 for c in s)
if not is_ascii:
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 初始化绘图参数并绘制混淆矩阵
plt.figure(figsize=(12, 8), dpi=100)
np.set_printoptions(precision=2)
# 在混淆矩阵中绘制每个格子的概率值
ind_array = np.arange(len(class_labels))
x, y = np.meshgrid(ind_array, ind_array)
for x_val, y_val in zip(x.flatten(), y.flatten()):
c = cm[y_val][x_val] / (np.sum(cm[:, x_val]) + 1e-6)
# 忽略概率值太小的格子
if c < 1e-4: continue
plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')
m = np.sum(cm, axis=0) + 1e-6
plt.imshow(cm / m, interpolation='nearest', cmap=plt.cm.binary)
plt.title('Confusion Matrix' if is_ascii else '混合矩阵')
plt.colorbar()
# 设置类别标签
xlocations = np.array(range(len(class_labels)))
plt.xticks(xlocations, class_labels, rotation=90)
plt.yticks(xlocations, class_labels)
plt.ylabel('Actual label' if is_ascii else '实际标签')
plt.xlabel('Predict label' if is_ascii else '预测标签')
# 调整刻度标记位置,提高可视化效果
tick_marks = np.array(range(len(class_labels))) + 0.5
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
# 保存图片
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, format='png')
if show:
# 显示图片
plt.show()
# 根据a的类型将b转换为相应的类型
def convert_string_based_on_type(a, b):
if isinstance(a, int):
try:
b = int(b)
except ValueError:
logger.error("无法将字符串转换为整数")
elif isinstance(a, float):
try:
b = float(b)
except ValueError:
logger.error("无法将字符串转换为浮点数")
elif isinstance(a, str):
return b
elif isinstance(a, bool):
b = b.lower() == 'true'
else:
try:
b = eval(b)
except Exception as e:
logger.exception("无法将字符串转换为其他类型,将忽略该参数类型转换")
return b

@ -0,0 +1,14 @@
import time
from macls.utils.record import RecordAudio
s = input('请输入你计划录音多少秒:')
record_seconds = int(s)
save_path = "dataset/save_audio/%s.wav" % str(int(time.time()*1000))
record_audio = RecordAudio()
input(f"按下回车键开机录音,录音{record_seconds}秒中:")
record_audio.record(record_seconds=record_seconds,
save_path=save_path)
print('文件保存在:%s' % save_path)

@ -0,0 +1,17 @@
numpy>=1.19.2
scipy>=1.6.3
librosa>=0.9.1
soundfile>=0.12.1
soundcard>=0.4.2
resampy>=0.2.2
numba>=0.53.0
pydub~=0.25.1
matplotlib>=3.5.2
pillow>=10.3.0
tqdm>=4.66.3
visualdl==2.5.3
pyyaml>=5.4.1
scikit-learn>=1.0.2
torchinfo>=1.7.2
loguru>=0.7.2
yeaudio>=0.0.7

@ -0,0 +1,54 @@
import shutil
from setuptools import setup, find_packages
import macls
VERSION = macls.__version__
# 复制配置文件到项目目录下
shutil.rmtree('./macls/configs/', ignore_errors=True)
shutil.copytree('./configs/', './macls/configs/')
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def parse_requirements():
with open('./requirements.txt', encoding="utf-8") as f:
requirements = f.readlines()
return requirements
if __name__ == "__main__":
setup(
name='macls',
packages=find_packages(),
package_data={'': ['configs/*']},
author='yeyupiaoling',
version=VERSION,
install_requires=parse_requirements(),
description='Audio Classification toolkit on Pytorch',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/yeyupiaoling/AudioClassification-Pytorch',
download_url='https://github.com/yeyupiaoling/AudioClassification-Pytorch.git',
keywords=['audio', 'pytorch'],
classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Topic :: Utilities'
],
license='Apache License 2.0',
ext_modules=[])
shutil.rmtree('./macls/configs/', ignore_errors=True)

@ -0,0 +1,29 @@
#!/bin/bash
download_dir=dataset/language
[ ! -d ${download_dir} ] && mkdir -p ${download_dir}
if [ ! -f ${download_dir}/test.tar.gz ]; then
echo "准备下载测试集"
wget --no-check-certificate https://speech-lab-share-data.oss-cn-shanghai.aliyuncs.com/3D-Speaker/test.tar.gz -P ${download_dir}
md5=$(md5sum ${download_dir}/test.tar.gz | awk '{print $1}')
[ $md5 != "45972606dd10d3f7c1c31f27acdfbed7" ] && echo "Wrong md5sum of 3dspeaker test.tar.gz" && exit 1
fi
if [ ! -f ${download_dir}/train.tar.gz ]; then
echo "准备下载训练集"
wget --no-check-certificate https://speech-lab-share-data.oss-cn-shanghai.aliyuncs.com/3D-Speaker/train.tar.gz -P ${download_dir}
md5=$(md5sum ${download_dir}/train.tar.gz | awk '{print $1}')
[ $md5 != "c2cea55fd22a2b867d295fb35a2d3340" ] && echo "Wrong md5sum of 3dspeaker train.tar.gz" && exit 1
fi
echo "下载完成!"
echo "准备解压"
tar -zxvf ${download_dir}/train.tar.gz -C ${rawdata_dir}/
tar -xzvf ${download_dir}/test.tar.gz -C ${rawdata_dir}/
echo "解压完成!"

@ -0,0 +1,30 @@
import argparse
import functools
from macls.trainer import MAClsTrainer
from macls.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('data_augment_configs', str, 'configs/augmentation.yml', '数据增强配置文件')
add_arg("local_rank", int, 0, '多卡训练需要的参数')
add_arg("use_gpu", bool, True, '是否使用GPU训练')
add_arg('save_model_path', str, 'models/', '模型保存的路径')
add_arg('log_dir', str, 'log/', '保存VisualDL日志文件的路径')
add_arg('resume_model', str, None, '恢复训练当为None则不使用预训练模型')
add_arg('pretrained_model', str, None, '预训练模型的路径当为None则不使用预训练模型')
add_arg('overwrites', str, None, '覆盖配置文件中的参数,比如"train_conf.max_epoch=100",多个用逗号隔开')
args = parser.parse_args()
print_arguments(args=args)
# 获取训练器
trainer = MAClsTrainer(configs=args.configs,
use_gpu=args.use_gpu,
data_augment_configs=args.data_augment_configs,
overwrites=args.overwrites)
trainer.train(save_model_path=args.save_model_path,
log_dir=args.log_dir,
resume_model=args.resume_model,
pretrained_model=args.pretrained_model)

@ -0,0 +1,298 @@
# 数据库配置文件说明
## 概述
本目录包含基于多模态的战场隐蔽威胁发现和定位系统的数据库配置文件。系统支持多种数据库类型包括SQLite、MySQL、PostgreSQL和MongoDB。
## 文件结构
```
database/
├── config.py # 数据库配置文件
├── db_utils.py # 数据库操作工具
├── init_database.py # 数据库初始化脚本
├── README.md # 说明文档
└── database_config.json # 数据库配置文件(自动生成)
```
## 功能特性
### 支持的数据库类型
- **SQLite** (默认) - 轻量级文件数据库
- **MySQL** - 关系型数据库
- **PostgreSQL** - 高级关系型数据库
- **MongoDB** - NoSQL文档数据库
### 数据库表结构
系统包含以下主要数据表:
1. **target_detections** - 目标检测记录表
- 存储检测到的目标信息
- 包含目标类型、置信度、位置、威胁等级等
2. **uav_flights** - 无人机飞行记录表
- 记录无人机飞行数据
- 包含飞行时间、高度、距离、模式等
3. **sound_locations** - 声源定位记录表
- 存储声源定位数据
- 包含声源坐标、强度、角度、频率等
4. **system_logs** - 系统日志表
- 记录系统运行日志
- 包含日志级别、模块、消息等
5. **user_actions** - 用户操作记录表
- 记录用户操作行为
- 包含操作类型、详情、IP地址等
6. **device_status** - 设备状态记录表
- 记录设备运行状态
- 包含设备类型、状态、电池、信号等
7. **map_markers** - 地图标记表
- 存储地图标记信息
- 包含标记类型、坐标、标题等
## 使用方法
### 1. 基本配置
```python
from database.config import get_db_config, get_connection_string
# 获取默认数据库配置
config = get_db_config()
# 获取指定数据库配置
mysql_config = get_db_config("mysql")
# 获取连接字符串
conn_str = get_connection_string()
```
### 2. 数据库操作
```python
from database.db_utils import (
DatabaseOperations,
target_detection_dao,
sound_location_dao,
system_log_dao
)
# 插入目标检测记录
detection_id = target_detection_dao.insert_detection(
target_type="vehicle",
confidence=0.95,
distance=150.5,
latitude=39.9042,
longitude=116.4074,
threat_level="high"
)
# 插入声源定位记录
sound_id = sound_location_dao.insert_sound_location(
source_x=100.0,
source_y=200.0,
strength=0.8,
angle=45.0,
frequency=1000.0
)
# 记录系统日志
system_log_dao.insert_log(
level="INFO",
module="target_detection",
message="检测到可疑目标"
)
```
### 3. 初始化数据库
```bash
# 运行数据库初始化脚本
python init_database.py
```
### 4. 自定义配置
编辑 `database_config.json` 文件来自定义数据库配置:
```json
{
"default_database": "sqlite",
"databases": {
"sqlite": {
"type": "sqlite",
"database": "battlefield_system.db",
"echo": false,
"pool_size": 5,
"max_overflow": 10
},
"mysql": {
"type": "mysql",
"host": "localhost",
"port": 3306,
"database": "battlefield_system",
"username": "root",
"password": "your_password",
"charset": "utf8mb4"
}
}
}
```
## 配置说明
### 数据库配置参数
- **host**: 数据库服务器地址
- **port**: 数据库端口
- **database**: 数据库名称
- **username**: 用户名
- **password**: 密码
- **charset**: 字符集
- **pool_size**: 连接池大小
- **max_overflow**: 最大溢出连接数
- **pool_timeout**: 连接池超时时间
- **pool_recycle**: 连接回收时间
- **echo**: 是否显示SQL语句
### 连接池配置
系统使用连接池来管理数据库连接,提高性能和稳定性:
- **pool_size**: 连接池中保持的连接数
- **max_overflow**: 超过pool_size后最多可以创建的连接数
- **pool_timeout**: 从连接池获取连接的超时时间
- **pool_recycle**: 连接在连接池中的回收时间
## 性能优化
### 索引优化
系统自动创建以下索引以提高查询性能:
- 时间戳索引:用于时间范围查询
- 类型索引:用于分类查询
- 置信度索引:用于阈值筛选
- 坐标索引:用于地理位置查询
### 查询优化建议
1. **使用参数化查询**避免SQL注入提高性能
2. **限制查询结果**使用LIMIT子句限制返回记录数
3. **合理使用索引**在WHERE子句中使用索引字段
4. **批量操作**使用executemany进行批量插入
## 错误处理
### 常见错误及解决方案
1. **连接失败**
- 检查数据库服务是否启动
- 验证连接参数是否正确
- 确认网络连接正常
2. **权限错误**
- 检查用户权限
- 确认数据库用户存在
- 验证密码是否正确
3. **表不存在**
- 运行初始化脚本创建表
- 检查表名是否正确
- 确认数据库选择正确
### 日志记录
系统自动记录数据库操作日志:
```python
from database.db_utils import log_system_event
# 记录系统事件
log_system_event(
level="ERROR",
module="database",
message="数据库连接失败",
details="连接超时"
)
```
## 安全考虑
### 数据安全
1. **密码加密**:数据库密码应加密存储
2. **访问控制**:限制数据库访问权限
3. **数据备份**:定期备份重要数据
4. **审计日志**:记录所有数据库操作
### SQL注入防护
1. **参数化查询**使用参数化查询避免SQL注入
2. **输入验证**:验证所有用户输入
3. **权限最小化**:使用最小权限原则
## 扩展开发
### 添加新的数据表
1. 在 `DatabaseTables` 类中定义表结构
2. 创建对应的DAO类
3. 添加索引定义
4. 更新初始化脚本
### 添加新的数据库类型
1. 在配置文件中添加新数据库配置
2. 实现连接逻辑
3. 更新连接字符串生成方法
4. 测试连接和操作
## 测试
### 运行测试
```bash
# 测试数据库配置
python config.py
# 测试数据库操作
python db_utils.py
# 完整初始化测试
python init_database.py
```
### 测试内容
- 数据库连接测试
- 表创建测试
- 数据插入测试
- 查询操作测试
- 性能测试
## 维护
### 日常维护
1. **监控连接池**:检查连接池使用情况
2. **清理日志**:定期清理系统日志
3. **优化查询**:分析慢查询并优化
4. **备份数据**:定期备份数据库
### 版本升级
1. **备份数据**:升级前备份所有数据
2. **测试升级**:在测试环境验证升级
3. **更新配置**:更新配置文件
4. **验证功能**:验证所有功能正常
## 技术支持
如有问题,请联系系统开发团队或查看系统日志获取详细信息。

@ -0,0 +1,432 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库配置文件
基于多模态的战场隐蔽威胁发现和定位系统数据库配置
支持多种数据库类型
- SQLite (默认)
- MySQL
- PostgreSQL
- MongoDB (NoSQL)
作者: 系统开发团队
创建时间: 2024
"""
import os
import json
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class DatabaseConfig:
"""数据库配置数据类"""
host: str = "localhost"
port: int = 3306
database: str = "battlefield_system"
username: str = "root"
password: str = ""
charset: str = "utf8mb4"
pool_size: int = 10
max_overflow: int = 20
pool_timeout: int = 30
pool_recycle: int = 3600
echo: bool = False
class DatabaseManager:
"""数据库管理器类"""
def __init__(self, config_file: str = "database_config.json"):
"""
初始化数据库管理器
Args:
config_file: 配置文件路径
"""
self.config_file = config_file
self.config = self._load_config()
self.connections = {}
def _load_config(self) -> Dict[str, Any]:
"""
加载数据库配置
Returns:
配置字典
"""
config_path = Path(__file__).parent / self.config_file
# 如果配置文件不存在,创建默认配置
if not config_path.exists():
self._create_default_config(config_path)
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"成功加载数据库配置: {config_path}")
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return self._get_default_config()
def _create_default_config(self, config_path: Path):
"""
创建默认配置文件
Args:
config_path: 配置文件路径
"""
default_config = self._get_default_config()
try:
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
logger.info(f"创建默认配置文件: {config_path}")
except Exception as e:
logger.error(f"创建配置文件失败: {e}")
def _get_default_config(self) -> Dict[str, Any]:
"""
获取默认配置
Returns:
默认配置字典
"""
return {
"default_database": "sqlite",
"databases": {
"sqlite": {
"type": "sqlite",
"database": "battlefield_system.db",
"echo": False,
"pool_size": 5,
"max_overflow": 10
},
"mysql": {
"type": "mysql",
"host": "localhost",
"port": 3306,
"database": "battlefield_system",
"username": "root",
"password": "",
"charset": "utf8mb4",
"pool_size": 10,
"max_overflow": 20,
"pool_timeout": 30,
"pool_recycle": 3600,
"echo": False
},
"postgresql": {
"type": "postgresql",
"host": "localhost",
"port": 5432,
"database": "battlefield_system",
"username": "postgres",
"password": "",
"pool_size": 10,
"max_overflow": 20,
"pool_timeout": 30,
"pool_recycle": 3600,
"echo": False
},
"mongodb": {
"type": "mongodb",
"host": "localhost",
"port": 27017,
"database": "battlefield_system",
"username": "",
"password": "",
"pool_size": 10,
"max_overflow": 20,
"echo": False
}
},
"logging": {
"level": "INFO",
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
}
}
def get_database_config(self, db_name: Optional[str] = None) -> DatabaseConfig:
"""
获取指定数据库的配置
Args:
db_name: 数据库名称如果为None则使用默认数据库
Returns:
数据库配置对象
"""
if db_name is None:
db_name = self.config.get("default_database", "sqlite")
db_config = self.config["databases"].get(db_name)
if not db_config:
raise ValueError(f"未找到数据库配置: {db_name}")
return DatabaseConfig(**db_config)
def get_connection_string(self, db_name: Optional[str] = None) -> str:
"""
获取数据库连接字符串
Args:
db_name: 数据库名称
Returns:
连接字符串
"""
config = self.get_database_config(db_name)
if config.database == "sqlite":
return f"sqlite:///{config.database}"
elif config.database == "mysql":
return (f"mysql+pymysql://{config.username}:{config.password}"
f"@{config.host}:{config.port}/{config.database}"
f"?charset={config.charset}")
elif config.database == "postgresql":
return (f"postgresql://{config.username}:{config.password}"
f"@{config.host}:{config.port}/{config.database}")
elif config.database == "mongodb":
return f"mongodb://{config.host}:{config.port}/{config.database}"
else:
raise ValueError(f"不支持的数据库类型: {config.database}")
class DatabaseTables:
"""数据库表结构定义"""
# 目标检测记录表
TARGET_DETECTION_TABLE = """
CREATE TABLE IF NOT EXISTS target_detections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
target_type VARCHAR(50) NOT NULL,
confidence FLOAT NOT NULL,
distance FLOAT,
latitude DOUBLE,
longitude DOUBLE,
altitude FLOAT,
image_path VARCHAR(255),
threat_level VARCHAR(20),
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 无人机飞行记录表
UAV_FLIGHT_TABLE = """
CREATE TABLE IF NOT EXISTS uav_flights (
id INTEGER PRIMARY KEY AUTOINCREMENT,
flight_id VARCHAR(50) UNIQUE NOT NULL,
start_time DATETIME,
end_time DATETIME,
duration INTEGER,
max_altitude FLOAT,
max_distance FLOAT,
flight_mode VARCHAR(20),
status VARCHAR(20),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 声源定位记录表
SOUND_LOCATION_TABLE = """
CREATE TABLE IF NOT EXISTS sound_locations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
source_x FLOAT NOT NULL,
source_y FLOAT NOT NULL,
strength FLOAT NOT NULL,
angle FLOAT NOT NULL,
frequency FLOAT,
confidence FLOAT,
latitude DOUBLE,
longitude DOUBLE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 系统日志表
SYSTEM_LOG_TABLE = """
CREATE TABLE IF NOT EXISTS system_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
level VARCHAR(10) NOT NULL,
module VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
details TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 用户操作记录表
USER_ACTIONS_TABLE = """
CREATE TABLE IF NOT EXISTS user_actions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
user_id VARCHAR(50),
action_type VARCHAR(50) NOT NULL,
action_detail TEXT,
ip_address VARCHAR(45),
user_agent TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 设备状态记录表
DEVICE_STATUS_TABLE = """
CREATE TABLE IF NOT EXISTS device_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
device_type VARCHAR(50) NOT NULL,
device_id VARCHAR(50) NOT NULL,
status VARCHAR(20) NOT NULL,
battery_level INTEGER,
signal_strength INTEGER,
temperature FLOAT,
humidity FLOAT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
# 地图标记表
MAP_MARKERS_TABLE = """
CREATE TABLE IF NOT EXISTS map_markers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
marker_type VARCHAR(50) NOT NULL,
latitude DOUBLE NOT NULL,
longitude DOUBLE NOT NULL,
altitude FLOAT,
title VARCHAR(100),
description TEXT,
icon_path VARCHAR(255),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
@classmethod
def get_all_tables(cls) -> Dict[str, str]:
"""
获取所有表结构定义
Returns:
表名和SQL语句的字典
"""
return {
"target_detections": cls.TARGET_DETECTION_TABLE,
"uav_flights": cls.UAV_FLIGHT_TABLE,
"sound_locations": cls.SOUND_LOCATION_TABLE,
"system_logs": cls.SYSTEM_LOG_TABLE,
"user_actions": cls.USER_ACTIONS_TABLE,
"device_status": cls.DEVICE_STATUS_TABLE,
"map_markers": cls.MAP_MARKERS_TABLE
}
class DatabaseInitializer:
"""数据库初始化器"""
def __init__(self, db_manager: DatabaseManager):
"""
初始化数据库初始化器
Args:
db_manager: 数据库管理器
"""
self.db_manager = db_manager
def initialize_database(self, db_name: Optional[str] = None):
"""
初始化数据库
Args:
db_name: 数据库名称
"""
try:
config = self.db_manager.get_database_config(db_name)
logger.info(f"开始初始化数据库: {config.database}")
# 这里可以添加具体的数据库初始化逻辑
# 例如创建表、插入初始数据等
logger.info(f"数据库初始化完成: {config.database}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
# 全局数据库管理器实例
db_manager = DatabaseManager()
# 导出常用函数
def get_db_config(db_name: Optional[str] = None) -> DatabaseConfig:
"""
获取数据库配置
Args:
db_name: 数据库名称
Returns:
数据库配置对象
"""
return db_manager.get_database_config(db_name)
def get_connection_string(db_name: Optional[str] = None) -> str:
"""
获取数据库连接字符串
Args:
db_name: 数据库名称
Returns:
连接字符串
"""
return db_manager.get_connection_string(db_name)
def initialize_database(db_name: Optional[str] = None):
"""
初始化数据库
Args:
db_name: 数据库名称
"""
initializer = DatabaseInitializer(db_manager)
initializer.initialize_database(db_name)
if __name__ == "__main__":
"""测试数据库配置"""
try:
# 测试获取配置
config = get_db_config()
print(f"默认数据库配置: {config}")
# 测试获取连接字符串
conn_str = get_connection_string()
print(f"连接字符串: {conn_str}")
# 显示所有表结构
tables = DatabaseTables.get_all_tables()
print(f"数据库表数量: {len(tables)}")
for table_name in tables.keys():
print(f" - {table_name}")
print("数据库配置测试完成")
except Exception as e:
print(f"测试失败: {e}")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save