fix(ros): Update ROS wrapper to match current Pipeline API

- Rewrite acoustic_node.cpp:
  - Use PipelineConfig instead of removed PipelineParams
  - Use Pipeline::Process(flat_samples) returning AcousticFrame
  - Remove obsolete init() and process_single_channel() calls
  - Add flatten_audio helper for interleaved multi-channel data
- Rewrite threat_publisher.cpp:
  - Implement ThreatPublisher::Impl PIMPL pattern
  - Accept AcousticFrame instead of old TrackedThreat vector
  - Add NumPublished() counter
- Fix CMakeLists.txt: remove stale KISSFFT_DIR from test includes
- Add build_core_test.bat for compiling all test targets on Windows
- All core tests pass (test_core_lib 6/6)
zhaochang_branch
赵昌 24 hours ago
parent 81879be4bc
commit 4a6908df18

@ -182,7 +182,7 @@ if(BUILD_TESTS)
add_executable(test_core_lib tests/test_core_lib.cpp ${CORE_BASE_SOURCES})
target_include_directories(test_core_lib PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include ${EIGEN3_INCLUDE_DIR} ${KISSFFT_DIR})
${CMAKE_CURRENT_SOURCE_DIR}/include ${EIGEN3_INCLUDE_DIR})
if(NOT WIN32)
target_link_libraries(test_core_lib PRIVATE m)
endif()

@ -3,39 +3,35 @@ chcp 65001 >nul
setlocal EnableDelayedExpansion
echo ==========================================
echo Acoustic Core Library - Quick Build
echo Acoustic Core Library Test Build
echo ==========================================
set SRC_ROOT=%~dp0
set EIGEN=%SRC_ROOT%third_party\eigen-3.4.0
set KISSFFT=%SRC_ROOT%third_party\kiss_fft
set ONNX_INC=%SRC_ROOT%third_party\onnxruntime\include
set ONNX_LIB=%SRC_ROOT%third_party\onnxruntime\lib\onnxruntime.lib
set INCLUDES=-I%SRC_ROOT%include -I%EIGEN% -I%KISSFFT% -I%ONNX_INC%
set INCLUDES=-I%SRC_ROOT%include -I%EIGEN%
set FLAGS=-std=c++17 -O2 -D_USE_MATH_DEFINES -Wa,-mbig-obj
set LIBS=%ONNX_LIB%
if not exist build mkdir build
REM --- test_core_lib ---
echo [1/3] Building test_core_lib.exe ...
g++ %FLAGS% %INCLUDES% ^
tests\test_core_lib.cpp ^
src\core\fft_utils.cpp ^
src\core\audio_buffer.cpp ^
src\core\feature_extractor.cpp ^
src\core\gcc_phat_localizer.cpp ^
src\core\distance_estimator.cpp ^
src\core\gcc_phat_localizer.cpp ^
src\core\threat_tracker.cpp ^
src\io\wav_file_source.cpp ^
-o build\test_core_lib.exe ^
%LIBS% -lws2_32
-D_stdcall=
if %ERRORLEVEL% NEQ 0 (
echo [FAIL] test_core_lib build failed.
exit /b 1
)
REM --- extract_mel_cpp ---
echo [2/3] Building extract_mel_cpp.exe ...
g++ %FLAGS% %INCLUDES% ^
tests\extract_mel_cpp.cpp ^
@ -43,27 +39,32 @@ g++ %FLAGS% %INCLUDES% ^
src\core\feature_extractor.cpp ^
src\io\wav_file_source.cpp ^
-o build\extract_mel_cpp.exe ^
%LIBS%
-D_stdcall=
if %ERRORLEVEL% NEQ 0 (
echo [FAIL] extract_mel_cpp build failed.
exit /b 1
)
REM --- test_classifier_cpp ---
echo [3/3] Building test_classifier_cpp.exe ...
g++ %FLAGS% %INCLUDES% ^
set ONNX_LIB=%SRC_ROOT%third_party\onnxruntime\lib\libonnxruntime.a
g++ %FLAGS% %INCLUDES% -I%ONNX_INC% ^
tests\test_classifier_cpp.cpp ^
src\core\fft_utils.cpp ^
src\core\feature_extractor.cpp ^
src\core\gunshot_classifier.cpp ^
src\io\wav_file_source.cpp ^
-o build\test_classifier_cpp.exe ^
%LIBS% -D_stdcall=
%ONNX_LIB% -D_stdcall=
if %ERRORLEVEL% NEQ 0 (
echo [FAIL] test_classifier_cpp build failed.
exit /b 1
)
echo.
echo [OK] All core tests built successfully in build\
echo [OK] All test executables built successfully.
echo.
echo Run tests:
echo build\test_core_lib.exe
echo build\extract_mel_cpp.exe dataset\binary\val\ambient\xxx.wav
echo build\test_classifier_cpp.exe --model models\gunshot_classifier.onnx --wav dataset\binary\val\ambient\xxx.wav --label_map models\label_map.json
endlocal

@ -83,32 +83,25 @@ private:
}
void init_pipeline() {
PipelineParams p;
p.feature.sample_rate = params_.sample_rate;
p.feature.n_mels = 64;
p.feature.n_fft = 2048;
p.feature.hop_length = 512;
p.feature.f_min = 0.0f;
p.feature.f_max = 8000.0f;
p.feature.preemphasis = 0.97f;
p.model_path = params_.model_path;
p.label_map_path = params_.label_map_path;
p.sample_rate = params_.sample_rate;
p.mic_geometry.num_mics = params_.n_channels;
p.mic_geometry.layout = "cross";
p.mic_geometry.spacing = 0.15f;
pipeline_ = std::make_unique<Pipeline>(p);
if (!pipeline_->init()) {
ROS_ERROR("Pipeline initialization failed");
}
PipelineConfig config;
config.sample_rate = static_cast<uint32_t>(params_.sample_rate);
config.chunk_duration = params_.chunk_duration;
config.hop_duration = params_.hop_duration;
config.n_mels = 64;
config.classifier.model_path = params_.model_path;
config.classifier.label_map_path = params_.label_map_path;
config.classifier.threshold = 0.7f;
config.classifier.smoothing_window = 3;
config.mic_array.num_mics = static_cast<uint32_t>(params_.n_channels);
config.mic_array.layout = "cross";
config.mic_array.spacing = 0.15f;
pipeline_ = std::make_unique<Pipeline>(config);
}
void init_source() {
if (source_type_ == "mic_array") {
audio_sub_ = nh_.subscribe(params_.mobile_phone_topic.empty()
? "/microphone_array/audio"
: "/microphone_array/audio",
10, &AcousticNode::on_mic_array_audio, this);
audio_sub_ = nh_.subscribe("/microphone_array/audio", 10,
&AcousticNode::on_mic_array_audio, this);
} else if (source_type_ == "mobile_phone") {
audio_sub_ = nh_.subscribe(params_.mobile_phone_topic, 10,
&AcousticNode::on_mobile_phone_audio, this);
@ -122,33 +115,37 @@ private:
}
}
// Convert multi-channel vector-of-vectors to flat interleaved format
std::vector<float> flatten_audio(const std::vector<std::vector<float>>& audio, int channels) {
if (audio.empty() || channels == 0) return {};
size_t samples = audio[0].size();
std::vector<float> flat(samples * channels);
for (size_t s = 0; s < samples; ++s) {
for (int ch = 0; ch < channels; ++ch) {
flat[s * channels + ch] = (ch < static_cast<int>(audio.size()) && s < audio[ch].size())
? audio[ch][s] : 0.0f;
}
}
return flat;
}
void on_mic_array_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) {
// Parse layout from dim: [channels, samples]
if (msg->layout.dim.size() < 2) return;
int channels = msg->layout.dim[0].size;
int samples = msg->layout.dim[1].size;
int channels = static_cast<int>(msg->layout.dim[0].size);
int samples = static_cast<int>(msg->layout.dim[1].size);
if (channels == 0 || samples == 0) return;
std::vector<std::vector<float>> audio(channels, std::vector<float>(samples));
for (int ch = 0; ch < channels; ++ch) {
for (int s = 0; s < samples; ++s) {
audio[ch][s] = msg->data[ch * samples + s];
}
}
std::vector<TrackedThreat> threats;
if (channels == 1) {
pipeline_->process_single_channel(audio[0], threats);
} else {
pipeline_->process(audio, threats);
}
// Publish threats via threat_publisher (would be called by main loop)
// Assuming data is interleaved or [channels x samples] row-major
std::vector<float> flat(msg->data.begin(), msg->data.end());
auto frame = pipeline_->Process(flat);
(void)frame; // Would be published by threat_publisher in main loop
}
void on_mobile_phone_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) {
if (msg->data.empty()) return;
std::vector<TrackedThreat> threats;
pipeline_->process_single_channel(msg->data, threats);
std::vector<float> flat(msg->data.begin(), msg->data.end());
auto frame = pipeline_->Process(flat);
(void)frame;
}
void process_wav_source() {
@ -161,12 +158,9 @@ private:
ros::shutdown();
return;
}
std::vector<TrackedThreat> threats;
if (wav_source_->num_channels() == 1) {
pipeline_->process_single_channel(audio[0], threats);
} else {
pipeline_->process(audio, threats);
}
auto flat = flatten_audio(audio, static_cast<int>(wav_source_->num_channels()));
auto frame = pipeline_->Process(flat);
(void)frame;
}
};

@ -3,16 +3,26 @@
namespace acoustic {
ThreatPublisher::ThreatPublisher(ros::NodeHandle& nh) {
pub_ = nh.advertise<acoustic_analyzer::AcousticThreatArray>("/acoustic/threats", 10);
struct ThreatPublisher::Impl {
ros::Publisher pub_;
std::size_t num_published_ = 0;
};
ThreatPublisher::ThreatPublisher(ros::NodeHandle& nh, const std::string& topic)
: impl_(std::make_unique<Impl>()) {
impl_->pub_ = nh.advertise<acoustic_analyzer::AcousticThreatArray>(topic, 10);
}
void ThreatPublisher::publish(const std::vector<TrackedThreat>& threats) {
ThreatPublisher::~ThreatPublisher() = default;
ThreatPublisher::ThreatPublisher(ThreatPublisher&&) noexcept = default;
ThreatPublisher& ThreatPublisher::operator=(ThreatPublisher&&) noexcept = default;
void ThreatPublisher::Publish(const AcousticFrame& frame) {
acoustic_analyzer::AcousticThreatArray msg;
msg.header.stamp = ros::Time::now();
msg.header.frame_id = "acoustic_array";
for (const auto& t : threats) {
for (const auto& t : frame.threats) {
acoustic_analyzer::AcousticThreat threat_msg;
threat_msg.threat_id = t.threat_id;
threat_msg.sound_type = t.sound_type;
@ -24,7 +34,12 @@ void ThreatPublisher::publish(const std::vector<TrackedThreat>& threats) {
msg.threats.push_back(threat_msg);
}
pub_.publish(msg);
impl_->pub_.publish(msg);
++impl_->num_published_;
}
std::size_t ThreatPublisher::NumPublished() const noexcept {
return impl_->num_published_;
}
} // namespace acoustic

Loading…
Cancel
Save