diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..d0f07203 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# Python +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +venv/ + +# Database +zhitu.db +*.db + +# Audio data (large files) +*.wav + +# Generated results +results/ + +# IDE +.idea/ +.vscode/ +*.iml diff --git a/src/drone-software/src/acoustic/build_demo_multichannel.bat b/src/drone-software/src/acoustic/build_demo_multichannel.bat new file mode 100644 index 00000000..429b4d94 --- /dev/null +++ b/src/drone-software/src/acoustic/build_demo_multichannel.bat @@ -0,0 +1,47 @@ +@echo off +chcp 65001 >nul +setlocal EnableDelayedExpansion + +echo ========================================== +echo Acoustic Offline Multichannel Demo Build +echo ========================================== + +set SRC_ROOT=%~dp0 +set EIGEN=%SRC_ROOT%third_party\eigen-3.4.0 +set ONNX_INC=%SRC_ROOT%third_party\onnxruntime\include +set ONNX_LIB=%SRC_ROOT%third_party\onnxruntime\lib\libonnxruntime.a + +set INCLUDES=-I%SRC_ROOT%include -I%EIGEN% -I%ONNX_INC% +set FLAGS=-std=c++17 -O2 -D_USE_MATH_DEFINES -Wa,-mbig-obj +set LIBS=%ONNX_LIB% + +if not exist build mkdir build + +echo [1/1] Building demo_offline_multichannel.exe ... +g++ %FLAGS% %INCLUDES% ^ + tests\demo_offline_multichannel.cpp ^ + src\core\pipeline.cpp ^ + src\core\audio_buffer.cpp ^ + src\core\fft_utils.cpp ^ + src\core\feature_extractor.cpp ^ + src\core\gunshot_classifier.cpp ^ + src\core\gcc_phat_localizer.cpp ^ + src\core\distance_estimator.cpp ^ + src\core\threat_tracker.cpp ^ + src\io\wav_file_source.cpp ^ + -o build\demo_offline_multichannel.exe ^ + %LIBS% -D_stdcall= +if %ERRORLEVEL% NEQ 0 ( + echo [FAIL] demo_offline_multichannel build failed. + exit /b 1 +) + +echo. +echo [OK] demo_offline_multichannel.exe built successfully in build\^ +echo. +echo Usage: build\demo_offline_multichannel.exe dataset\multichannel_test.wav --num_mics 4 --layout cross +echo build\demo_offline_multichannel.exe dataset\real\threat\ --threshold 0.6 --num_mics 4 +echo. +echo With ground-truth for error analysis: +echo build\demo_offline_multichannel.exe synth_90deg_100m.wav --ground_azimuth 90 --ground_distance 100 +endlocal diff --git a/src/drone-software/src/acoustic/include/acoustic_analyzer/core/types.h b/src/drone-software/src/acoustic/include/acoustic_analyzer/core/types.h index 8b7164f6..f4a58616 100644 --- a/src/drone-software/src/acoustic/include/acoustic_analyzer/core/types.h +++ b/src/drone-software/src/acoustic/include/acoustic_analyzer/core/types.h @@ -86,6 +86,17 @@ struct PipelineConfig { float confidence_threshold = 0.7f; ///< 输出置信度阈值 float min_event_interval = 0.3f; ///< 同一威胁最小上报间隔 (秒) + // 特征提取参数(原 hard-coded,现可配置) + uint32_t n_fft = 2048; ///< FFT 长度 + uint32_t hop_length = 512; ///< 帧移样本数 + float f_min = 0.0f; ///< 最低频率 (Hz) + float f_max = 8000.0f; ///< 最高频率 (Hz) + float preemphasis = 0.97f; ///< 预加重系数 + + // 定位参数(原 hard-coded,现可配置) + float max_tdoa = 0.00044f; ///< 最大允许时延 (秒) + int interpolation_factor = 4; ///< GCC-PHAT 插值倍数 + MicArrayConfig mic_array; ClassifierConfig classifier; DistanceConfig distance; diff --git a/src/drone-software/src/acoustic/include/acoustic_analyzer/ros/acoustic_node.h b/src/drone-software/src/acoustic/include/acoustic_analyzer/ros/acoustic_node.h index b5063ea8..9e44dd23 100644 --- a/src/drone-software/src/acoustic/include/acoustic_analyzer/ros/acoustic_node.h +++ b/src/drone-software/src/acoustic/include/acoustic_analyzer/ros/acoustic_node.h @@ -4,8 +4,12 @@ namespace acoustic { +class ThreatPublisher; class AcousticNode; -std::unique_ptr create_acoustic_node(ros::NodeHandle& nh, ros::NodeHandle& pnh); +std::unique_ptr create_acoustic_node( + ros::NodeHandle& nh, + ros::NodeHandle& pnh, + ThreatPublisher* publisher = nullptr); } // namespace acoustic diff --git a/src/drone-software/src/acoustic/scripts/generate_multichannel_test.py b/src/drone-software/src/acoustic/scripts/generate_multichannel_test.py index d922dbfe..70816414 100644 --- a/src/drone-software/src/acoustic/scripts/generate_multichannel_test.py +++ b/src/drone-software/src/acoustic/scripts/generate_multichannel_test.py @@ -5,8 +5,7 @@ """ import argparse import numpy as np -import torch -import torchaudio +import soundfile as sf def generate_gunshot_signal(sr=16000, duration=2.0, freq=1000, decay=0.05): @@ -64,8 +63,8 @@ def main(): if max_val > 0: multich = multich / max_val * 0.9 - tensor = torch.from_numpy(multich) - torchaudio.save(args.output, tensor, args.sr) + # soundfile expects [samples, channels] for multi-channel + sf.write(args.output, multich.T, args.sr, subtype="PCM_16") print(f"Generated {args.output}: azimuth={args.azimuth}°, distance={args.distance}m, shape={multich.shape}") diff --git a/src/drone-software/src/acoustic/scripts/inject_noise_eval.py b/src/drone-software/src/acoustic/scripts/inject_noise_eval.py new file mode 100644 index 00000000..459ab5e1 --- /dev/null +++ b/src/drone-software/src/acoustic/scripts/inject_noise_eval.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" +噪声注入批量测试脚本 + +功能: +1. 对干净音频按指定 SNR 注入噪声,生成测试样本 +2. 调用 C++ 离线推理程序批量检测 +3. 输出 CSV 统计结果,便于绘制 Pd-SNR / 定位误差-SNR 曲线 + +用法示例: + # 单通道分类测试 + python inject_noise_eval.py \\ + --clean_dir dataset/real/threat \\ + --noise noise_samples/drone_hover.wav \\ + --snrs 20 10 5 0 -5 \\ + --exe build/demo_offline.exe \\ + --output results/threat_vs_drone_noise.csv + + # 多通道定位测试 + python inject_noise_eval.py \\ + --clean_dir dataset/multichannel \\ + --noise noise_samples/white.wav \\ + --snrs 20 10 5 0 \\ + --exe build/demo_offline_multichannel.exe \\ + --num_mics 4 --layout cross \\ + --output results/loc_snr_sweep.csv +""" + +import argparse +import csv +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +import numpy as np +import soundfile as sf + + +def list_wav_files(directory: str): + """递归列出目录下所有 wav 文件。""" + return sorted([str(p) for p in Path(directory).rglob("*.wav")]) + + +def compute_rms(signal: np.ndarray): + """计算信号 RMS。""" + return np.sqrt(np.mean(signal ** 2)) + + +def mix_with_noise(clean: np.ndarray, noise: np.ndarray, target_snr_db: float): + """ + 将噪声以指定 SNR 混入干净信号。 + 如果噪声比干净信号短,会循环拼接。 + 支持多通道 clean(noise 会先混为单通道再扩展到匹配通道数,或直接用同通道数)。 + """ + # Ensure 2D: [channels, samples] + if clean.ndim == 1: + clean = clean.reshape(1, -1) + if noise.ndim == 1: + noise = noise.reshape(1, -1) + + n_ch, n_samples = clean.shape + + # Tile or truncate noise to match length + noise_tiled = np.tile(noise, (1, int(np.ceil(n_samples / noise.shape[1])) + 1)) + noise_tiled = noise_tiled[:, :n_samples] + + # If noise has fewer channels, replicate first channel + if noise_tiled.shape[0] < n_ch: + noise_tiled = np.repeat(noise_tiled[:1, :], n_ch, axis=0) + elif noise_tiled.shape[0] > n_ch: + noise_tiled = noise_tiled[:n_ch, :] + + # Compute per-channel RMS and average for a global scaling factor + clean_rms = compute_rms(clean) + if clean_rms < 1e-10: + return clean # silent signal, return as-is + + noise_rms = compute_rms(noise_tiled) + if noise_rms < 1e-10: + return clean + + target_noise_rms = clean_rms / (10 ** (target_snr_db / 20.0)) + scale = target_noise_rms / noise_rms + mixed = clean + noise_tiled * scale + + # Normalize to prevent clipping + max_val = np.max(np.abs(mixed)) + if max_val > 1.0: + mixed = mixed / max_val * 0.95 + + return mixed + + +def run_classifier(exe_path: str, wav_path: str, extra_args: list) -> dict: + """ + 调用 C++ 离线推理程序,解析其标准输出。 + 返回字典包含 pred_label, confidence 等。 + """ + import re + cmd = [exe_path, wav_path] + extra_args + try: + result = subprocess.run(cmd, capture_output=True, timeout=30) + # Decode with GBK first (Windows console default), fallback to UTF-8 + try: + output = result.stdout.decode("gbk", errors="replace") + except Exception: + output = result.stdout.decode("utf-8", errors="replace") + except Exception as e: + return {"error": str(e)} + + parsed = {} + # Robust regex parsing for output like: + # File: xxx.wav | True: threat | Pred: threat | Conf: 0.9234 | Az: 90.00° | El: 0.00° | Dist: 95.20m + for line in output.splitlines(): + if line.startswith("File:") and "Pred:" in line: + m = re.search(r"Pred:\s*(\S+)", line) + if m: + parsed["pred"] = m.group(1) + m = re.search(r"Conf:\s*([0-9.]+)", line) + if m: + parsed["conf"] = float(m.group(1)) + m = re.search(r"Az:\s*([0-9.+-]+)", line) + if m: + parsed["azimuth"] = float(m.group(1)) + m = re.search(r"El:\s*([0-9.+-]+)", line) + if m: + parsed["elevation"] = float(m.group(1)) + m = re.search(r"Dist:\s*([0-9.+-]+)", line) + if m: + parsed["distance"] = float(m.group(1)) + break + return parsed + + +def main(): + parser = argparse.ArgumentParser(description="Noise injection evaluation for acoustic analyzer") + parser.add_argument("--clean_dir", required=True, help="Directory containing clean WAV files") + parser.add_argument("--noise", required=True, help="Noise WAV file path") + parser.add_argument("--snrs", nargs="+", type=float, required=True, + help="List of SNR values in dB, e.g. 20 10 5 0 -5") + parser.add_argument("--exe", required=True, help="Path to demo_offline or demo_offline_multichannel exe") + parser.add_argument("--output", required=True, help="Output CSV path") + parser.add_argument("--num_mics", type=int, default=1, help="Number of channels (for multichannel exe)") + parser.add_argument("--layout", default="cross", help="Array layout (for multichannel exe)") + parser.add_argument("--threshold", type=float, default=0.5, help="Detection threshold") + parser.add_argument("--ground_azimuth", type=float, default=None, help="Ground-truth azimuth for error calc") + parser.add_argument("--ground_distance", type=float, default=None, help="Ground-truth distance for error calc") + args = parser.parse_args() + + clean_files = list_wav_files(args.clean_dir) + if not clean_files: + print(f"[ERROR] No WAV files found in {args.clean_dir}") + sys.exit(1) + + noise_data, noise_sr = sf.read(args.noise, dtype="float32") + if noise_data.ndim == 1: + noise_data = noise_data.reshape(1, -1) + else: + noise_data = noise_data.T # [channels, samples] + + # Prepare extra args for C++ exe + extra_args = [f"--threshold", str(args.threshold)] + if "multichannel" in os.path.basename(args.exe).lower(): + extra_args += ["--num_mics", str(args.num_mics), "--layout", args.layout] + if args.ground_azimuth is not None: + extra_args += ["--ground_azimuth", str(args.ground_azimuth)] + if args.ground_distance is not None: + extra_args += ["--ground_distance", str(args.ground_distance)] + + # CSV header + fieldnames = ["file", "true_label", "snr_db", "pred_label", "confidence", + "azimuth", "elevation", "distance"] + if args.ground_azimuth is not None: + fieldnames.append("azimuth_error") + if args.ground_distance is not None: + fieldnames.append("distance_error") + + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + + with open(args.output, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + total_runs = len(clean_files) * len(args.snrs) + run_idx = 0 + + with tempfile.TemporaryDirectory() as tmpdir: + for clean_path in clean_files: + clean_data, clean_sr = sf.read(clean_path, dtype="float32") + true_label = Path(clean_path).parent.name + + # Ensure 2D + if clean_data.ndim == 1: + clean_data = clean_data.reshape(1, -1) + else: + clean_data = clean_data.T + + for snr in args.snrs: + run_idx += 1 + mixed = mix_with_noise(clean_data, noise_data, snr) + + # Save temp WAV (interleaved) + if mixed.shape[0] == 1: + mixed_interleaved = mixed[0] + else: + mixed_interleaved = mixed.T # [samples, channels] + + tmp_wav = os.path.join(tmpdir, f"mix_{run_idx:04d}.wav") + sf.write(tmp_wav, mixed_interleaved, clean_sr, subtype="PCM_16") + + # Run inference + parsed = run_classifier(args.exe, tmp_wav, extra_args) + + row = { + "file": Path(clean_path).name, + "true_label": true_label, + "snr_db": snr, + "pred_label": parsed.get("pred", "error"), + "confidence": parsed.get("conf", 0.0), + "azimuth": parsed.get("azimuth", ""), + "elevation": parsed.get("elevation", ""), + "distance": parsed.get("distance", ""), + } + + if args.ground_azimuth is not None and "azimuth" in parsed: + err = abs(parsed["azimuth"] - args.ground_azimuth) + if err > 180: + err = 360 - err + row["azimuth_error"] = round(err, 2) + + if args.ground_distance is not None and "distance" in parsed: + row["distance_error"] = round(abs(parsed["distance"] - args.ground_distance), 2) + + writer.writerow(row) + print(f"[{run_idx}/{total_runs}] {Path(clean_path).name} SNR={snr}dB -> {row['pred_label']} ({row['confidence']:.3f})") + + print(f"\n[OK] Results saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/src/drone-software/src/acoustic/src/core/gcc_phat_localizer.cpp b/src/drone-software/src/acoustic/src/core/gcc_phat_localizer.cpp index b8915ba7..635623aa 100644 --- a/src/drone-software/src/acoustic/src/core/gcc_phat_localizer.cpp +++ b/src/drone-software/src/acoustic/src/core/gcc_phat_localizer.cpp @@ -127,7 +127,9 @@ struct GccPhatLocalizer::Impl { } float delay_samples = max_idx + p; float delay_sec = delay_samples / static_cast(sample_rate); - return delay_sec; + // Note: GCC-PHAT IFFT peak at tau corresponds to ch2[n] = ch1[n - tau], + // so tau is actually the negative of the physical delay of ch2 relative to ch1. + return -delay_sec; } bool SolveDirection(const std::vector& delays, float& azimuth, float& elevation) { diff --git a/src/drone-software/src/acoustic/src/core/pipeline.cpp b/src/drone-software/src/acoustic/src/core/pipeline.cpp index 7fca0274..4a1f7d73 100644 --- a/src/drone-software/src/acoustic/src/core/pipeline.cpp +++ b/src/drone-software/src/acoustic/src/core/pipeline.cpp @@ -6,7 +6,17 @@ #include "acoustic_analyzer/core/distance_estimator.h" #include "acoustic_analyzer/core/threat_tracker.h" +// yaml-cpp is optional: if absent, Pipeline still works but FromYaml returns defaults +#if defined(__has_include) +# if __has_include() +# define ACOUSTIC_HAS_YAML_CPP 1 +# endif +#endif + +#ifdef ACOUSTIC_HAS_YAML_CPP #include +#endif + #include #include #include @@ -40,8 +50,13 @@ struct Pipeline::Impl { static_cast(config.mic_array.num_mics)); feature_extractor = std::make_unique( - static_cast(config.sample_rate), 2048, 512, - static_cast(config.n_mels)); + static_cast(config.sample_rate), + static_cast(config.n_fft), + static_cast(config.hop_length), + static_cast(config.n_mels), + config.f_min, + config.f_max, + config.preemphasis); classifier = std::make_unique(config.classifier); @@ -49,7 +64,8 @@ struct Pipeline::Impl { localizer = std::make_unique( config.mic_array, static_cast(config.sample_rate), - 0.00044f, 4); + config.max_tdoa, + config.interpolation_factor); } distance_estimator = std::make_unique(config.distance); @@ -239,6 +255,7 @@ AcousticFrame Pipeline::Process(const std::vector& audio_samples) { PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) { PipelineConfig config; +#ifdef ACOUSTIC_HAS_YAML_CPP try { YAML::Node root = YAML::LoadFile(yaml_path); @@ -283,19 +300,19 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) { auto feat = root["features"]; if (feat["n_mels"]) config.n_mels = feat["n_mels"].as(); if (feat["n_fft"]) { - // Not stored in PipelineConfig directly; FeatureExtractor uses hard-coded 2048 + config.n_fft = feat["n_fft"].as(); } if (feat["hop_length"]) { - // Not stored in PipelineConfig directly + config.hop_length = feat["hop_length"].as(); } if (feat["f_min"]) { - // Not stored directly + config.f_min = feat["f_min"].as(); } if (feat["f_max"]) { - // Not stored directly + config.f_max = feat["f_max"].as(); } if (feat["preemphasis"]) { - // Not stored directly + config.preemphasis = feat["preemphasis"].as(); } } @@ -314,10 +331,10 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) { if (root["localization"]) { auto loc = root["localization"]; if (loc["max_tdoa"]) { - // Not stored directly in PipelineConfig + config.max_tdoa = loc["max_tdoa"].as(); } if (loc["interpolation_factor"]) { - // Not stored directly + config.interpolation_factor = loc["interpolation_factor"].as(); } } @@ -356,6 +373,9 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) { (void)e; // On parse error, return defaults } +#else + (void)yaml_path; +#endif return config; } diff --git a/src/drone-software/src/acoustic/src/main.cpp b/src/drone-software/src/acoustic/src/main.cpp index 2d443707..0e9056be 100644 --- a/src/drone-software/src/acoustic/src/main.cpp +++ b/src/drone-software/src/acoustic/src/main.cpp @@ -7,8 +7,8 @@ int main(int argc, char** argv) { ros::NodeHandle nh; ros::NodeHandle pnh("~"); - auto node = acoustic::create_acoustic_node(nh, pnh); acoustic::ThreatPublisher publisher(nh); + auto node = acoustic::create_acoustic_node(nh, pnh, &publisher); node->run(); return 0; diff --git a/src/drone-software/src/acoustic/src/ros/acoustic_node.cpp b/src/drone-software/src/acoustic/src/ros/acoustic_node.cpp index e1a39a79..8f4a0eb9 100644 --- a/src/drone-software/src/acoustic/src/ros/acoustic_node.cpp +++ b/src/drone-software/src/acoustic/src/ros/acoustic_node.cpp @@ -15,7 +15,8 @@ namespace acoustic { class AcousticNode { public: - AcousticNode(ros::NodeHandle& nh, ros::NodeHandle& pnh) : nh_(nh), pnh_(pnh) { + AcousticNode(ros::NodeHandle& nh, ros::NodeHandle& pnh, ThreatPublisher* publisher) + : nh_(nh), pnh_(pnh), threat_publisher_(publisher) { load_params(); init_pipeline(); init_source(); @@ -36,6 +37,7 @@ private: ros::NodeHandle nh_; ros::NodeHandle pnh_; ros::Subscriber audio_sub_; + ThreatPublisher* threat_publisher_ = nullptr; std::string source_type_; std::unique_ptr pipeline_; @@ -129,6 +131,12 @@ private: return flat; } + void publish_if_available(const AcousticFrame& frame) { + if (threat_publisher_ && !frame.is_clear) { + threat_publisher_->Publish(frame); + } + } + void on_mic_array_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) { if (msg->layout.dim.size() < 2) return; int channels = static_cast(msg->layout.dim[0].size); @@ -138,14 +146,14 @@ private: // Assuming data is interleaved or [channels x samples] row-major std::vector flat(msg->data.begin(), msg->data.end()); auto frame = pipeline_->Process(flat); - (void)frame; // Would be published by threat_publisher in main loop + publish_if_available(frame); } void on_mobile_phone_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) { if (msg->data.empty()) return; std::vector flat(msg->data.begin(), msg->data.end()); auto frame = pipeline_->Process(flat); - (void)frame; + publish_if_available(frame); } void process_wav_source() { @@ -160,12 +168,15 @@ private: } auto flat = flatten_audio(audio, static_cast(wav_source_->num_channels())); auto frame = pipeline_->Process(flat); - (void)frame; + publish_if_available(frame); } }; -std::unique_ptr create_acoustic_node(ros::NodeHandle& nh, ros::NodeHandle& pnh) { - return std::make_unique(nh, pnh); +std::unique_ptr create_acoustic_node( + ros::NodeHandle& nh, + ros::NodeHandle& pnh, + ThreatPublisher* publisher) { + return std::make_unique(nh, pnh, publisher); } } // namespace acoustic diff --git a/src/drone-software/src/acoustic/tests/demo_offline_multichannel.cpp b/src/drone-software/src/acoustic/tests/demo_offline_multichannel.cpp new file mode 100644 index 00000000..e02ad3dd --- /dev/null +++ b/src/drone-software/src/acoustic/tests/demo_offline_multichannel.cpp @@ -0,0 +1,331 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acoustic_analyzer/core/pipeline.h" +#include "acoustic_analyzer/io/wav_file_source.h" + +namespace fs = std::filesystem; +using namespace acoustic; + +struct Prediction { + std::string file_path; + std::string true_label; + std::string pred_label; + float confidence = 0.0f; + float azimuth = 0.0f; + float elevation = 0.0f; + float distance = -1.0f; + bool detected = false; +}; + +void print_usage(const char* prog) { + std::cerr << "Usage: " << prog << " [options]" << std::endl; + std::cerr << "Options:" << std::endl; + std::cerr << " --model ONNX model path (default: models/gunshot_classifier.onnx)" << std::endl; + std::cerr << " --label_map Label map file (default: models/label_map.json)" << std::endl; + std::cerr << " --threshold Detection threshold (default: 0.5)" << std::endl; + std::cerr << " --num_mics Number of channels in WAV (default: 4)" << std::endl; + std::cerr << " --spacing Mic spacing in meters (default: 0.15)" << std::endl; + std::cerr << " --layout Array layout: cross/linear/circular (default: cross)" << std::endl; + std::cerr << " --ref_spl Reference SPL for distance estimation (default: 150)" << std::endl; + std::cerr << " --ground_azimuth Ground-truth azimuth for error calc (optional)" << std::endl; + std::cerr << " --ground_distance Ground-truth distance for error calc (optional)" << std::endl; +} + +bool ends_with(const std::string& s, const std::string& suffix) { + if (s.size() < suffix.size()) return false; + return s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +std::string get_parent_folder_name(const std::string& path) { + fs::path p(path); + if (p.has_parent_path()) { + return p.parent_path().filename().string(); + } + return ""; +} + +// Convert vector-of-vectors [channels][samples] to flat interleaved [ch0_s0, ch1_s0, ...] +std::vector flatten_audio(const std::vector>& audio, int channels) { + if (audio.empty() || channels == 0) return {}; + size_t samples = audio[0].size(); + std::vector flat(samples * channels); + for (size_t s = 0; s < samples; ++s) { + for (int ch = 0; ch < channels; ++ch) { + flat[s * channels + ch] = (ch < static_cast(audio.size()) && s < audio[ch].size()) + ? audio[ch][s] : 0.0f; + } + } + return flat; +} + +Prediction process_file(const std::string& path, + Pipeline& pipeline, + int num_mics, + float ground_azimuth, + float ground_distance) { + Prediction result; + result.file_path = path; + result.true_label = get_parent_folder_name(path); + + WavFileSource wav(path); + if (!wav.open()) { + std::cerr << "[SKIP] Cannot open: " << path << std::endl; + result.pred_label = "error"; + return result; + } + + int sr = wav.sample_rate(); + int file_ch = wav.num_channels(); + if (file_ch != num_mics) { + std::cerr << "[WARN] " << path << " has " << file_ch + << " channels, expected " << num_mics << std::endl; + } + // Use actual file channels if mismatch, but pipeline config should match + int effective_channels = (file_ch < num_mics) ? file_ch : num_mics; + + size_t chunk_samples = static_cast(sr * pipeline.Config().chunk_duration); + std::vector> audio; + size_t got = wav.read(audio, chunk_samples); + if (got == 0 || audio.empty()) { + result.pred_label = "empty"; + return result; + } + + // Process only the first chunk (like demo_offline) + // For sliding-window analysis, call Process() on hop-sized chunks + auto flat = flatten_audio(audio, effective_channels); + auto frame = pipeline.Process(flat); + + if (!frame.is_clear && !frame.threats.empty()) { + const auto& t = frame.threats[0]; + result.detected = true; + result.pred_label = t.sound_type; + result.confidence = t.confidence; + result.azimuth = t.azimuth; + result.elevation = t.elevation; + result.distance = t.distance; + } else { + result.pred_label = "ambient"; + result.confidence = 0.0f; + } + + std::cout << "File: " << fs::path(path).filename().string() + << " | True: " << result.true_label + << " | Pred: " << result.pred_label + << " | Conf: " << std::fixed << std::setprecision(4) << result.confidence; + if (result.detected) { + std::cout << " | Az: " << std::setprecision(2) << result.azimuth << "°" + << " | El: " << std::setprecision(2) << result.elevation << "°" + << " | Dist: " << std::setprecision(2) << result.distance << "m"; + if (ground_azimuth >= 0.0f) { + float az_err = std::fabs(result.azimuth - ground_azimuth); + if (az_err > 180.0f) az_err = 360.0f - az_err; + std::cout << " | AzErr: " << az_err << "°"; + } + if (ground_distance >= 0.0f) { + std::cout << " | DistErr: " << std::fabs(result.distance - ground_distance) << "m"; + } + } + std::cout << std::endl; + return result; +} + +void collect_wav_files(const std::string& target, std::vector& out) { + if (fs::is_regular_file(target) && ends_with(target, ".wav")) { + out.push_back(target); + return; + } + if (!fs::is_directory(target)) return; + + for (const auto& entry : fs::recursive_directory_iterator(target)) { + if (entry.is_regular_file() && ends_with(entry.path().string(), ".wav")) { + out.push_back(entry.path().string()); + } + } + std::sort(out.begin(), out.end()); +} + +void print_report(const std::vector& results, + float ground_azimuth, + float ground_distance) { + std::map total_by_true; + std::map correct_by_true; + std::map conf_sum_by_true; + std::map> confusion; + + int total = 0, correct = 0; + int detected_count = 0; + float az_err_sum = 0.0f, dist_err_sum = 0.0f; + int az_err_count = 0, dist_err_count = 0; + + for (const auto& r : results) { + if (r.pred_label == "error" || r.pred_label == "empty") continue; + total++; + total_by_true[r.true_label]++; + conf_sum_by_true[r.true_label] += r.confidence; + confusion[r.true_label][r.pred_label]++; + if (r.true_label == r.pred_label) { + correct++; + correct_by_true[r.true_label]++; + } + if (r.detected) { + detected_count++; + if (ground_azimuth >= 0.0f) { + float az_err = std::fabs(r.azimuth - ground_azimuth); + if (az_err > 180.0f) az_err = 360.0f - az_err; + az_err_sum += az_err; + az_err_count++; + } + if (ground_distance >= 0.0f) { + dist_err_sum += std::fabs(r.distance - ground_distance); + dist_err_count++; + } + } + } + + std::cout << "\n==========================================" << std::endl; + std::cout << " MULTICHANNEL VALIDATION REPORT" << std::endl; + std::cout << "==========================================" << std::endl; + std::cout << "Total samples: " << total << std::endl; + std::cout << "Correct: " << correct << std::endl; + std::cout << "Accuracy: " << std::fixed << std::setprecision(2) + << (total > 0 ? 100.0f * correct / total : 0.0f) << "%" << std::endl; + std::cout << "Detected frames: " << detected_count << std::endl; + + if (az_err_count > 0) { + std::cout << "Azimuth RMSE: " << std::setprecision(2) + << std::sqrt(az_err_sum / az_err_count) << "°" << std::endl; + } + if (dist_err_count > 0) { + std::cout << "Distance MAE: " << std::setprecision(2) + << (dist_err_sum / dist_err_count) << "m" << std::endl; + } + + std::cout << "\nPer-class breakdown:" << std::endl; + for (const auto& kv : total_by_true) { + const std::string& cls = kv.first; + int cls_total = kv.second; + int cls_correct = correct_by_true[cls]; + float avg_conf = conf_sum_by_true[cls] / cls_total; + std::cout << " " << std::setw(10) << std::left << cls + << " Count: " << std::setw(3) << cls_total + << " Correct: " << std::setw(3) << cls_correct + << " Acc: " << std::setw(6) << std::fixed << std::setprecision(2) + << (100.0f * cls_correct / cls_total) << "%" + << " AvgConf: " << std::setprecision(4) << avg_conf << std::endl; + } + + std::cout << "\nConfusion matrix (rows=true, cols=pred):" << std::endl; + std::vector labels; + for (const auto& row : confusion) labels.push_back(row.first); + for (const auto& row : confusion) { + for (const auto& col : row.second) { + if (std::find(labels.begin(), labels.end(), col.first) == labels.end()) { + labels.push_back(col.first); + } + } + } + std::sort(labels.begin(), labels.end()); + + std::cout << std::setw(12) << " "; + for (const auto& l : labels) std::cout << std::setw(10) << l; + std::cout << std::endl; + for (const auto& true_l : labels) { + std::cout << std::setw(10) << std::left << true_l << " "; + for (const auto& pred_l : labels) { + int count = confusion.count(true_l) ? confusion[true_l].count(pred_l) ? confusion[true_l].at(pred_l) : 0 : 0; + std::cout << std::setw(10) << count; + } + std::cout << std::endl; + } + std::cout << "==========================================" << std::endl; +} + +int main(int argc, char** argv) { + if (argc < 2 || std::strcmp(argv[1], "--help") == 0 || std::strcmp(argv[1], "-h") == 0) { + print_usage(argv[0]); + return argc < 2 ? 1 : 0; + } + + std::string target = argv[1]; + std::string model_path = "models/gunshot_classifier.onnx"; + std::string label_map_path = "models/label_map.json"; + float threshold = 0.5f; + int num_mics = 4; + float spacing = 0.15f; + std::string layout = "cross"; + float ref_spl = 150.0f; + float ground_azimuth = -1.0f; + float ground_distance = -1.0f; + + for (int i = 2; i < argc; ++i) { + if (std::strcmp(argv[i], "--model") == 0 && i + 1 < argc) model_path = argv[++i]; + else if (std::strcmp(argv[i], "--label_map") == 0 && i + 1 < argc) label_map_path = argv[++i]; + else if (std::strcmp(argv[i], "--threshold") == 0 && i + 1 < argc) threshold = std::stof(argv[++i]); + else if (std::strcmp(argv[i], "--num_mics") == 0 && i + 1 < argc) num_mics = std::stoi(argv[++i]); + else if (std::strcmp(argv[i], "--spacing") == 0 && i + 1 < argc) spacing = std::stof(argv[++i]); + else if (std::strcmp(argv[i], "--layout") == 0 && i + 1 < argc) layout = argv[++i]; + else if (std::strcmp(argv[i], "--ref_spl") == 0 && i + 1 < argc) ref_spl = std::stof(argv[++i]); + else if (std::strcmp(argv[i], "--ground_azimuth") == 0 && i + 1 < argc) ground_azimuth = std::stof(argv[++i]); + else if (std::strcmp(argv[i], "--ground_distance") == 0 && i + 1 < argc) ground_distance = std::stof(argv[++i]); + } + + // Build PipelineConfig directly (no yaml-cpp needed) + PipelineConfig config; + config.sample_rate = 16000; + config.chunk_duration = 2.0f; + config.hop_duration = 0.5f; + config.n_mels = 64; + config.confidence_threshold = threshold; + config.classifier.model_path = model_path; + config.classifier.label_map_path = label_map_path; + config.classifier.threshold = threshold; + config.classifier.smoothing_window = 1; // offline: no temporal smoothing + config.mic_array.num_mics = static_cast(num_mics); + config.mic_array.layout = layout; + config.mic_array.spacing = spacing; + config.distance.ref_spl_gunshot = ref_spl; + config.distance.ref_spl_artillery = ref_spl; + config.distance.ref_spl_explosion = ref_spl; + + Pipeline pipeline(config); + + std::vector files; + collect_wav_files(target, files); + if (files.empty()) { + std::cerr << "No .wav files found in: " << target << std::endl; + return 1; + } + std::cout << "Found " << files.size() << " WAV file(s)." << std::endl; + std::cout << "Channels: " << num_mics << ", Layout: " << layout + << ", Spacing: " << spacing << "m" << std::endl; + std::cout << std::endl; + + std::vector results; + results.reserve(files.size()); + double total_ms = 0.0; + for (const auto& f : files) { + auto t0 = std::chrono::steady_clock::now(); + results.push_back(process_file(f, pipeline, num_mics, ground_azimuth, ground_distance)); + auto t1 = std::chrono::steady_clock::now(); + total_ms += std::chrono::duration(t1 - t0).count(); + } + + std::cout << "\nTotal inference time: " << std::fixed << std::setprecision(2) + << total_ms << " ms" + << " | Avg per file: " << (files.empty() ? 0.0 : total_ms / files.size()) + << " ms" << std::endl; + + print_report(results, ground_azimuth, ground_distance); + return 0; +} diff --git a/src/单兵终端APP/js/api.js b/src/单兵终端APP/js/api.js index 4316bcb8..a1f25d9c 100644 --- a/src/单兵终端APP/js/api.js +++ b/src/单兵终端APP/js/api.js @@ -4,7 +4,8 @@ */ const API = (() => { - let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://192.168.1.14:5000'; + let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://121.41.216.243'; + let TOKEN = localStorage.getItem('auth_token') || ''; async function request(url, options = {}) { // 每次请求时重新读取BASE,支持动态切换 @@ -15,14 +16,33 @@ const API = (() => { const controller = new AbortController(); const timeoutId = setTimeout(() => controller.abort(), 5000); + // 自动带上认证Token + const headers = { + 'Content-Type': 'application/json', + ...(options.headers || {}) + }; + if (TOKEN) { + headers['X-Auth-Token'] = TOKEN; + } + try { const resp = await fetch(fullUrl, { - headers: { 'Content-Type': 'application/json' }, + headers, signal: controller.signal, ...options }); clearTimeout(timeoutId); if (!resp.ok) { + if (resp.status === 401) { + // Token失效,清除登录状态并跳转登录页 + TOKEN = ''; + localStorage.removeItem('auth_token'); + localStorage.removeItem('soldier_session'); + if (typeof App !== 'undefined' && App.router) { + App.router('login'); + App.showToast('登录已过期,请重新登录'); + } + } const text = await resp.text(); throw new Error(`HTTP ${resp.status}: ${text}`); } @@ -187,6 +207,14 @@ const API = (() => { ]; } + function setBase(url) { BASE = url; } + function setToken(token) { + TOKEN = token; + if (token) localStorage.setItem('auth_token', token); + else localStorage.removeItem('auth_token'); + } + function getToken() { return TOKEN; } + return { updateLocation, getSoldiers, @@ -201,6 +229,9 @@ const API = (() => { sendSOS, login, register, - getAccounts + getAccounts, + setBase, + setToken, + getToken }; })(); diff --git a/src/单兵终端APP/js/app.js b/src/单兵终端APP/js/app.js index c5afd329..ea47faec 100644 --- a/src/单兵终端APP/js/app.js +++ b/src/单兵终端APP/js/app.js @@ -6,10 +6,10 @@ const App = (() => { // ===== 配置 ===== const CONFIG = { - apiBase: localStorage.getItem('api_base') || 'http://192.168.1.14:5000', - soldierId: localStorage.getItem('soldier_id') || 'soldier_001', - soldierName: localStorage.getItem('soldier_name') || '张三', - soldierUnit: localStorage.getItem('soldier_unit') || '第3步兵师/1连', + apiBase: localStorage.getItem('api_base') || 'http://121.41.216.243', + soldierId: '', + soldierName: '', + soldierUnit: '', pollInterval: 5000 }; @@ -171,7 +171,7 @@ const App = (() => { if (!url.startsWith('http')) url = 'http://' + url; CONFIG.apiBase = url; localStorage.setItem('api_base', url); - API.BASE = url; + API.setBase(url); updateSoldierInfo(); showToast('服务器地址已更新'); } @@ -186,31 +186,10 @@ const App = (() => { return; } - // 演示账号:不连后端也能直接登录 - const demoAccounts = { - 'soldier_001': { name: '张三', unit: '第3步兵师/1连', role: '狙击手' }, - 'soldier_002': { name: '李四', unit: '第3步兵师/2连', role: '机枪手' }, - 'soldier_003': { name: '王五', unit: '第3步兵师/3连', role: '通讯员' } - }; - if (demoAccounts[id] && pwd === '123456') { - const acc = demoAccounts[id]; - CONFIG.soldierId = id; - CONFIG.soldierName = acc.name; - CONFIG.soldierUnit = acc.unit; - localStorage.setItem('soldier_session', JSON.stringify({ - soldier_id: id, name: acc.name, unit: acc.unit, role: acc.role - })); - updateSoldierInfo(); - router('home'); - startPolling(); - startLocationReporting(); - showToast('登录成功,欢迎 ' + acc.name); - return; - } - try { const result = await API.login(id, pwd); if (result.ok) { + API.setToken(result.token); // 保存认证令牌到前后端统一中心 CONFIG.soldierId = result.soldier_id; CONFIG.soldierName = result.name; CONFIG.soldierUnit = result.unit; @@ -270,6 +249,7 @@ const App = (() => { function logout() { if (confirm('确定要退出登录吗?')) { localStorage.removeItem('soldier_session'); + API.setToken(''); // 清除认证令牌 CONFIG.soldierId = ''; CONFIG.soldierName = ''; stopPolling(); diff --git a/src/单兵终端APP/www/js/api.js b/src/单兵终端APP/www/js/api.js index 4316bcb8..a1f25d9c 100644 --- a/src/单兵终端APP/www/js/api.js +++ b/src/单兵终端APP/www/js/api.js @@ -4,7 +4,8 @@ */ const API = (() => { - let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://192.168.1.14:5000'; + let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://121.41.216.243'; + let TOKEN = localStorage.getItem('auth_token') || ''; async function request(url, options = {}) { // 每次请求时重新读取BASE,支持动态切换 @@ -15,14 +16,33 @@ const API = (() => { const controller = new AbortController(); const timeoutId = setTimeout(() => controller.abort(), 5000); + // 自动带上认证Token + const headers = { + 'Content-Type': 'application/json', + ...(options.headers || {}) + }; + if (TOKEN) { + headers['X-Auth-Token'] = TOKEN; + } + try { const resp = await fetch(fullUrl, { - headers: { 'Content-Type': 'application/json' }, + headers, signal: controller.signal, ...options }); clearTimeout(timeoutId); if (!resp.ok) { + if (resp.status === 401) { + // Token失效,清除登录状态并跳转登录页 + TOKEN = ''; + localStorage.removeItem('auth_token'); + localStorage.removeItem('soldier_session'); + if (typeof App !== 'undefined' && App.router) { + App.router('login'); + App.showToast('登录已过期,请重新登录'); + } + } const text = await resp.text(); throw new Error(`HTTP ${resp.status}: ${text}`); } @@ -187,6 +207,14 @@ const API = (() => { ]; } + function setBase(url) { BASE = url; } + function setToken(token) { + TOKEN = token; + if (token) localStorage.setItem('auth_token', token); + else localStorage.removeItem('auth_token'); + } + function getToken() { return TOKEN; } + return { updateLocation, getSoldiers, @@ -201,6 +229,9 @@ const API = (() => { sendSOS, login, register, - getAccounts + getAccounts, + setBase, + setToken, + getToken }; })(); diff --git a/src/单兵终端APP/www/js/app.js b/src/单兵终端APP/www/js/app.js index c5afd329..ea47faec 100644 --- a/src/单兵终端APP/www/js/app.js +++ b/src/单兵终端APP/www/js/app.js @@ -6,10 +6,10 @@ const App = (() => { // ===== 配置 ===== const CONFIG = { - apiBase: localStorage.getItem('api_base') || 'http://192.168.1.14:5000', - soldierId: localStorage.getItem('soldier_id') || 'soldier_001', - soldierName: localStorage.getItem('soldier_name') || '张三', - soldierUnit: localStorage.getItem('soldier_unit') || '第3步兵师/1连', + apiBase: localStorage.getItem('api_base') || 'http://121.41.216.243', + soldierId: '', + soldierName: '', + soldierUnit: '', pollInterval: 5000 }; @@ -171,7 +171,7 @@ const App = (() => { if (!url.startsWith('http')) url = 'http://' + url; CONFIG.apiBase = url; localStorage.setItem('api_base', url); - API.BASE = url; + API.setBase(url); updateSoldierInfo(); showToast('服务器地址已更新'); } @@ -186,31 +186,10 @@ const App = (() => { return; } - // 演示账号:不连后端也能直接登录 - const demoAccounts = { - 'soldier_001': { name: '张三', unit: '第3步兵师/1连', role: '狙击手' }, - 'soldier_002': { name: '李四', unit: '第3步兵师/2连', role: '机枪手' }, - 'soldier_003': { name: '王五', unit: '第3步兵师/3连', role: '通讯员' } - }; - if (demoAccounts[id] && pwd === '123456') { - const acc = demoAccounts[id]; - CONFIG.soldierId = id; - CONFIG.soldierName = acc.name; - CONFIG.soldierUnit = acc.unit; - localStorage.setItem('soldier_session', JSON.stringify({ - soldier_id: id, name: acc.name, unit: acc.unit, role: acc.role - })); - updateSoldierInfo(); - router('home'); - startPolling(); - startLocationReporting(); - showToast('登录成功,欢迎 ' + acc.name); - return; - } - try { const result = await API.login(id, pwd); if (result.ok) { + API.setToken(result.token); // 保存认证令牌到前后端统一中心 CONFIG.soldierId = result.soldier_id; CONFIG.soldierName = result.name; CONFIG.soldierUnit = result.unit; @@ -270,6 +249,7 @@ const App = (() => { function logout() { if (confirm('确定要退出登录吗?')) { localStorage.removeItem('soldier_session'); + API.setToken(''); // 清除认证令牌 CONFIG.soldierId = ''; CONFIG.soldierName = ''; stopPolling(); diff --git a/src/软件电脑端/server/app.py b/src/软件电脑端/server/app.py index 1ecd7253..a4f609cb 100644 --- a/src/软件电脑端/server/app.py +++ b/src/软件电脑端/server/app.py @@ -5,28 +5,135 @@ Flask 服务器 - 智途投送电脑端 """ import os +import sqlite3 +import secrets +from functools import wraps from datetime import datetime from flask import Flask, send_from_directory, request, jsonify from flask_cors import CORS +from werkzeug.security import generate_password_hash, check_password_hash -# 项目根目录(server 的上级目录) -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# 项目根目录(app.py 所在目录,部署时前端静态文件应放在同级目录) +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DB_PATH = os.path.join(os.path.dirname(__file__), 'zhitu.db') app = Flask(__name__, static_folder=BASE_DIR) CORS(app) # 允许跨域请求,支持单兵APP访问 -# ---- 内存数据存储 ---- -soldiers = {} # { soldier_id: { id, name, lat, lng, updated_at } } -danger_zones = [] # [ { id, lat, lng, radius, description, created_at } ] -_demands = [] # [ { id, soldier_id, soldier_name, type, quantity, unit, urgency, status, created_at } ] -_demand_id_counter = 0 -_drop_points = [ # 推荐投放点(演示数据) +# ---- SQLite 数据库 ---- +def get_db(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + +def init_db(): + conn = get_db() + conn.executescript(''' + CREATE TABLE IF NOT EXISTS accounts ( + soldier_id TEXT PRIMARY KEY, + password_hash TEXT NOT NULL, + name TEXT NOT NULL, + unit TEXT, + role TEXT, + created_at TEXT + ); + CREATE TABLE IF NOT EXISTS demands ( + id TEXT PRIMARY KEY, + soldier_id TEXT, + soldier_name TEXT, + type TEXT, + items TEXT, + quantity INTEGER, + unit TEXT, + urgency TEXT, + status TEXT DEFAULT '待处理', + lat REAL, + lng REAL, + created_at TEXT + ); + CREATE TABLE IF NOT EXISTS soldiers ( + soldier_id TEXT PRIMARY KEY, + name TEXT, + lat REAL, + lng REAL, + updated_at TEXT + ); + CREATE TABLE IF NOT EXISTS sos_alerts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + soldier_id TEXT, + soldier_name TEXT, + lat REAL, + lng REAL, + alert_time TEXT, + handled INTEGER DEFAULT 0 + ); + CREATE TABLE IF NOT EXISTS danger_zones ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + lat REAL, + lng REAL, + radius REAL, + description TEXT, + created_at TEXT + ); + CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + soldier_id TEXT, + soldier_name TEXT, + type TEXT, + status TEXT, + progress INTEGER DEFAULT 0, + eta TEXT, + remain_time TEXT, + start_name TEXT, + target_name TEXT, + start_lat REAL, + start_lng REAL, + end_lat REAL, + end_lng REAL, + safety_score INTEGER, + created_at TEXT + ); + CREATE TABLE IF NOT EXISTS tokens ( + token TEXT PRIMARY KEY, + soldier_id TEXT, + name TEXT, + unit TEXT, + role TEXT, + created_at TEXT + ); + ''') + conn.commit() + conn.close() + +def init_demo_accounts(): + """初始化演示账号到数据库,确保前后端账号统一""" + conn = get_db() + demo_accounts = [ + ("soldier_001", "123456", "张三", "第3步兵师/1连", "狙击手"), + ("soldier_002", "123456", "李四", "第3步兵师/2连", "机枪手"), + ("soldier_003", "123456", "王五", "第3步兵师/3连", "通讯员"), + ] + for sid, pwd, name, unit, role in demo_accounts: + existing = conn.execute('SELECT 1 FROM accounts WHERE soldier_id=?', (sid,)).fetchone() + if not existing: + conn.execute(''' + INSERT INTO accounts (soldier_id, password_hash, name, unit, role, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ''', (sid, generate_password_hash(pwd), name, unit, role, + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + conn.commit() + conn.close() + +init_db() +init_demo_accounts() + +# ---- 内存数据(演示数据:投放点、无人机状态重启后恢复默认值)---- +_drop_points = [ {"id": 1, "name": "安全点 #1", "safety_score": 92, "distance": 5, "reason": "深掩体保护,视角盲区", "lat": 30.0050, "lng": 120.0030}, {"id": 2, "name": "安全点 #2", "safety_score": 85, "distance": 12, "reason": "钢筋混凝土建筑内部", "lat": 30.0055, "lng": 120.0035}, {"id": 3, "name": "陷阱点 #3", "safety_score": 35, "distance": 20, "reason": "孤立断墙,成瞄准点", "lat": 30.0060, "lng": 120.0040} ] -_tasks = {} # { soldier_id: task } -_drone_status = { # 无人机实时状态(演示数据) +_drone_status = { "drone_id": "无人机-01", "task_id": "#001", "status": "飞行中", @@ -40,20 +147,26 @@ _drone_status = { # 无人机实时状态(演示数据) "lat": 30.0040, "lng": 120.0025 } -_drone_logs = [ # 无人机动态日志 +_drone_logs = [ {"time": "12:25:30", "message": "到达投放点"}, {"time": "12:20:45", "message": "接收任务指令"}, {"time": "12:10:00", "message": "任务分配"} ] -_sos_alerts = [] # 求救记录 -_accounts = { # 士兵账号 { soldier_id: { password, name, unit, role } } - "soldier_001": { "password": "123456", "name": "张三", "unit": "第3步兵师/1连", "role": "狙击手" }, - "soldier_002": { "password": "123456", "name": "李四", "unit": "第3步兵师/2连", "role": "机枪手" }, - "soldier_003": { "password": "123456", "name": "王五", "unit": "第3步兵师/3连", "role": "通讯员" } -} -_tasks = {} # { soldier_id: { id, status, progress, eta, ... } } -_task_id_counter = 0 -_danger_id_counter = 0 +# 计数器已从内存变量改为数据库查询,避免服务重启后主键冲突 + +def require_auth(f): + """接口认证装饰器:验证 X-Auth-Token(从数据库查询,支持多 worker 共享)""" + @wraps(f) + def decorated(*args, **kwargs): + token = request.headers.get('X-Auth-Token') or request.args.get('token') + conn = get_db() + row = conn.execute('SELECT soldier_id, name, unit, role FROM tokens WHERE token=?', (token,)).fetchone() + conn.close() + if not row: + return jsonify({"ok": False, "error": "未登录或token无效"}), 401 + request.current_user = dict(row) + return f(*args, **kwargs) + return decorated # ===== 静态文件路由 ===== @@ -76,78 +189,108 @@ def js_files(filename): # ===== REST API: 士兵位置 ===== @app.route("/api/soldier/location", methods=["POST"]) +@require_auth def update_soldier_location(): data = request.get_json(force=True) sid = data.get("id") if not sid: return jsonify({"ok": False, "error": "missing id"}), 400 - soldiers[sid] = { - "id": sid, - "name": data.get("name", sid), - "lat": float(data.get("lat", 0)), - "lng": float(data.get("lng", 0)), - "updated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") - } + conn = get_db() + conn.execute(''' + INSERT INTO soldiers (soldier_id, name, lat, lng, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(soldier_id) DO UPDATE SET + name=excluded.name, lat=excluded.lat, lng=excluded.lng, updated_at=excluded.updated_at + ''', (sid, data.get("name", sid), float(data.get("lat", 0)), float(data.get("lng", 0)), + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + conn.commit() + conn.close() return jsonify({"ok": True}) @app.route("/api/soldiers", methods=["GET"]) def get_soldiers(): - return jsonify({"soldiers": list(soldiers.values())}) + conn = get_db() + rows = conn.execute('SELECT * FROM soldiers').fetchall() + conn.close() + return jsonify({"soldiers": [dict(r) for r in rows]}) # ===== REST API: 危险区域 ===== @app.route("/api/danger-zones", methods=["GET"]) def get_danger_zones(): - return jsonify({"danger_zones": danger_zones}) + conn = get_db() + rows = conn.execute('SELECT * FROM danger_zones').fetchall() + conn.close() + return jsonify({"danger_zones": [dict(r) for r in rows]}) @app.route("/api/danger-zones", methods=["POST"]) +@require_auth def add_danger_zone(): - global _danger_id_counter data = request.get_json(force=True) - _danger_id_counter += 1 zone = { - "id": _danger_id_counter, "lat": float(data.get("lat", 0)), "lng": float(data.get("lng", 0)), "radius": float(data.get("radius", 50)), "description": data.get("description", "危险区域"), "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } - danger_zones.append(zone) - return jsonify({"ok": True, "id": zone["id"]}) + conn = get_db() + cur = conn.execute(''' + INSERT INTO danger_zones (lat, lng, radius, description, created_at) + VALUES (?, ?, ?, ?, ?) + ''', (zone["lat"], zone["lng"], zone["radius"], zone["description"], zone["created_at"])) + zone_id = cur.lastrowid + conn.commit() + conn.close() + return jsonify({"ok": True, "id": zone_id}) # ===== REST API: 物资需求(单兵APP) ===== +def _next_demand_id(): + """从数据库查询当前最大需求ID,生成下一个(避免内存计数器重启归零导致主键冲突)""" + conn = get_db() + row = conn.execute("SELECT id FROM demands WHERE id LIKE 'REQ-%' ORDER BY id DESC LIMIT 1").fetchone() + conn.close() + if row: + try: + num = int(row["id"].replace("REQ-", "")) + except ValueError: + num = 0 + else: + num = 0 + return "REQ-" + str(num + 1).zfill(3) + + @app.route("/api/demand", methods=["POST"]) +@require_auth def post_demand(): - global _demand_id_counter data = request.get_json(force=True) - _demand_id_counter += 1 - # 获取投放点坐标(如果有) drop_point = data.get("drop_point") or {} lat = float(drop_point.get("lat", 0)) if isinstance(drop_point, dict) else 0 lng = float(drop_point.get("lng", 0)) if isinstance(drop_point, dict) else 0 - # 如果APP没传坐标,尝试从士兵位置获取 soldier_id = data.get("soldier_id", "unknown") - if lat == 0 and lng == 0 and soldier_id in soldiers: - lat = soldiers[soldier_id].get("lat", 0) - lng = soldiers[soldier_id].get("lng", 0) + if lat == 0 and lng == 0: + conn = get_db() + row = conn.execute('SELECT lat, lng FROM soldiers WHERE soldier_id=?', (soldier_id,)).fetchone() + conn.close() + if row: + lat, lng = row["lat"], row["lng"] - # 构建物资清单描述 qty = int(data.get("quantity", 0)) unit = data.get("unit", "个") item_type = data.get("type", "未知") items = f"{item_type} × {qty}{unit}" + demand_id = _next_demand_id() demand = { - "id": "REQ-" + str(_demand_id_counter).zfill(3), + "id": demand_id, "soldier_id": soldier_id, "soldier_name": data.get("soldier_name", "未知"), "type": item_type + "补给", @@ -160,25 +303,38 @@ def post_demand(): "lng": lng, "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } - _demands.append(demand) + conn = get_db() + conn.execute(''' + INSERT INTO demands (id, soldier_id, soldier_name, type, items, quantity, unit, urgency, status, lat, lng, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (demand["id"], demand["soldier_id"], demand["soldier_name"], demand["type"], + demand["items"], demand["quantity"], demand["unit"], demand["urgency"], + demand["status"], demand["lat"], demand["lng"], demand["created_at"])) + conn.commit() + conn.close() return jsonify({"ok": True, "id": demand["id"]}) @app.route("/api/demands", methods=["GET"]) def get_demands(): soldier_id = request.args.get("soldier_id") + conn = get_db() if soldier_id: - return jsonify({"demands": [d for d in _demands if d["soldier_id"] == soldier_id]}) - # 返回所有待处理需求(供电脑端使用) - pending = [d for d in _demands if d["status"] == "待处理"] - return jsonify({"demands": pending}) + rows = conn.execute('SELECT * FROM demands WHERE soldier_id=?', (soldier_id,)).fetchall() + else: + rows = conn.execute("SELECT * FROM demands WHERE status='待处理'").fetchall() + conn.close() + return jsonify({"demands": [dict(r) for r in rows]}) -@app.route("/api/demands/", methods=["GET"]) +@app.route("/api/demands/", methods=["GET"]) +@require_auth def get_demand(demand_id): - for d in _demands: - if d["id"] == demand_id: - return jsonify({"demand": d}) + conn = get_db() + row = conn.execute('SELECT * FROM demands WHERE id=?', (demand_id,)).fetchone() + conn.close() + if row: + return jsonify({"demand": dict(row)}) return jsonify({"ok": False, "error": "not found"}), 404 @@ -190,24 +346,37 @@ def get_drop_points(): @app.route("/api/drop-point", methods=["POST"]) +@require_auth def post_drop_point(): data = request.get_json(force=True) - # 记录士兵选择的投放点 return jsonify({"ok": True}) # ===== REST API: 任务状态 ===== +def _next_task_id(): + """从数据库查询当前最大任务ID,生成下一个""" + conn = get_db() + row = conn.execute("SELECT id FROM tasks WHERE id LIKE '#%' ORDER BY id DESC LIMIT 1").fetchone() + conn.close() + if row: + try: + num = int(row["id"].replace("#", "")) + except ValueError: + num = 0 + else: + num = 0 + return "#" + str(num + 1).zfill(3) + + @app.route("/api/task/dispatch", methods=["POST"]) +@require_auth def dispatch_task(): - """电脑端调度任务时调用""" - global _task_id_counter data = request.get_json(force=True) soldier_id = data.get("soldier_id", "unknown") - _task_id_counter += 1 task = { - "id": "#" + str(_task_id_counter).zfill(3), + "id": _next_task_id(), "soldier_id": soldier_id, "soldier_name": data.get("soldier_name", "未知"), "type": data.get("type", "未知"), @@ -224,57 +393,76 @@ def dispatch_task(): "safety_score": data.get("safety_score", 90), "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } - _tasks[soldier_id] = task + conn = get_db() + conn.execute(''' + INSERT INTO tasks (id, soldier_id, soldier_name, type, status, progress, eta, remain_time, + start_name, target_name, start_lat, start_lng, end_lat, end_lng, safety_score, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (task["id"], task["soldier_id"], task["soldier_name"], task["type"], task["status"], + task["progress"], task["eta"], task["remain_time"], task["start_name"], task["target_name"], + task["start_lat"], task["start_lng"], task["end_lat"], task["end_lng"], + task["safety_score"], task["created_at"])) + conn.commit() + conn.close() # 更新需求状态为"已调度" demand_id = data.get("demand_id") if demand_id: - for d in _demands: - if d["id"] == demand_id: - d["status"] = "已调度" - break + conn = get_db() + conn.execute("UPDATE demands SET status='已调度' WHERE id=?", (demand_id,)) + conn.commit() + conn.close() return jsonify({"ok": True, "task": task}) @app.route("/api/task/current", methods=["GET"]) +@require_auth def get_current_task(): soldier_id = request.args.get("soldier_id", "soldier_001") - task = _tasks.get(soldier_id) - if not task: - # 没有任务时返回模拟数据 - task = { - "id": "#--", - "status": "无任务", - "progress": 0, - "eta": "--", - "remain_time": "--", - "start_name": "--", - "target_name": "--", - "safety_score": 0 - } - return jsonify({"task": task}) + conn = get_db() + row = conn.execute('SELECT * FROM tasks WHERE soldier_id=?', (soldier_id,)).fetchone() + conn.close() + if row: + return jsonify({"task": dict(row)}) + return jsonify({"task": { + "id": "#--", "status": "无任务", "progress": 0, "eta": "--", + "remain_time": "--", "start_name": "--", "target_name": "--", "safety_score": 0 + }}) @app.route("/api/task/update", methods=["POST"]) +@require_auth def update_task(): - """更新任务状态(进度、ETA等)""" data = request.get_json(force=True) soldier_id = data.get("soldier_id", "unknown") - task = _tasks.get(soldier_id) - if not task: + conn = get_db() + row = conn.execute('SELECT * FROM tasks WHERE soldier_id=?', (soldier_id,)).fetchone() + if not row: + conn.close() return jsonify({"ok": False, "error": "无进行中的任务"}), 404 + updates = [] + params = [] if "progress" in data: - task["progress"] = int(data["progress"]) + updates.append("progress=?") + params.append(int(data["progress"])) if "status" in data: - task["status"] = data["status"] + updates.append("status=?") + params.append(data["status"]) if "eta" in data: - task["eta"] = data["eta"] + updates.append("eta=?") + params.append(data["eta"]) if "remain_time" in data: - task["remain_time"] = data["remain_time"] - - return jsonify({"ok": True, "task": task}) + updates.append("remain_time=?") + params.append(data["remain_time"]) + + if updates: + params.append(soldier_id) + conn.execute(f"UPDATE tasks SET {','.join(updates)} WHERE soldier_id=?", params) + conn.commit() + conn.close() + return jsonify({"ok": True}) # ===== REST API: 无人机状态 ===== @@ -292,32 +480,28 @@ def get_drone_logs(): # ===== REST API: SOS求救 ===== @app.route("/api/sos", methods=["POST"]) +@require_auth def post_sos(): data = request.get_json(force=True) - alert = { - "soldier_id": data.get("soldier_id", "unknown"), - "soldier_name": data.get("soldier_name", "未知"), - "lat": float(data.get("lat", 0)), - "lng": float(data.get("lng", 0)), - "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "handled": False - } - _sos_alerts.append(alert) + alert_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conn = get_db() + conn.execute(''' + INSERT INTO sos_alerts (soldier_id, soldier_name, lat, lng, alert_time, handled) + VALUES (?, ?, ?, ?, ?, ?) + ''', (data.get("soldier_id", "unknown"), data.get("soldier_name", "未知"), + float(data.get("lat", 0)), float(data.get("lng", 0)), alert_time, 0)) # 同时标记为危险区域 - global _danger_id_counter - _danger_id_counter += 1 - danger_zones.append({ - "id": _danger_id_counter, - "lat": alert["lat"], - "lng": alert["lng"], - "radius": 100, - "description": f"士兵求救: {alert['soldier_name']}", - "created_at": alert["time"] - }) - return jsonify({"ok": True, "alert_id": len(_sos_alerts)}) + conn.execute(''' + INSERT INTO danger_zones (lat, lng, radius, description, created_at) + VALUES (?, ?, ?, ?, ?) + ''', (float(data.get("lat", 0)), float(data.get("lng", 0)), 100, + f"士兵求救: {data.get('soldier_name', '未知')}", alert_time)) + conn.commit() + conn.close() + return jsonify({"ok": True}) -# ===== REST API: 账号系统 ===== +# ===== REST API: 账号系统(核心改造:持久化 + Token + 密码加密) ===== @app.route("/api/auth/register", methods=["POST"]) def auth_register(): @@ -330,15 +514,20 @@ def auth_register(): if not sid or not pwd or not name: return jsonify({"ok": False, "error": "士兵编号、密码、姓名不能为空"}), 400 - if sid in _accounts: + + conn = get_db() + existing = conn.execute('SELECT 1 FROM accounts WHERE soldier_id=?', (sid,)).fetchone() + if existing: + conn.close() return jsonify({"ok": False, "error": "该士兵编号已注册"}), 409 - _accounts[sid] = { - "password": pwd, - "name": name, - "unit": unit, - "role": role - } + password_hash = generate_password_hash(pwd) + conn.execute(''' + INSERT INTO accounts (soldier_id, password_hash, name, unit, role, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ''', (sid, password_hash, name, unit, role, datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + conn.commit() + conn.close() return jsonify({"ok": True, "message": "注册成功"}) @@ -348,31 +537,61 @@ def auth_login(): sid = data.get("soldier_id", "").strip() pwd = data.get("password", "").strip() - account = _accounts.get(sid) - if not account: + conn = get_db() + row = conn.execute('SELECT * FROM accounts WHERE soldier_id=?', (sid,)).fetchone() + + if not row: + conn.close() return jsonify({"ok": False, "error": "士兵编号不存在"}), 404 - if account["password"] != pwd: + + if not check_password_hash(row["password_hash"], pwd): + conn.close() return jsonify({"ok": False, "error": "密码错误"}), 401 + token = secrets.token_hex(16) + conn.execute(''' + INSERT INTO tokens (token, soldier_id, name, unit, role, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ''', (token, sid, row["name"], row["unit"], row["role"], datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + conn.commit() + conn.close() + return jsonify({ "ok": True, + "token": token, "soldier_id": sid, - "name": account["name"], - "unit": account["unit"], - "role": account["role"] + "name": row["name"], + "unit": row["unit"], + "role": row["role"] }) @app.route("/api/auth/accounts", methods=["GET"]) def get_accounts(): + conn = get_db() + rows = conn.execute('SELECT soldier_id, name, unit, role FROM accounts').fetchall() + conn.close() return jsonify({ - "accounts": [ - {"soldier_id": k, "name": v["name"], "unit": v["unit"], "role": v["role"]} - for k, v in _accounts.items() - ] + "accounts": [dict(r) for r in rows] }) +@app.route("/api/auth/me", methods=["GET"]) +def auth_me(): + token = request.headers.get('X-Auth-Token') or request.args.get('token') + conn = get_db() + row = conn.execute('SELECT soldier_id, name, unit, role FROM tokens WHERE token=?', (token,)).fetchone() + conn.close() + if not row: + return jsonify({"ok": False, "error": "未登录"}), 401 + return jsonify({"ok": True, "user": dict(row)}) + + +@app.route("/api/ping", methods=["GET"]) +def ping(): + return jsonify({"ok": True, "message": "智途投送后端运行正常"}) + + if __name__ == "__main__": import sys import io diff --git a/src/软件电脑端/server/deploy.sh b/src/软件电脑端/server/deploy.sh new file mode 100644 index 00000000..0738da3d --- /dev/null +++ b/src/软件电脑端/server/deploy.sh @@ -0,0 +1,192 @@ +#!/bin/bash +# ============================================================================= +# 智途投送 - 云端一键部署脚本 +# 适用系统:Ubuntu 20.04 / 22.04 LTS +# 运行方式:root 用户执行 bash deploy.sh +# ============================================================================= + +set -e # 遇到错误立即退出 + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +PROJECT_DIR="/opt/zhitu" +SERVICE_NAME="zhitu" +NGINX_CONF="/etc/nginx/sites-available/zhitu" + +echo -e "${GREEN}=========================================${NC}" +echo -e "${GREEN} 智途投送 - 云端一键部署脚本${NC}" +echo -e "${GREEN}=========================================${NC}" +echo "" + +# 检查是否 root +if [ "$EUID" -ne 0 ]; then + echo -e "${RED}错误:请使用 root 用户运行此脚本${NC}" + echo "提示:先执行 sudo -i 切换到 root,再运行 bash deploy.sh" + exit 1 +fi + +# 获取脚本所在目录(即后端代码目录) +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +echo -e "${YELLOW}[INFO]${NC} 后端代码目录: $SCRIPT_DIR" + +# ============================================================================= +# 步骤1: 更新系统并安装依赖 +# ============================================================================= +echo "" +echo -e "${GREEN}[1/8] 更新系统软件源...${NC}" +apt update -qq + +echo -e "${GREEN}[2/8] 安装 Python3、Nginx、防火墙...${NC}" +apt install -y -qq python3 python3-pip python3-venv nginx curl ufw + +# ============================================================================= +# 步骤2: 创建项目目录并复制代码 +# ============================================================================= +echo -e "${GREEN}[3/8] 创建项目目录并复制代码...${NC}" +mkdir -p "$PROJECT_DIR" + +# 复制所有文件(排除可能的旧数据库,避免覆盖) +rsync -av --exclude='zhitu.db' "$SCRIPT_DIR/" "$PROJECT_DIR/" 2>/dev/null || cp -r "$SCRIPT_DIR"/* "$PROJECT_DIR/" + +# 如果本地有开发数据库,可以选择性复制(提示用户) +if [ -f "$SCRIPT_DIR/zhitu.db" ]; then + echo -e "${YELLOW}[提示]${NC} 检测到本地存在 zhitu.db 数据库文件" + read -p "是否将本地数据库复制到服务器?(y/n): " copy_db + if [ "$copy_db" = "y" ] || [ "$copy_db" = "Y" ]; then + cp "$SCRIPT_DIR/zhitu.db" "$PROJECT_DIR/zhitu.db" + echo -e "${GREEN}已复制数据库${NC}" + fi +fi + +# ============================================================================= +# 步骤3: 安装 Python 依赖 +# ============================================================================= +echo -e "${GREEN}[4/8] 安装 Python 依赖...${NC}" +cd "$PROJECT_DIR" +pip3 install -q -r requirements.txt + +# ============================================================================= +# 步骤4: 创建 Systemd 服务(比 Supervisor 更现代、更稳定) +# ============================================================================= +echo -e "${GREEN}[5/8] 创建系统服务...${NC}" + +cat > /etc/systemd/system/${SERVICE_NAME}.service << EOF +[Unit] +Description=智途投送后端服务 (ZhiTu Flask Backend) +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=${PROJECT_DIR} +ExecStart=/usr/local/bin/gunicorn -w 2 -b 127.0.0.1:5000 app:app +Restart=always +RestartSec=5 +StandardOutput=append:/var/log/zhitu.log +StandardError=append:/var/log/zhitu.log + +[Install] +WantedBy=multi-user.target +EOF + +systemctl daemon-reload +systemctl enable ${SERVICE_NAME} + +# ============================================================================= +# 步骤5: 配置 Nginx 反向代理 +# ============================================================================= +echo -e "${GREEN}[6/8] 配置 Nginx...${NC}" + +cat > "$NGINX_CONF" << 'EOF' +server { + listen 80 default_server; + listen [::]:80 default_server; + server_name _; + + client_max_body_size 20M; + + location / { + proxy_pass http://127.0.0.1:5000; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_connect_timeout 60s; + proxy_send_timeout 60s; + proxy_read_timeout 60s; + } + + # 静态文件直接由 Nginx 处理(提高效率) + location /static { + alias /opt/zhitu/static; + expires 7d; + } +} +EOF + +# 启用配置,禁用默认站点 +rm -f /etc/nginx/sites-enabled/default +ln -sf "$NGINX_CONF" /etc/nginx/sites-enabled/zhitu + +# 测试 Nginx 配置 +nginx -t + +# ============================================================================= +# 步骤6: 配置防火墙 +# ============================================================================= +echo -e "${GREEN}[7/8] 配置防火墙...${NC}" +ufw default deny incoming +ufw default allow outgoing +ufw allow 80/tcp comment 'HTTP' +ufw allow 443/tcp comment 'HTTPS' +ufw allow 22/tcp comment 'SSH' +ufw --force enable + +echo -e "${YELLOW}[提示]${NC} 防火墙已启用,仅开放 22(SSH)、80(HTTP)、443(HTTPS)" + +# ============================================================================= +# 步骤7: 启动服务 +# ============================================================================= +echo -e "${GREEN}[8/8] 启动服务...${NC}" +systemctl restart ${SERVICE_NAME} +systemctl restart nginx +sleep 2 + +# ============================================================================= +# 部署完成,输出状态 +# ============================================================================= +echo "" +echo -e "${GREEN}=========================================${NC}" +echo -e "${GREEN} 部署完成!${NC}" +echo -e "${GREEN}=========================================${NC}" +echo "" + +# 获取公网 IP +PUBLIC_IP=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null || echo "无法获取") + +SERVICE_STATUS=$(systemctl is-active ${SERVICE_NAME}) +NGINX_STATUS=$(systemctl is-active nginx) + +echo -e "后端服务状态 : ${SERVICE_STATUS}" +echo -e "Nginx 状态 : ${NGINX_STATUS}" +echo -e "公网 IP : ${PUBLIC_IP}" +echo "" +echo -e "访问地址:" +echo -e " HTTP : ${YELLOW}http://${PUBLIC_IP}${NC}" +echo -e " 或 : ${YELLOW}http://你的域名${NC} (如果已配置域名)" +echo "" +echo -e "常用命令:" +echo -e " 查看后端日志 : ${YELLOW}journalctl -u zhitu -f${NC}" +echo -e " 重启后端 : ${YELLOW}systemctl restart zhitu${NC}" +echo -e " 查看服务状态 : ${YELLOW}systemctl status zhitu${NC}" +echo -e " 数据库位置 : ${YELLOW}${PROJECT_DIR}/zhitu.db${NC}" +echo "" +echo -e "${YELLOW}下一步:${NC}" +echo -e " 1. 打开单兵终端 APP,将服务器地址配置为: ${GREEN}http://${PUBLIC_IP}${NC}" +echo -e " 2. 如需 HTTPS,请先购买域名,然后运行: ${GREEN}certbot --nginx${NC}" +echo "" diff --git a/src/软件电脑端/server/deploy_safe.sh b/src/软件电脑端/server/deploy_safe.sh new file mode 100644 index 00000000..13979e68 --- /dev/null +++ b/src/软件电脑端/server/deploy_safe.sh @@ -0,0 +1,212 @@ +#!/bin/bash +# ============================================================================= +# 智途投送 - 安全部署脚本(自动检测冲突、自适应端口) +# 适用:已部署其他项目的共享服务器 +# ============================================================================= + +set -e + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +PROJECT_DIR="/opt/zhitu" +SERVICE_NAME="zhitu" + +echo -e "${GREEN}=========================================${NC}" +echo -e "${GREEN} 智途投送 - 安全部署脚本${NC}" +echo -e "${GREEN} (自动检测冲突 + 自适应端口)${NC}" +echo -e "${GREEN}=========================================${NC}" +echo "" + +if [ "$EUID" -ne 0 ]; then + echo -e "${RED}错误:请使用 root 用户运行${NC}" + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +# ============================================================================= +# 自动检测端口占用,选择可用端口 +# ============================================================================= +echo -e "${BLUE}[检测]${NC} 检查端口占用情况..." + +BASE_PORT=5000 +PORT=$BASE_PORT +while ss -tlnp 2>/dev/null | grep -q ":$PORT\b" || netstat -tlnp 2>/dev/null | grep -q ":$PORT\b"; do + PORT=$((PORT + 1)) + if [ $PORT -gt 5100 ]; then + echo -e "${RED}错误:5000-5100 端口全部被占用,请手动清理${NC}" + exit 1 + fi +done + +echo -e "${GREEN} ✓ 后端将使用端口: $PORT${NC}" + +# ============================================================================= +# 检测 Nginx 状态 +# ============================================================================= +NGINX_ACTIVE=false +if systemctl is-active --quiet nginx 2>/dev/null || service nginx status 2>/dev/null | grep -q running; then + NGINX_ACTIVE=true + echo -e "${YELLOW} ! 检测到 Nginx 已在运行(其他项目正在使用 80 端口)${NC}" + echo -e "${YELLOW} 智途投送将直接暴露在高端口 $PORT,不占用 80 端口${NC}" +else + echo -e "${GREEN} ✓ 80 端口空闲,将配置 Nginx 反向代理${NC}" +fi + +# ============================================================================= +# 检测 /opt/zhitu 是否已存在(防止覆盖其他项目) +# ============================================================================= +if [ -d "$PROJECT_DIR" ]; then + echo -e "${YELLOW}[警告]${NC} 检测到目录 $PROJECT_DIR 已存在" + BACKUP_DIR="${PROJECT_DIR}.backup.$(date +%Y%m%d%H%M%S)" + echo -e "${YELLOW} 自动备份到: $BACKUP_DIR${NC}" + mv "$PROJECT_DIR" "$BACKUP_DIR" +fi + +# ============================================================================= +# 安装依赖 +# ============================================================================= +echo "" +echo -e "${GREEN}[1/5] 安装环境依赖...${NC}" +apt update -qq +apt install -y -qq python3 python3-pip python3-venv curl + +# ============================================================================= +# 复制代码 +# ============================================================================= +echo -e "${GREEN}[2/5] 复制后端代码...${NC}" +mkdir -p "$PROJECT_DIR" +rsync -av --exclude='zhitu.db' "$SCRIPT_DIR/" "$PROJECT_DIR/" 2>/dev/null || cp -r "$SCRIPT_DIR"/* "$PROJECT_DIR/" + +# 如果本地有数据库,询问是否复制 +if [ -f "$SCRIPT_DIR/zhitu.db" ]; then + echo -e "${YELLOW}[提示]${NC} 检测到本地数据库文件" + # 非交互式环境直接复制(课程服务器通常自动化部署) + cp "$SCRIPT_DIR/zhitu.db" "$PROJECT_DIR/zhitu.db" 2>/dev/null && echo -e "${GREEN} ✓ 已复制数据库${NC}" || true +fi + +# ============================================================================= +# 安装 Python 依赖 +# ============================================================================= +echo -e "${GREEN}[3/5] 创建虚拟环境并安装依赖...${NC}" +cd "$PROJECT_DIR" +# Ubuntu 24.04 禁止直接 pip 安装到系统,使用虚拟环境 +if [ ! -d "$PROJECT_DIR/venv" ]; then + python3 -m venv venv +fi +source "$PROJECT_DIR/venv/bin/activate" +pip install -q -r requirements.txt + +# ============================================================================= +# 创建 Systemd 服务 +# ============================================================================= +echo -e "${GREEN}[4/5] 创建后端服务...${NC}" + +# 如果存在旧服务,先停止 +systemctl stop ${SERVICE_NAME} 2>/dev/null || true +systemctl disable ${SERVICE_NAME} 2>/dev/null || true +rm -f /etc/systemd/system/${SERVICE_NAME}.service + +cat > /etc/systemd/system/${SERVICE_NAME}.service << EOF +[Unit] +Description=智途投送后端服务 (Port $PORT) +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=${PROJECT_DIR} +ExecStart=${PROJECT_DIR}/venv/bin/gunicorn -w 2 -b 0.0.0.0:${PORT} app:app +Restart=always +RestartSec=5 +StandardOutput=append:/var/log/zhitu.log +StandardError=append:/var/log/zhitu.log + +[Install] +WantedBy=multi-user.target +EOF + +systemctl daemon-reload +systemctl enable ${SERVICE_NAME} +systemctl restart ${SERVICE_NAME} + +sleep 2 + +# ============================================================================= +# 配置 Nginx(仅在 Nginx 未运行时) +# ============================================================================= +if [ "$NGINX_ACTIVE" = false ]; then + echo -e "${GREEN}[5/5] 配置 Nginx...${NC}" + apt install -y -qq nginx + + cat > /etc/nginx/sites-available/zhitu << EOF +server { + listen 80 default_server; + listen [::]:80 default_server; + server_name _; + + client_max_body_size 20M; + + location / { + proxy_pass http://127.0.0.1:${PORT}; + proxy_http_version 1.1; + proxy_set_header Host \$host; + proxy_set_header X-Real-IP \$remote_addr; + proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for; + } +} +EOF + rm -f /etc/nginx/sites-enabled/default + ln -sf /etc/nginx/sites-available/zhitu /etc/nginx/sites-enabled/zhitu + nginx -t && systemctl restart nginx + + # 防火墙 + ufw allow 80/tcp 2>/dev/null || true + ufw allow 443/tcp 2>/dev/null || true + ufw allow 22/tcp 2>/dev/null || true +else + echo -e "${GREEN}[5/5] 跳过 Nginx 配置(已被其他项目占用)${NC}" + # 只开放智途使用的端口 + ufw allow ${PORT}/tcp 2>/dev/null || true +fi + +# ============================================================================= +# 输出结果 +# ============================================================================= +echo "" +echo -e "${GREEN}=========================================${NC}" +echo -e "${GREEN} 部署完成!${NC}" +echo -e "${GREEN}=========================================${NC}" +echo "" + +PUBLIC_IP=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null || echo "你的服务器IP") +SERVICE_STATUS=$(systemctl is-active ${SERVICE_NAME} 2>/dev/null || echo "unknown") + +echo -e "后端服务状态 : ${SERVICE_STATUS}" +echo -e "使用端口 : ${YELLOW}${PORT}${NC}" +echo -e "公网 IP : ${PUBLIC_IP}" +echo "" + +if [ "$NGINX_ACTIVE" = true ]; then + echo -e "访问地址(直接端口):" + echo -e " ${YELLOW}http://${PUBLIC_IP}:${PORT}${NC}" + echo "" + echo -e "${YELLOW}提示:${NC} 由于 80 端口被其他项目占用," + echo -e " 单兵终端请配置为: ${GREEN}http://${PUBLIC_IP}:${PORT}${NC}" +else + echo -e "访问地址:" + echo -e " HTTP : ${YELLOW}http://${PUBLIC_IP}${NC}" + echo -e " 或 : ${YELLOW}http://${PUBLIC_IP}:${PORT}${NC}" +fi + +echo "" +echo -e "常用命令:" +echo -e " 查看日志 : ${YELLOW}journalctl -u zhitu -f${NC}" +echo -e " 重启服务 : ${YELLOW}systemctl restart zhitu${NC}" +echo -e " 查看状态 : ${YELLOW}systemctl status zhitu${NC}" +echo -e " 数据库位置 : ${YELLOW}${PROJECT_DIR}/zhitu.db${NC}" +echo "" diff --git a/src/软件电脑端/server/requirements.txt b/src/软件电脑端/server/requirements.txt index 26cfc63e..2f866706 100644 --- a/src/软件电脑端/server/requirements.txt +++ b/src/软件电脑端/server/requirements.txt @@ -1,2 +1,3 @@ flask==3.0.0 flask-cors==5.0.0 +gunicorn==23.0.0 diff --git a/后端操作手册.md b/后端操作手册.md new file mode 100644 index 00000000..1d933568 --- /dev/null +++ b/后端操作手册.md @@ -0,0 +1,146 @@ +# 智途投送后端操作手册 + +> 服务器:阿里云 ECS `121.41.216.243`(Ubuntu 24.04) +> 部署路径:`/opt/zhitu/` +> 服务名:`zhitu.service` +> 数据库:SQLite `/opt/zhitu/zhitu.db` + +--- + +## 1. 连接服务器 + +```bash +ssh root@121.41.216.243 +# 输入密码 +``` + +--- + +## 2. 服务启停(最常用) + +| 操作 | 命令 | +|------|------| +| 查看状态 | `systemctl status zhitu` | +| 启动服务 | `systemctl start zhitu` | +| 停止服务 | `systemctl stop zhitu` | +| 重启服务 | `systemctl restart zhitu` | +| 重载配置 | `systemctl daemon-reload` | +| 开机自启 | `systemctl enable zhitu`(已配置)| + +--- + +## 3. 查看日志(排错必备) + +```bash +# 查看最新 50 行日志 +journalctl -u zhitu -n 50 + +# 实时跟踪日志(调试时用,按 Ctrl+C 退出) +journalctl -u zhitu -f +``` + +--- + +## 4. 改 Worker 数量为 1(避免 SQLite 并发锁) + +当前配置为 2 个 worker,SQLite 并发写入时可能锁竞争导致请求卡住,建议改为 1 个: + +```bash +sed -i 's/-w 2/-w 1/' /etc/systemd/system/zhitu.service +systemctl daemon-reload +systemctl restart zhitu +systemctl status zhitu +``` + +--- + +## 5. 备份数据库 + +数据库文件在 `/opt/zhitu/zhitu.db`,课程结束前务必备份: + +```bash +# 方式一:在服务器上备份到 root 目录 +cp /opt/zhitu/zhitu.db ~/zhitu_backup_$(date +%Y%m%d_%H%M%S).db + +# 方式二:下载到本地电脑(在本地终端执行) +scp root@121.41.216.243:/opt/zhitu/zhitu.db D:\zhitu_backup.db +``` + +--- + +## 6. 修改后端代码后热更新 + +改完 `app.py` 后需要重启生效: + +```bash +# 先检查语法 +python3 -m py_compile /opt/zhitu/app.py + +# 语法通过后重启 +systemctl restart zhitu +``` + +--- + +## 7. 常用验证命令 + +```bash +# 测试后端是否存活 +curl http://121.41.216.243/api/ping + +# 查看所有需求(无需登录) +curl http://121.41.216.243/api/demands + +# 查看账号列表(无需登录) +curl http://121.41.216.243/api/auth/accounts + +# 直接查看数据库中的 token(调试用) +sqlite3 /opt/zhitu/zhitu.db "SELECT * FROM tokens;" + +# 查看所有需求记录 +sqlite3 /opt/zhitu/zhitu.db "SELECT * FROM demands;" +``` + +--- + +## 8. 网络/端口检查 + +```bash +# 查看 5000 端口是否在监听 +ss -tlnp | grep 5000 + +# 查看 Nginx 状态 +systemctl status nginx + +# 查看防火墙状态 +ufw status +``` + +--- + +## 9. 常见问题速查 + +| 现象 | 可能原因 | 解决 | +|------|----------|------| +| APP 上报 500 | `_demand_id_counter` 内存计数器归零导致主键冲突 | 已修复:改为数据库查询生成 ID | +| APP 上报超时/卡住 | Gunicorn 多 worker + SQLite 并发锁竞争 | 改 `-w 1` | +| APP 登录 401 | Token 跨 worker 不共享 / localStorage 缓存旧地址 | 已修复:Token 存数据库;检查 APP 服务器地址设置 | +| 电脑端首页 404 | 前端静态文件未上传到 `/opt/zhitu/` | 确保 `index.html`、`css/`、`js/` 都在部署目录 | +| 服务启动失败 | 语法错误 / 端口被占 | `python3 -m py_compile /opt/zhitu/app.py` 检查语法 | + +--- + +## 10. 关键文件路径 + +| 文件/目录 | 路径 | +|-----------|------| +| 后端主程序 | `/opt/zhitu/app.py` | +| 数据库 | `/opt/zhitu/zhitu.db` | +| 服务配置 | `/etc/systemd/system/zhitu.service` | +| Nginx 配置 | `/etc/nginx/sites-enabled/zhitu` | +| Python 虚拟环境 | `/opt/zhitu/venv/` | +| 日志输出 | `journalctl -u zhitu` | + +--- + +*手册生成时间:2026-05-23*