|
|
|
|
@ -6,6 +6,7 @@
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <numeric>
|
|
|
|
|
#include <map>
|
|
|
|
|
|
|
|
|
|
#include "acoustic_analyzer/core/feature_extractor.h"
|
|
|
|
|
#include "acoustic_analyzer/core/gunshot_classifier.h"
|
|
|
|
|
@ -15,6 +16,14 @@
|
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
|
using namespace acoustic;
|
|
|
|
|
|
|
|
|
|
struct Prediction {
|
|
|
|
|
std::string file_path;
|
|
|
|
|
std::string true_label;
|
|
|
|
|
std::string pred_label;
|
|
|
|
|
float confidence = 0.0f;
|
|
|
|
|
float distance = -1.0f;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void print_usage(const char* prog) {
|
|
|
|
|
std::cerr << "Usage: " << prog << " <file_or_dir> [--model <onnx>] [--label_map <json>]" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
@ -24,6 +33,14 @@ bool ends_with(const std::string& s, const std::string& suffix) {
|
|
|
|
|
return s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string get_parent_folder_name(const std::string& path) {
|
|
|
|
|
fs::path p(path);
|
|
|
|
|
if (p.has_parent_path()) {
|
|
|
|
|
return p.parent_path().filename().string();
|
|
|
|
|
}
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float compute_spl(const std::vector<float>& audio) {
|
|
|
|
|
if (audio.empty()) return -100.0f;
|
|
|
|
|
float rms = 0.0f;
|
|
|
|
|
@ -32,58 +49,134 @@ float compute_spl(const std::vector<float>& audio) {
|
|
|
|
|
return 20.0f * std::log10(rms + 1e-10f) + 94.0f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void process_file(const std::string& path,
|
|
|
|
|
const std::string& model_path,
|
|
|
|
|
const std::string& label_map_path) {
|
|
|
|
|
Prediction process_file(const std::string& path,
|
|
|
|
|
GunshotClassifier& classifier) {
|
|
|
|
|
Prediction result;
|
|
|
|
|
result.file_path = path;
|
|
|
|
|
result.true_label = get_parent_folder_name(path);
|
|
|
|
|
|
|
|
|
|
WavFileSource wav(path);
|
|
|
|
|
if (!wav.open()) {
|
|
|
|
|
std::cerr << "[SKIP] Cannot open: " << path << std::endl;
|
|
|
|
|
return;
|
|
|
|
|
result.pred_label = "error";
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int sr = wav.sample_rate();
|
|
|
|
|
std::vector<std::vector<float>> audio;
|
|
|
|
|
size_t chunk = static_cast<size_t>(sr * 2.0);
|
|
|
|
|
size_t got = wav.read(audio, chunk);
|
|
|
|
|
if (got == 0 || audio.empty()) return;
|
|
|
|
|
if (got == 0 || audio.empty()) {
|
|
|
|
|
result.pred_label = "empty";
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Feature extraction
|
|
|
|
|
FeatureExtractor extractor(sr, 2048, 512, 64, 0.0f, 8000.0f, 0.97f);
|
|
|
|
|
Eigen::MatrixXf mel = extractor.MelSpectrogram(audio[0]);
|
|
|
|
|
|
|
|
|
|
// Classification
|
|
|
|
|
ClassifierConfig cc;
|
|
|
|
|
cc.model_path = model_path;
|
|
|
|
|
cc.label_map_path = label_map_path;
|
|
|
|
|
cc.threshold = 0.5f;
|
|
|
|
|
cc.smoothing_window = 1;
|
|
|
|
|
GunshotClassifier classifier(cc);
|
|
|
|
|
if (!classifier.IsLoaded()) {
|
|
|
|
|
std::cerr << "[ERROR] Failed to load classifier model" << std::endl;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto [label, confidence] = classifier.Predict(mel);
|
|
|
|
|
result.pred_label = label;
|
|
|
|
|
result.confidence = confidence;
|
|
|
|
|
|
|
|
|
|
// Distance estimation
|
|
|
|
|
float distance = -1.0f;
|
|
|
|
|
float distance_conf = 0.0f;
|
|
|
|
|
if (label != "ambient" && confidence > 0.5f) {
|
|
|
|
|
DistanceConfig dc;
|
|
|
|
|
dc.ref_spl_gunshot = 150.0f;
|
|
|
|
|
dc.attenuation_alpha = 0.6f;
|
|
|
|
|
DistanceEstimator de(dc);
|
|
|
|
|
float spl = compute_spl(audio[0]);
|
|
|
|
|
distance = de.Estimate(spl, label);
|
|
|
|
|
distance = de.UpdateKalman(distance);
|
|
|
|
|
distance_conf = 0.8f;
|
|
|
|
|
float dist = de.Estimate(spl, label);
|
|
|
|
|
dist = de.UpdateKalman(dist);
|
|
|
|
|
result.distance = dist;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::cout << "File: " << path << std::endl;
|
|
|
|
|
std::cout << " -> Label: " << label
|
|
|
|
|
std::cout << "File: " << fs::path(path).filename().string()
|
|
|
|
|
<< " | True: " << result.true_label
|
|
|
|
|
<< " | Pred: " << label
|
|
|
|
|
<< " | Conf: " << std::fixed << std::setprecision(4) << confidence
|
|
|
|
|
<< " | Dist: " << std::setprecision(2) << distance << "m"
|
|
|
|
|
<< " | DConf: " << distance_conf << std::endl;
|
|
|
|
|
<< " | Dist: " << std::setprecision(2) << result.distance << "m" << std::endl;
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void collect_wav_files(const std::string& target, std::vector<std::string>& out) {
|
|
|
|
|
if (fs::is_regular_file(target) && ends_with(target, ".wav")) {
|
|
|
|
|
out.push_back(target);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!fs::is_directory(target)) return;
|
|
|
|
|
|
|
|
|
|
for (const auto& entry : fs::recursive_directory_iterator(target)) {
|
|
|
|
|
if (entry.is_regular_file() && ends_with(entry.path().string(), ".wav")) {
|
|
|
|
|
out.push_back(entry.path().string());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::sort(out.begin(), out.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void print_report(const std::vector<Prediction>& results) {
|
|
|
|
|
std::map<std::string, int> total_by_true;
|
|
|
|
|
std::map<std::string, int> correct_by_true;
|
|
|
|
|
std::map<std::string, float> conf_sum_by_true;
|
|
|
|
|
std::map<std::string, std::map<std::string, int>> confusion;
|
|
|
|
|
|
|
|
|
|
int total = 0, correct = 0;
|
|
|
|
|
for (const auto& r : results) {
|
|
|
|
|
if (r.pred_label == "error" || r.pred_label == "empty") continue;
|
|
|
|
|
total++;
|
|
|
|
|
total_by_true[r.true_label]++;
|
|
|
|
|
conf_sum_by_true[r.true_label] += r.confidence;
|
|
|
|
|
confusion[r.true_label][r.pred_label]++;
|
|
|
|
|
if (r.true_label == r.pred_label) {
|
|
|
|
|
correct++;
|
|
|
|
|
correct_by_true[r.true_label]++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::cout << "\n==========================================" << std::endl;
|
|
|
|
|
std::cout << " VALIDATION REPORT" << std::endl;
|
|
|
|
|
std::cout << "==========================================" << std::endl;
|
|
|
|
|
std::cout << "Total samples: " << total << std::endl;
|
|
|
|
|
std::cout << "Correct: " << correct << std::endl;
|
|
|
|
|
std::cout << "Accuracy: " << std::fixed << std::setprecision(2)
|
|
|
|
|
<< (total > 0 ? 100.0f * correct / total : 0.0f) << "%" << std::endl;
|
|
|
|
|
|
|
|
|
|
std::cout << "\nPer-class breakdown:" << std::endl;
|
|
|
|
|
for (const auto& kv : total_by_true) {
|
|
|
|
|
const std::string& cls = kv.first;
|
|
|
|
|
int cls_total = kv.second;
|
|
|
|
|
int cls_correct = correct_by_true[cls];
|
|
|
|
|
float avg_conf = conf_sum_by_true[cls] / cls_total;
|
|
|
|
|
std::cout << " " << std::setw(10) << std::left << cls
|
|
|
|
|
<< " Count: " << std::setw(3) << cls_total
|
|
|
|
|
<< " Correct: " << std::setw(3) << cls_correct
|
|
|
|
|
<< " Acc: " << std::setw(6) << std::fixed << std::setprecision(2)
|
|
|
|
|
<< (100.0f * cls_correct / cls_total) << "%"
|
|
|
|
|
<< " AvgConf: " << std::setprecision(4) << avg_conf << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::cout << "\nConfusion matrix (rows=true, cols=pred):" << std::endl;
|
|
|
|
|
std::vector<std::string> labels;
|
|
|
|
|
for (const auto& row : confusion) labels.push_back(row.first);
|
|
|
|
|
for (const auto& row : confusion) {
|
|
|
|
|
for (const auto& col : row.second) {
|
|
|
|
|
if (std::find(labels.begin(), labels.end(), col.first) == labels.end()) {
|
|
|
|
|
labels.push_back(col.first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::sort(labels.begin(), labels.end());
|
|
|
|
|
|
|
|
|
|
std::cout << std::setw(12) << " ";
|
|
|
|
|
for (const auto& l : labels) std::cout << std::setw(10) << l;
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
for (const auto& true_l : labels) {
|
|
|
|
|
std::cout << std::setw(10) << std::left << true_l << " ";
|
|
|
|
|
for (const auto& pred_l : labels) {
|
|
|
|
|
int count = confusion.count(true_l) ? confusion[true_l].count(pred_l) ? confusion[true_l].at(pred_l) : 0 : 0;
|
|
|
|
|
std::cout << std::setw(10) << count;
|
|
|
|
|
}
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
}
|
|
|
|
|
std::cout << "==========================================" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|
@ -101,19 +194,35 @@ int main(int argc, char** argv) {
|
|
|
|
|
else if (std::strcmp(argv[i], "--label_map") == 0 && i + 1 < argc) label_map_path = argv[++i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (fs::is_directory(target)) {
|
|
|
|
|
std::vector<fs::path> files;
|
|
|
|
|
for (const auto& entry : fs::directory_iterator(target)) {
|
|
|
|
|
if (entry.is_regular_file() && ends_with(entry.path().string(), ".wav")) {
|
|
|
|
|
files.push_back(entry.path());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::sort(files.begin(), files.end());
|
|
|
|
|
for (const auto& f : files) {
|
|
|
|
|
process_file(f.string(), model_path, label_map_path);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
process_file(target, model_path, label_map_path);
|
|
|
|
|
ClassifierConfig cc;
|
|
|
|
|
cc.model_path = model_path;
|
|
|
|
|
cc.label_map_path = label_map_path;
|
|
|
|
|
cc.threshold = 0.5f;
|
|
|
|
|
cc.smoothing_window = 1;
|
|
|
|
|
GunshotClassifier classifier(cc);
|
|
|
|
|
if (!classifier.IsLoaded()) {
|
|
|
|
|
std::cerr << "[ERROR] Failed to load classifier model: " << model_path << std::endl;
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
std::cout << "Model loaded: " << model_path << std::endl;
|
|
|
|
|
std::cout << "Labels: ";
|
|
|
|
|
for (const auto& l : classifier.Labels()) std::cout << l << " ";
|
|
|
|
|
std::cout << "\n" << std::endl;
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> files;
|
|
|
|
|
collect_wav_files(target, files);
|
|
|
|
|
if (files.empty()) {
|
|
|
|
|
std::cerr << "No .wav files found in: " << target << std::endl;
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
std::cout << "Found " << files.size() << " WAV file(s)." << std::endl;
|
|
|
|
|
|
|
|
|
|
std::vector<Prediction> results;
|
|
|
|
|
results.reserve(files.size());
|
|
|
|
|
for (const auto& f : files) {
|
|
|
|
|
results.push_back(process_file(f, classifier));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
print_report(results);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|