fix: 后端主键冲突修复 + Token持久化 + 部署脚本 + 操作手册

- 修复需求上报/任务分配 500 错误:内存计数器改为数据库查询
- 修复危险区域 ID 错乱:改用 lastrowid 获取自增 ID
- 后端:SQLite 持久化 + Werkzeug 密码哈希 + Token 数据库存储
- 前端:API 统一封装 + Token 自动携带 + Mock 兜底
- 新增 deploy_safe.sh 自适应部署脚本
- 新增后端操作手册
zhaochang_branch
赵昌 2 weeks ago
parent ed014e08a0
commit 0d2dba3910

21
.gitignore vendored

@ -0,0 +1,21 @@
# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
venv/
# Database
zhitu.db
*.db
# Audio data (large files)
*.wav
# Generated results
results/
# IDE
.idea/
.vscode/
*.iml

@ -0,0 +1,47 @@
@echo off
chcp 65001 >nul
setlocal EnableDelayedExpansion
echo ==========================================
echo Acoustic Offline Multichannel Demo Build
echo ==========================================
set SRC_ROOT=%~dp0
set EIGEN=%SRC_ROOT%third_party\eigen-3.4.0
set ONNX_INC=%SRC_ROOT%third_party\onnxruntime\include
set ONNX_LIB=%SRC_ROOT%third_party\onnxruntime\lib\libonnxruntime.a
set INCLUDES=-I%SRC_ROOT%include -I%EIGEN% -I%ONNX_INC%
set FLAGS=-std=c++17 -O2 -D_USE_MATH_DEFINES -Wa,-mbig-obj
set LIBS=%ONNX_LIB%
if not exist build mkdir build
echo [1/1] Building demo_offline_multichannel.exe ...
g++ %FLAGS% %INCLUDES% ^
tests\demo_offline_multichannel.cpp ^
src\core\pipeline.cpp ^
src\core\audio_buffer.cpp ^
src\core\fft_utils.cpp ^
src\core\feature_extractor.cpp ^
src\core\gunshot_classifier.cpp ^
src\core\gcc_phat_localizer.cpp ^
src\core\distance_estimator.cpp ^
src\core\threat_tracker.cpp ^
src\io\wav_file_source.cpp ^
-o build\demo_offline_multichannel.exe ^
%LIBS% -D_stdcall=
if %ERRORLEVEL% NEQ 0 (
echo [FAIL] demo_offline_multichannel build failed.
exit /b 1
)
echo.
echo [OK] demo_offline_multichannel.exe built successfully in build\^
echo.
echo Usage: build\demo_offline_multichannel.exe dataset\multichannel_test.wav --num_mics 4 --layout cross
echo build\demo_offline_multichannel.exe dataset\real\threat\ --threshold 0.6 --num_mics 4
echo.
echo With ground-truth for error analysis:
echo build\demo_offline_multichannel.exe synth_90deg_100m.wav --ground_azimuth 90 --ground_distance 100
endlocal

@ -86,6 +86,17 @@ struct PipelineConfig {
float confidence_threshold = 0.7f; ///< 输出置信度阈值
float min_event_interval = 0.3f; ///< 同一威胁最小上报间隔 (秒)
// 特征提取参数(原 hard-coded现可配置
uint32_t n_fft = 2048; ///< FFT 长度
uint32_t hop_length = 512; ///< 帧移样本数
float f_min = 0.0f; ///< 最低频率 (Hz)
float f_max = 8000.0f; ///< 最高频率 (Hz)
float preemphasis = 0.97f; ///< 预加重系数
// 定位参数(原 hard-coded现可配置
float max_tdoa = 0.00044f; ///< 最大允许时延 (秒)
int interpolation_factor = 4; ///< GCC-PHAT 插值倍数
MicArrayConfig mic_array;
ClassifierConfig classifier;
DistanceConfig distance;

@ -4,8 +4,12 @@
namespace acoustic {
class ThreatPublisher;
class AcousticNode;
std::unique_ptr<AcousticNode> create_acoustic_node(ros::NodeHandle& nh, ros::NodeHandle& pnh);
std::unique_ptr<AcousticNode> create_acoustic_node(
ros::NodeHandle& nh,
ros::NodeHandle& pnh,
ThreatPublisher* publisher = nullptr);
} // namespace acoustic

@ -5,8 +5,7 @@
"""
import argparse
import numpy as np
import torch
import torchaudio
import soundfile as sf
def generate_gunshot_signal(sr=16000, duration=2.0, freq=1000, decay=0.05):
@ -64,8 +63,8 @@ def main():
if max_val > 0:
multich = multich / max_val * 0.9
tensor = torch.from_numpy(multich)
torchaudio.save(args.output, tensor, args.sr)
# soundfile expects [samples, channels] for multi-channel
sf.write(args.output, multich.T, args.sr, subtype="PCM_16")
print(f"Generated {args.output}: azimuth={args.azimuth}°, distance={args.distance}m, shape={multich.shape}")

@ -0,0 +1,243 @@
#!/usr/bin/env python3
"""
噪声注入批量测试脚本
功能
1. 对干净音频按指定 SNR 注入噪声生成测试样本
2. 调用 C++ 离线推理程序批量检测
3. 输出 CSV 统计结果便于绘制 Pd-SNR / 定位误差-SNR 曲线
用法示例
# 单通道分类测试
python inject_noise_eval.py \\
--clean_dir dataset/real/threat \\
--noise noise_samples/drone_hover.wav \\
--snrs 20 10 5 0 -5 \\
--exe build/demo_offline.exe \\
--output results/threat_vs_drone_noise.csv
# 多通道定位测试
python inject_noise_eval.py \\
--clean_dir dataset/multichannel \\
--noise noise_samples/white.wav \\
--snrs 20 10 5 0 \\
--exe build/demo_offline_multichannel.exe \\
--num_mics 4 --layout cross \\
--output results/loc_snr_sweep.csv
"""
import argparse
import csv
import os
import subprocess
import sys
import tempfile
from pathlib import Path
import numpy as np
import soundfile as sf
def list_wav_files(directory: str):
"""递归列出目录下所有 wav 文件。"""
return sorted([str(p) for p in Path(directory).rglob("*.wav")])
def compute_rms(signal: np.ndarray):
"""计算信号 RMS。"""
return np.sqrt(np.mean(signal ** 2))
def mix_with_noise(clean: np.ndarray, noise: np.ndarray, target_snr_db: float):
"""
将噪声以指定 SNR 混入干净信号
如果噪声比干净信号短会循环拼接
支持多通道 cleannoise 会先混为单通道再扩展到匹配通道数或直接用同通道数
"""
# Ensure 2D: [channels, samples]
if clean.ndim == 1:
clean = clean.reshape(1, -1)
if noise.ndim == 1:
noise = noise.reshape(1, -1)
n_ch, n_samples = clean.shape
# Tile or truncate noise to match length
noise_tiled = np.tile(noise, (1, int(np.ceil(n_samples / noise.shape[1])) + 1))
noise_tiled = noise_tiled[:, :n_samples]
# If noise has fewer channels, replicate first channel
if noise_tiled.shape[0] < n_ch:
noise_tiled = np.repeat(noise_tiled[:1, :], n_ch, axis=0)
elif noise_tiled.shape[0] > n_ch:
noise_tiled = noise_tiled[:n_ch, :]
# Compute per-channel RMS and average for a global scaling factor
clean_rms = compute_rms(clean)
if clean_rms < 1e-10:
return clean # silent signal, return as-is
noise_rms = compute_rms(noise_tiled)
if noise_rms < 1e-10:
return clean
target_noise_rms = clean_rms / (10 ** (target_snr_db / 20.0))
scale = target_noise_rms / noise_rms
mixed = clean + noise_tiled * scale
# Normalize to prevent clipping
max_val = np.max(np.abs(mixed))
if max_val > 1.0:
mixed = mixed / max_val * 0.95
return mixed
def run_classifier(exe_path: str, wav_path: str, extra_args: list) -> dict:
"""
调用 C++ 离线推理程序解析其标准输出
返回字典包含 pred_label, confidence
"""
import re
cmd = [exe_path, wav_path] + extra_args
try:
result = subprocess.run(cmd, capture_output=True, timeout=30)
# Decode with GBK first (Windows console default), fallback to UTF-8
try:
output = result.stdout.decode("gbk", errors="replace")
except Exception:
output = result.stdout.decode("utf-8", errors="replace")
except Exception as e:
return {"error": str(e)}
parsed = {}
# Robust regex parsing for output like:
# File: xxx.wav | True: threat | Pred: threat | Conf: 0.9234 | Az: 90.00° | El: 0.00° | Dist: 95.20m
for line in output.splitlines():
if line.startswith("File:") and "Pred:" in line:
m = re.search(r"Pred:\s*(\S+)", line)
if m:
parsed["pred"] = m.group(1)
m = re.search(r"Conf:\s*([0-9.]+)", line)
if m:
parsed["conf"] = float(m.group(1))
m = re.search(r"Az:\s*([0-9.+-]+)", line)
if m:
parsed["azimuth"] = float(m.group(1))
m = re.search(r"El:\s*([0-9.+-]+)", line)
if m:
parsed["elevation"] = float(m.group(1))
m = re.search(r"Dist:\s*([0-9.+-]+)", line)
if m:
parsed["distance"] = float(m.group(1))
break
return parsed
def main():
parser = argparse.ArgumentParser(description="Noise injection evaluation for acoustic analyzer")
parser.add_argument("--clean_dir", required=True, help="Directory containing clean WAV files")
parser.add_argument("--noise", required=True, help="Noise WAV file path")
parser.add_argument("--snrs", nargs="+", type=float, required=True,
help="List of SNR values in dB, e.g. 20 10 5 0 -5")
parser.add_argument("--exe", required=True, help="Path to demo_offline or demo_offline_multichannel exe")
parser.add_argument("--output", required=True, help="Output CSV path")
parser.add_argument("--num_mics", type=int, default=1, help="Number of channels (for multichannel exe)")
parser.add_argument("--layout", default="cross", help="Array layout (for multichannel exe)")
parser.add_argument("--threshold", type=float, default=0.5, help="Detection threshold")
parser.add_argument("--ground_azimuth", type=float, default=None, help="Ground-truth azimuth for error calc")
parser.add_argument("--ground_distance", type=float, default=None, help="Ground-truth distance for error calc")
args = parser.parse_args()
clean_files = list_wav_files(args.clean_dir)
if not clean_files:
print(f"[ERROR] No WAV files found in {args.clean_dir}")
sys.exit(1)
noise_data, noise_sr = sf.read(args.noise, dtype="float32")
if noise_data.ndim == 1:
noise_data = noise_data.reshape(1, -1)
else:
noise_data = noise_data.T # [channels, samples]
# Prepare extra args for C++ exe
extra_args = [f"--threshold", str(args.threshold)]
if "multichannel" in os.path.basename(args.exe).lower():
extra_args += ["--num_mics", str(args.num_mics), "--layout", args.layout]
if args.ground_azimuth is not None:
extra_args += ["--ground_azimuth", str(args.ground_azimuth)]
if args.ground_distance is not None:
extra_args += ["--ground_distance", str(args.ground_distance)]
# CSV header
fieldnames = ["file", "true_label", "snr_db", "pred_label", "confidence",
"azimuth", "elevation", "distance"]
if args.ground_azimuth is not None:
fieldnames.append("azimuth_error")
if args.ground_distance is not None:
fieldnames.append("distance_error")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
total_runs = len(clean_files) * len(args.snrs)
run_idx = 0
with tempfile.TemporaryDirectory() as tmpdir:
for clean_path in clean_files:
clean_data, clean_sr = sf.read(clean_path, dtype="float32")
true_label = Path(clean_path).parent.name
# Ensure 2D
if clean_data.ndim == 1:
clean_data = clean_data.reshape(1, -1)
else:
clean_data = clean_data.T
for snr in args.snrs:
run_idx += 1
mixed = mix_with_noise(clean_data, noise_data, snr)
# Save temp WAV (interleaved)
if mixed.shape[0] == 1:
mixed_interleaved = mixed[0]
else:
mixed_interleaved = mixed.T # [samples, channels]
tmp_wav = os.path.join(tmpdir, f"mix_{run_idx:04d}.wav")
sf.write(tmp_wav, mixed_interleaved, clean_sr, subtype="PCM_16")
# Run inference
parsed = run_classifier(args.exe, tmp_wav, extra_args)
row = {
"file": Path(clean_path).name,
"true_label": true_label,
"snr_db": snr,
"pred_label": parsed.get("pred", "error"),
"confidence": parsed.get("conf", 0.0),
"azimuth": parsed.get("azimuth", ""),
"elevation": parsed.get("elevation", ""),
"distance": parsed.get("distance", ""),
}
if args.ground_azimuth is not None and "azimuth" in parsed:
err = abs(parsed["azimuth"] - args.ground_azimuth)
if err > 180:
err = 360 - err
row["azimuth_error"] = round(err, 2)
if args.ground_distance is not None and "distance" in parsed:
row["distance_error"] = round(abs(parsed["distance"] - args.ground_distance), 2)
writer.writerow(row)
print(f"[{run_idx}/{total_runs}] {Path(clean_path).name} SNR={snr}dB -> {row['pred_label']} ({row['confidence']:.3f})")
print(f"\n[OK] Results saved to {args.output}")
if __name__ == "__main__":
main()

@ -127,7 +127,9 @@ struct GccPhatLocalizer::Impl {
}
float delay_samples = max_idx + p;
float delay_sec = delay_samples / static_cast<float>(sample_rate);
return delay_sec;
// Note: GCC-PHAT IFFT peak at tau corresponds to ch2[n] = ch1[n - tau],
// so tau is actually the negative of the physical delay of ch2 relative to ch1.
return -delay_sec;
}
bool SolveDirection(const std::vector<float>& delays, float& azimuth, float& elevation) {

@ -6,7 +6,17 @@
#include "acoustic_analyzer/core/distance_estimator.h"
#include "acoustic_analyzer/core/threat_tracker.h"
// yaml-cpp is optional: if absent, Pipeline still works but FromYaml returns defaults
#if defined(__has_include)
# if __has_include(<yaml-cpp/yaml.h>)
# define ACOUSTIC_HAS_YAML_CPP 1
# endif
#endif
#ifdef ACOUSTIC_HAS_YAML_CPP
#include <yaml-cpp/yaml.h>
#endif
#include <cmath>
#include <algorithm>
#include <numeric>
@ -40,8 +50,13 @@ struct Pipeline::Impl {
static_cast<int>(config.mic_array.num_mics));
feature_extractor = std::make_unique<FeatureExtractor>(
static_cast<int>(config.sample_rate), 2048, 512,
static_cast<int>(config.n_mels));
static_cast<int>(config.sample_rate),
static_cast<int>(config.n_fft),
static_cast<int>(config.hop_length),
static_cast<int>(config.n_mels),
config.f_min,
config.f_max,
config.preemphasis);
classifier = std::make_unique<GunshotClassifier>(config.classifier);
@ -49,7 +64,8 @@ struct Pipeline::Impl {
localizer = std::make_unique<GccPhatLocalizer>(
config.mic_array,
static_cast<int>(config.sample_rate),
0.00044f, 4);
config.max_tdoa,
config.interpolation_factor);
}
distance_estimator = std::make_unique<DistanceEstimator>(config.distance);
@ -239,6 +255,7 @@ AcousticFrame Pipeline::Process(const std::vector<float>& audio_samples) {
PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) {
PipelineConfig config;
#ifdef ACOUSTIC_HAS_YAML_CPP
try {
YAML::Node root = YAML::LoadFile(yaml_path);
@ -283,19 +300,19 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) {
auto feat = root["features"];
if (feat["n_mels"]) config.n_mels = feat["n_mels"].as<uint32_t>();
if (feat["n_fft"]) {
// Not stored in PipelineConfig directly; FeatureExtractor uses hard-coded 2048
config.n_fft = feat["n_fft"].as<uint32_t>();
}
if (feat["hop_length"]) {
// Not stored in PipelineConfig directly
config.hop_length = feat["hop_length"].as<uint32_t>();
}
if (feat["f_min"]) {
// Not stored directly
config.f_min = feat["f_min"].as<float>();
}
if (feat["f_max"]) {
// Not stored directly
config.f_max = feat["f_max"].as<float>();
}
if (feat["preemphasis"]) {
// Not stored directly
config.preemphasis = feat["preemphasis"].as<float>();
}
}
@ -314,10 +331,10 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) {
if (root["localization"]) {
auto loc = root["localization"];
if (loc["max_tdoa"]) {
// Not stored directly in PipelineConfig
config.max_tdoa = loc["max_tdoa"].as<float>();
}
if (loc["interpolation_factor"]) {
// Not stored directly
config.interpolation_factor = loc["interpolation_factor"].as<int>();
}
}
@ -356,6 +373,9 @@ PipelineConfig Pipeline::FromYaml(const std::string& yaml_path) {
(void)e;
// On parse error, return defaults
}
#else
(void)yaml_path;
#endif
return config;
}

@ -7,8 +7,8 @@ int main(int argc, char** argv) {
ros::NodeHandle nh;
ros::NodeHandle pnh("~");
auto node = acoustic::create_acoustic_node(nh, pnh);
acoustic::ThreatPublisher publisher(nh);
auto node = acoustic::create_acoustic_node(nh, pnh, &publisher);
node->run();
return 0;

@ -15,7 +15,8 @@ namespace acoustic {
class AcousticNode {
public:
AcousticNode(ros::NodeHandle& nh, ros::NodeHandle& pnh) : nh_(nh), pnh_(pnh) {
AcousticNode(ros::NodeHandle& nh, ros::NodeHandle& pnh, ThreatPublisher* publisher)
: nh_(nh), pnh_(pnh), threat_publisher_(publisher) {
load_params();
init_pipeline();
init_source();
@ -36,6 +37,7 @@ private:
ros::NodeHandle nh_;
ros::NodeHandle pnh_;
ros::Subscriber audio_sub_;
ThreatPublisher* threat_publisher_ = nullptr;
std::string source_type_;
std::unique_ptr<Pipeline> pipeline_;
@ -129,6 +131,12 @@ private:
return flat;
}
void publish_if_available(const AcousticFrame& frame) {
if (threat_publisher_ && !frame.is_clear) {
threat_publisher_->Publish(frame);
}
}
void on_mic_array_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) {
if (msg->layout.dim.size() < 2) return;
int channels = static_cast<int>(msg->layout.dim[0].size);
@ -138,14 +146,14 @@ private:
// Assuming data is interleaved or [channels x samples] row-major
std::vector<float> flat(msg->data.begin(), msg->data.end());
auto frame = pipeline_->Process(flat);
(void)frame; // Would be published by threat_publisher in main loop
publish_if_available(frame);
}
void on_mobile_phone_audio(const std_msgs::Float32MultiArray::ConstPtr& msg) {
if (msg->data.empty()) return;
std::vector<float> flat(msg->data.begin(), msg->data.end());
auto frame = pipeline_->Process(flat);
(void)frame;
publish_if_available(frame);
}
void process_wav_source() {
@ -160,12 +168,15 @@ private:
}
auto flat = flatten_audio(audio, static_cast<int>(wav_source_->num_channels()));
auto frame = pipeline_->Process(flat);
(void)frame;
publish_if_available(frame);
}
};
std::unique_ptr<AcousticNode> create_acoustic_node(ros::NodeHandle& nh, ros::NodeHandle& pnh) {
return std::make_unique<AcousticNode>(nh, pnh);
std::unique_ptr<AcousticNode> create_acoustic_node(
ros::NodeHandle& nh,
ros::NodeHandle& pnh,
ThreatPublisher* publisher) {
return std::make_unique<AcousticNode>(nh, pnh, publisher);
}
} // namespace acoustic

@ -0,0 +1,331 @@
#include <iostream>
#include <iomanip>
#include <vector>
#include <string>
#include <filesystem>
#include <algorithm>
#include <cmath>
#include <numeric>
#include <map>
#include <chrono>
#include <cstring>
#include "acoustic_analyzer/core/pipeline.h"
#include "acoustic_analyzer/io/wav_file_source.h"
namespace fs = std::filesystem;
using namespace acoustic;
struct Prediction {
std::string file_path;
std::string true_label;
std::string pred_label;
float confidence = 0.0f;
float azimuth = 0.0f;
float elevation = 0.0f;
float distance = -1.0f;
bool detected = false;
};
void print_usage(const char* prog) {
std::cerr << "Usage: " << prog << " <file_or_dir> [options]" << std::endl;
std::cerr << "Options:" << std::endl;
std::cerr << " --model <path> ONNX model path (default: models/gunshot_classifier.onnx)" << std::endl;
std::cerr << " --label_map <path> Label map file (default: models/label_map.json)" << std::endl;
std::cerr << " --threshold <float> Detection threshold (default: 0.5)" << std::endl;
std::cerr << " --num_mics <int> Number of channels in WAV (default: 4)" << std::endl;
std::cerr << " --spacing <float> Mic spacing in meters (default: 0.15)" << std::endl;
std::cerr << " --layout <str> Array layout: cross/linear/circular (default: cross)" << std::endl;
std::cerr << " --ref_spl <float> Reference SPL for distance estimation (default: 150)" << std::endl;
std::cerr << " --ground_azimuth <float> Ground-truth azimuth for error calc (optional)" << std::endl;
std::cerr << " --ground_distance <float> Ground-truth distance for error calc (optional)" << std::endl;
}
bool ends_with(const std::string& s, const std::string& suffix) {
if (s.size() < suffix.size()) return false;
return s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0;
}
std::string get_parent_folder_name(const std::string& path) {
fs::path p(path);
if (p.has_parent_path()) {
return p.parent_path().filename().string();
}
return "";
}
// Convert vector-of-vectors [channels][samples] to flat interleaved [ch0_s0, ch1_s0, ...]
std::vector<float> flatten_audio(const std::vector<std::vector<float>>& audio, int channels) {
if (audio.empty() || channels == 0) return {};
size_t samples = audio[0].size();
std::vector<float> flat(samples * channels);
for (size_t s = 0; s < samples; ++s) {
for (int ch = 0; ch < channels; ++ch) {
flat[s * channels + ch] = (ch < static_cast<int>(audio.size()) && s < audio[ch].size())
? audio[ch][s] : 0.0f;
}
}
return flat;
}
Prediction process_file(const std::string& path,
Pipeline& pipeline,
int num_mics,
float ground_azimuth,
float ground_distance) {
Prediction result;
result.file_path = path;
result.true_label = get_parent_folder_name(path);
WavFileSource wav(path);
if (!wav.open()) {
std::cerr << "[SKIP] Cannot open: " << path << std::endl;
result.pred_label = "error";
return result;
}
int sr = wav.sample_rate();
int file_ch = wav.num_channels();
if (file_ch != num_mics) {
std::cerr << "[WARN] " << path << " has " << file_ch
<< " channels, expected " << num_mics << std::endl;
}
// Use actual file channels if mismatch, but pipeline config should match
int effective_channels = (file_ch < num_mics) ? file_ch : num_mics;
size_t chunk_samples = static_cast<size_t>(sr * pipeline.Config().chunk_duration);
std::vector<std::vector<float>> audio;
size_t got = wav.read(audio, chunk_samples);
if (got == 0 || audio.empty()) {
result.pred_label = "empty";
return result;
}
// Process only the first chunk (like demo_offline)
// For sliding-window analysis, call Process() on hop-sized chunks
auto flat = flatten_audio(audio, effective_channels);
auto frame = pipeline.Process(flat);
if (!frame.is_clear && !frame.threats.empty()) {
const auto& t = frame.threats[0];
result.detected = true;
result.pred_label = t.sound_type;
result.confidence = t.confidence;
result.azimuth = t.azimuth;
result.elevation = t.elevation;
result.distance = t.distance;
} else {
result.pred_label = "ambient";
result.confidence = 0.0f;
}
std::cout << "File: " << fs::path(path).filename().string()
<< " | True: " << result.true_label
<< " | Pred: " << result.pred_label
<< " | Conf: " << std::fixed << std::setprecision(4) << result.confidence;
if (result.detected) {
std::cout << " | Az: " << std::setprecision(2) << result.azimuth << "°"
<< " | El: " << std::setprecision(2) << result.elevation << "°"
<< " | Dist: " << std::setprecision(2) << result.distance << "m";
if (ground_azimuth >= 0.0f) {
float az_err = std::fabs(result.azimuth - ground_azimuth);
if (az_err > 180.0f) az_err = 360.0f - az_err;
std::cout << " | AzErr: " << az_err << "°";
}
if (ground_distance >= 0.0f) {
std::cout << " | DistErr: " << std::fabs(result.distance - ground_distance) << "m";
}
}
std::cout << std::endl;
return result;
}
void collect_wav_files(const std::string& target, std::vector<std::string>& out) {
if (fs::is_regular_file(target) && ends_with(target, ".wav")) {
out.push_back(target);
return;
}
if (!fs::is_directory(target)) return;
for (const auto& entry : fs::recursive_directory_iterator(target)) {
if (entry.is_regular_file() && ends_with(entry.path().string(), ".wav")) {
out.push_back(entry.path().string());
}
}
std::sort(out.begin(), out.end());
}
void print_report(const std::vector<Prediction>& results,
float ground_azimuth,
float ground_distance) {
std::map<std::string, int> total_by_true;
std::map<std::string, int> correct_by_true;
std::map<std::string, float> conf_sum_by_true;
std::map<std::string, std::map<std::string, int>> confusion;
int total = 0, correct = 0;
int detected_count = 0;
float az_err_sum = 0.0f, dist_err_sum = 0.0f;
int az_err_count = 0, dist_err_count = 0;
for (const auto& r : results) {
if (r.pred_label == "error" || r.pred_label == "empty") continue;
total++;
total_by_true[r.true_label]++;
conf_sum_by_true[r.true_label] += r.confidence;
confusion[r.true_label][r.pred_label]++;
if (r.true_label == r.pred_label) {
correct++;
correct_by_true[r.true_label]++;
}
if (r.detected) {
detected_count++;
if (ground_azimuth >= 0.0f) {
float az_err = std::fabs(r.azimuth - ground_azimuth);
if (az_err > 180.0f) az_err = 360.0f - az_err;
az_err_sum += az_err;
az_err_count++;
}
if (ground_distance >= 0.0f) {
dist_err_sum += std::fabs(r.distance - ground_distance);
dist_err_count++;
}
}
}
std::cout << "\n==========================================" << std::endl;
std::cout << " MULTICHANNEL VALIDATION REPORT" << std::endl;
std::cout << "==========================================" << std::endl;
std::cout << "Total samples: " << total << std::endl;
std::cout << "Correct: " << correct << std::endl;
std::cout << "Accuracy: " << std::fixed << std::setprecision(2)
<< (total > 0 ? 100.0f * correct / total : 0.0f) << "%" << std::endl;
std::cout << "Detected frames: " << detected_count << std::endl;
if (az_err_count > 0) {
std::cout << "Azimuth RMSE: " << std::setprecision(2)
<< std::sqrt(az_err_sum / az_err_count) << "°" << std::endl;
}
if (dist_err_count > 0) {
std::cout << "Distance MAE: " << std::setprecision(2)
<< (dist_err_sum / dist_err_count) << "m" << std::endl;
}
std::cout << "\nPer-class breakdown:" << std::endl;
for (const auto& kv : total_by_true) {
const std::string& cls = kv.first;
int cls_total = kv.second;
int cls_correct = correct_by_true[cls];
float avg_conf = conf_sum_by_true[cls] / cls_total;
std::cout << " " << std::setw(10) << std::left << cls
<< " Count: " << std::setw(3) << cls_total
<< " Correct: " << std::setw(3) << cls_correct
<< " Acc: " << std::setw(6) << std::fixed << std::setprecision(2)
<< (100.0f * cls_correct / cls_total) << "%"
<< " AvgConf: " << std::setprecision(4) << avg_conf << std::endl;
}
std::cout << "\nConfusion matrix (rows=true, cols=pred):" << std::endl;
std::vector<std::string> labels;
for (const auto& row : confusion) labels.push_back(row.first);
for (const auto& row : confusion) {
for (const auto& col : row.second) {
if (std::find(labels.begin(), labels.end(), col.first) == labels.end()) {
labels.push_back(col.first);
}
}
}
std::sort(labels.begin(), labels.end());
std::cout << std::setw(12) << " ";
for (const auto& l : labels) std::cout << std::setw(10) << l;
std::cout << std::endl;
for (const auto& true_l : labels) {
std::cout << std::setw(10) << std::left << true_l << " ";
for (const auto& pred_l : labels) {
int count = confusion.count(true_l) ? confusion[true_l].count(pred_l) ? confusion[true_l].at(pred_l) : 0 : 0;
std::cout << std::setw(10) << count;
}
std::cout << std::endl;
}
std::cout << "==========================================" << std::endl;
}
int main(int argc, char** argv) {
if (argc < 2 || std::strcmp(argv[1], "--help") == 0 || std::strcmp(argv[1], "-h") == 0) {
print_usage(argv[0]);
return argc < 2 ? 1 : 0;
}
std::string target = argv[1];
std::string model_path = "models/gunshot_classifier.onnx";
std::string label_map_path = "models/label_map.json";
float threshold = 0.5f;
int num_mics = 4;
float spacing = 0.15f;
std::string layout = "cross";
float ref_spl = 150.0f;
float ground_azimuth = -1.0f;
float ground_distance = -1.0f;
for (int i = 2; i < argc; ++i) {
if (std::strcmp(argv[i], "--model") == 0 && i + 1 < argc) model_path = argv[++i];
else if (std::strcmp(argv[i], "--label_map") == 0 && i + 1 < argc) label_map_path = argv[++i];
else if (std::strcmp(argv[i], "--threshold") == 0 && i + 1 < argc) threshold = std::stof(argv[++i]);
else if (std::strcmp(argv[i], "--num_mics") == 0 && i + 1 < argc) num_mics = std::stoi(argv[++i]);
else if (std::strcmp(argv[i], "--spacing") == 0 && i + 1 < argc) spacing = std::stof(argv[++i]);
else if (std::strcmp(argv[i], "--layout") == 0 && i + 1 < argc) layout = argv[++i];
else if (std::strcmp(argv[i], "--ref_spl") == 0 && i + 1 < argc) ref_spl = std::stof(argv[++i]);
else if (std::strcmp(argv[i], "--ground_azimuth") == 0 && i + 1 < argc) ground_azimuth = std::stof(argv[++i]);
else if (std::strcmp(argv[i], "--ground_distance") == 0 && i + 1 < argc) ground_distance = std::stof(argv[++i]);
}
// Build PipelineConfig directly (no yaml-cpp needed)
PipelineConfig config;
config.sample_rate = 16000;
config.chunk_duration = 2.0f;
config.hop_duration = 0.5f;
config.n_mels = 64;
config.confidence_threshold = threshold;
config.classifier.model_path = model_path;
config.classifier.label_map_path = label_map_path;
config.classifier.threshold = threshold;
config.classifier.smoothing_window = 1; // offline: no temporal smoothing
config.mic_array.num_mics = static_cast<uint32_t>(num_mics);
config.mic_array.layout = layout;
config.mic_array.spacing = spacing;
config.distance.ref_spl_gunshot = ref_spl;
config.distance.ref_spl_artillery = ref_spl;
config.distance.ref_spl_explosion = ref_spl;
Pipeline pipeline(config);
std::vector<std::string> files;
collect_wav_files(target, files);
if (files.empty()) {
std::cerr << "No .wav files found in: " << target << std::endl;
return 1;
}
std::cout << "Found " << files.size() << " WAV file(s)." << std::endl;
std::cout << "Channels: " << num_mics << ", Layout: " << layout
<< ", Spacing: " << spacing << "m" << std::endl;
std::cout << std::endl;
std::vector<Prediction> results;
results.reserve(files.size());
double total_ms = 0.0;
for (const auto& f : files) {
auto t0 = std::chrono::steady_clock::now();
results.push_back(process_file(f, pipeline, num_mics, ground_azimuth, ground_distance));
auto t1 = std::chrono::steady_clock::now();
total_ms += std::chrono::duration<double, std::milli>(t1 - t0).count();
}
std::cout << "\nTotal inference time: " << std::fixed << std::setprecision(2)
<< total_ms << " ms"
<< " | Avg per file: " << (files.empty() ? 0.0 : total_ms / files.size())
<< " ms" << std::endl;
print_report(results, ground_azimuth, ground_distance);
return 0;
}

@ -4,7 +4,8 @@
*/
const API = (() => {
let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://192.168.1.14:5000';
let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://121.41.216.243';
let TOKEN = localStorage.getItem('auth_token') || '';
async function request(url, options = {}) {
// 每次请求时重新读取BASE支持动态切换
@ -15,14 +16,33 @@ const API = (() => {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 5000);
// 自动带上认证Token
const headers = {
'Content-Type': 'application/json',
...(options.headers || {})
};
if (TOKEN) {
headers['X-Auth-Token'] = TOKEN;
}
try {
const resp = await fetch(fullUrl, {
headers: { 'Content-Type': 'application/json' },
headers,
signal: controller.signal,
...options
});
clearTimeout(timeoutId);
if (!resp.ok) {
if (resp.status === 401) {
// Token失效清除登录状态并跳转登录页
TOKEN = '';
localStorage.removeItem('auth_token');
localStorage.removeItem('soldier_session');
if (typeof App !== 'undefined' && App.router) {
App.router('login');
App.showToast('登录已过期,请重新登录');
}
}
const text = await resp.text();
throw new Error(`HTTP ${resp.status}: ${text}`);
}
@ -187,6 +207,14 @@ const API = (() => {
];
}
function setBase(url) { BASE = url; }
function setToken(token) {
TOKEN = token;
if (token) localStorage.setItem('auth_token', token);
else localStorage.removeItem('auth_token');
}
function getToken() { return TOKEN; }
return {
updateLocation,
getSoldiers,
@ -201,6 +229,9 @@ const API = (() => {
sendSOS,
login,
register,
getAccounts
getAccounts,
setBase,
setToken,
getToken
};
})();

@ -6,10 +6,10 @@
const App = (() => {
// ===== 配置 =====
const CONFIG = {
apiBase: localStorage.getItem('api_base') || 'http://192.168.1.14:5000',
soldierId: localStorage.getItem('soldier_id') || 'soldier_001',
soldierName: localStorage.getItem('soldier_name') || '张三',
soldierUnit: localStorage.getItem('soldier_unit') || '第3步兵师/1连',
apiBase: localStorage.getItem('api_base') || 'http://121.41.216.243',
soldierId: '',
soldierName: '',
soldierUnit: '',
pollInterval: 5000
};
@ -171,7 +171,7 @@ const App = (() => {
if (!url.startsWith('http')) url = 'http://' + url;
CONFIG.apiBase = url;
localStorage.setItem('api_base', url);
API.BASE = url;
API.setBase(url);
updateSoldierInfo();
showToast('服务器地址已更新');
}
@ -186,31 +186,10 @@ const App = (() => {
return;
}
// 演示账号:不连后端也能直接登录
const demoAccounts = {
'soldier_001': { name: '张三', unit: '第3步兵师/1连', role: '狙击手' },
'soldier_002': { name: '李四', unit: '第3步兵师/2连', role: '机枪手' },
'soldier_003': { name: '王五', unit: '第3步兵师/3连', role: '通讯员' }
};
if (demoAccounts[id] && pwd === '123456') {
const acc = demoAccounts[id];
CONFIG.soldierId = id;
CONFIG.soldierName = acc.name;
CONFIG.soldierUnit = acc.unit;
localStorage.setItem('soldier_session', JSON.stringify({
soldier_id: id, name: acc.name, unit: acc.unit, role: acc.role
}));
updateSoldierInfo();
router('home');
startPolling();
startLocationReporting();
showToast('登录成功,欢迎 ' + acc.name);
return;
}
try {
const result = await API.login(id, pwd);
if (result.ok) {
API.setToken(result.token); // 保存认证令牌到前后端统一中心
CONFIG.soldierId = result.soldier_id;
CONFIG.soldierName = result.name;
CONFIG.soldierUnit = result.unit;
@ -270,6 +249,7 @@ const App = (() => {
function logout() {
if (confirm('确定要退出登录吗?')) {
localStorage.removeItem('soldier_session');
API.setToken(''); // 清除认证令牌
CONFIG.soldierId = '';
CONFIG.soldierName = '';
stopPolling();

@ -4,7 +4,8 @@
*/
const API = (() => {
let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://192.168.1.14:5000';
let BASE = (typeof App !== 'undefined' && App.CONFIG) ? App.CONFIG.apiBase : 'http://121.41.216.243';
let TOKEN = localStorage.getItem('auth_token') || '';
async function request(url, options = {}) {
// 每次请求时重新读取BASE支持动态切换
@ -15,14 +16,33 @@ const API = (() => {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 5000);
// 自动带上认证Token
const headers = {
'Content-Type': 'application/json',
...(options.headers || {})
};
if (TOKEN) {
headers['X-Auth-Token'] = TOKEN;
}
try {
const resp = await fetch(fullUrl, {
headers: { 'Content-Type': 'application/json' },
headers,
signal: controller.signal,
...options
});
clearTimeout(timeoutId);
if (!resp.ok) {
if (resp.status === 401) {
// Token失效清除登录状态并跳转登录页
TOKEN = '';
localStorage.removeItem('auth_token');
localStorage.removeItem('soldier_session');
if (typeof App !== 'undefined' && App.router) {
App.router('login');
App.showToast('登录已过期,请重新登录');
}
}
const text = await resp.text();
throw new Error(`HTTP ${resp.status}: ${text}`);
}
@ -187,6 +207,14 @@ const API = (() => {
];
}
function setBase(url) { BASE = url; }
function setToken(token) {
TOKEN = token;
if (token) localStorage.setItem('auth_token', token);
else localStorage.removeItem('auth_token');
}
function getToken() { return TOKEN; }
return {
updateLocation,
getSoldiers,
@ -201,6 +229,9 @@ const API = (() => {
sendSOS,
login,
register,
getAccounts
getAccounts,
setBase,
setToken,
getToken
};
})();

@ -6,10 +6,10 @@
const App = (() => {
// ===== 配置 =====
const CONFIG = {
apiBase: localStorage.getItem('api_base') || 'http://192.168.1.14:5000',
soldierId: localStorage.getItem('soldier_id') || 'soldier_001',
soldierName: localStorage.getItem('soldier_name') || '张三',
soldierUnit: localStorage.getItem('soldier_unit') || '第3步兵师/1连',
apiBase: localStorage.getItem('api_base') || 'http://121.41.216.243',
soldierId: '',
soldierName: '',
soldierUnit: '',
pollInterval: 5000
};
@ -171,7 +171,7 @@ const App = (() => {
if (!url.startsWith('http')) url = 'http://' + url;
CONFIG.apiBase = url;
localStorage.setItem('api_base', url);
API.BASE = url;
API.setBase(url);
updateSoldierInfo();
showToast('服务器地址已更新');
}
@ -186,31 +186,10 @@ const App = (() => {
return;
}
// 演示账号:不连后端也能直接登录
const demoAccounts = {
'soldier_001': { name: '张三', unit: '第3步兵师/1连', role: '狙击手' },
'soldier_002': { name: '李四', unit: '第3步兵师/2连', role: '机枪手' },
'soldier_003': { name: '王五', unit: '第3步兵师/3连', role: '通讯员' }
};
if (demoAccounts[id] && pwd === '123456') {
const acc = demoAccounts[id];
CONFIG.soldierId = id;
CONFIG.soldierName = acc.name;
CONFIG.soldierUnit = acc.unit;
localStorage.setItem('soldier_session', JSON.stringify({
soldier_id: id, name: acc.name, unit: acc.unit, role: acc.role
}));
updateSoldierInfo();
router('home');
startPolling();
startLocationReporting();
showToast('登录成功,欢迎 ' + acc.name);
return;
}
try {
const result = await API.login(id, pwd);
if (result.ok) {
API.setToken(result.token); // 保存认证令牌到前后端统一中心
CONFIG.soldierId = result.soldier_id;
CONFIG.soldierName = result.name;
CONFIG.soldierUnit = result.unit;
@ -270,6 +249,7 @@ const App = (() => {
function logout() {
if (confirm('确定要退出登录吗?')) {
localStorage.removeItem('soldier_session');
API.setToken(''); // 清除认证令牌
CONFIG.soldierId = '';
CONFIG.soldierName = '';
stopPolling();

@ -5,28 +5,135 @@ Flask 服务器 - 智途投送电脑端
"""
import os
import sqlite3
import secrets
from functools import wraps
from datetime import datetime
from flask import Flask, send_from_directory, request, jsonify
from flask_cors import CORS
from werkzeug.security import generate_password_hash, check_password_hash
# 项目根目录server 的上级目录)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 项目根目录app.py 所在目录,部署时前端静态文件应放在同级目录)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(os.path.dirname(__file__), 'zhitu.db')
app = Flask(__name__, static_folder=BASE_DIR)
CORS(app) # 允许跨域请求支持单兵APP访问
# ---- 内存数据存储 ----
soldiers = {} # { soldier_id: { id, name, lat, lng, updated_at } }
danger_zones = [] # [ { id, lat, lng, radius, description, created_at } ]
_demands = [] # [ { id, soldier_id, soldier_name, type, quantity, unit, urgency, status, created_at } ]
_demand_id_counter = 0
_drop_points = [ # 推荐投放点(演示数据)
# ---- SQLite 数据库 ----
def get_db():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def init_db():
conn = get_db()
conn.executescript('''
CREATE TABLE IF NOT EXISTS accounts (
soldier_id TEXT PRIMARY KEY,
password_hash TEXT NOT NULL,
name TEXT NOT NULL,
unit TEXT,
role TEXT,
created_at TEXT
);
CREATE TABLE IF NOT EXISTS demands (
id TEXT PRIMARY KEY,
soldier_id TEXT,
soldier_name TEXT,
type TEXT,
items TEXT,
quantity INTEGER,
unit TEXT,
urgency TEXT,
status TEXT DEFAULT '待处理',
lat REAL,
lng REAL,
created_at TEXT
);
CREATE TABLE IF NOT EXISTS soldiers (
soldier_id TEXT PRIMARY KEY,
name TEXT,
lat REAL,
lng REAL,
updated_at TEXT
);
CREATE TABLE IF NOT EXISTS sos_alerts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
soldier_id TEXT,
soldier_name TEXT,
lat REAL,
lng REAL,
alert_time TEXT,
handled INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS danger_zones (
id INTEGER PRIMARY KEY AUTOINCREMENT,
lat REAL,
lng REAL,
radius REAL,
description TEXT,
created_at TEXT
);
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
soldier_id TEXT,
soldier_name TEXT,
type TEXT,
status TEXT,
progress INTEGER DEFAULT 0,
eta TEXT,
remain_time TEXT,
start_name TEXT,
target_name TEXT,
start_lat REAL,
start_lng REAL,
end_lat REAL,
end_lng REAL,
safety_score INTEGER,
created_at TEXT
);
CREATE TABLE IF NOT EXISTS tokens (
token TEXT PRIMARY KEY,
soldier_id TEXT,
name TEXT,
unit TEXT,
role TEXT,
created_at TEXT
);
''')
conn.commit()
conn.close()
def init_demo_accounts():
"""初始化演示账号到数据库,确保前后端账号统一"""
conn = get_db()
demo_accounts = [
("soldier_001", "123456", "张三", "第3步兵师/1连", "狙击手"),
("soldier_002", "123456", "李四", "第3步兵师/2连", "机枪手"),
("soldier_003", "123456", "王五", "第3步兵师/3连", "通讯员"),
]
for sid, pwd, name, unit, role in demo_accounts:
existing = conn.execute('SELECT 1 FROM accounts WHERE soldier_id=?', (sid,)).fetchone()
if not existing:
conn.execute('''
INSERT INTO accounts (soldier_id, password_hash, name, unit, role, created_at)
VALUES (?, ?, ?, ?, ?, ?)
''', (sid, generate_password_hash(pwd), name, unit, role,
datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
conn.commit()
conn.close()
init_db()
init_demo_accounts()
# ---- 内存数据(演示数据:投放点、无人机状态重启后恢复默认值)----
_drop_points = [
{"id": 1, "name": "安全点 #1", "safety_score": 92, "distance": 5, "reason": "深掩体保护,视角盲区", "lat": 30.0050, "lng": 120.0030},
{"id": 2, "name": "安全点 #2", "safety_score": 85, "distance": 12, "reason": "钢筋混凝土建筑内部", "lat": 30.0055, "lng": 120.0035},
{"id": 3, "name": "陷阱点 #3", "safety_score": 35, "distance": 20, "reason": "孤立断墙,成瞄准点", "lat": 30.0060, "lng": 120.0040}
]
_tasks = {} # { soldier_id: task }
_drone_status = { # 无人机实时状态(演示数据)
_drone_status = {
"drone_id": "无人机-01",
"task_id": "#001",
"status": "飞行中",
@ -40,20 +147,26 @@ _drone_status = { # 无人机实时状态(演示数据)
"lat": 30.0040,
"lng": 120.0025
}
_drone_logs = [ # 无人机动态日志
_drone_logs = [
{"time": "12:25:30", "message": "到达投放点"},
{"time": "12:20:45", "message": "接收任务指令"},
{"time": "12:10:00", "message": "任务分配"}
]
_sos_alerts = [] # 求救记录
_accounts = { # 士兵账号 { soldier_id: { password, name, unit, role } }
"soldier_001": { "password": "123456", "name": "张三", "unit": "第3步兵师/1连", "role": "狙击手" },
"soldier_002": { "password": "123456", "name": "李四", "unit": "第3步兵师/2连", "role": "机枪手" },
"soldier_003": { "password": "123456", "name": "王五", "unit": "第3步兵师/3连", "role": "通讯员" }
}
_tasks = {} # { soldier_id: { id, status, progress, eta, ... } }
_task_id_counter = 0
_danger_id_counter = 0
# 计数器已从内存变量改为数据库查询,避免服务重启后主键冲突
def require_auth(f):
"""接口认证装饰器:验证 X-Auth-Token从数据库查询支持多 worker 共享)"""
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('X-Auth-Token') or request.args.get('token')
conn = get_db()
row = conn.execute('SELECT soldier_id, name, unit, role FROM tokens WHERE token=?', (token,)).fetchone()
conn.close()
if not row:
return jsonify({"ok": False, "error": "未登录或token无效"}), 401
request.current_user = dict(row)
return f(*args, **kwargs)
return decorated
# ===== 静态文件路由 =====
@ -76,78 +189,108 @@ def js_files(filename):
# ===== REST API: 士兵位置 =====
@app.route("/api/soldier/location", methods=["POST"])
@require_auth
def update_soldier_location():
data = request.get_json(force=True)
sid = data.get("id")
if not sid:
return jsonify({"ok": False, "error": "missing id"}), 400
soldiers[sid] = {
"id": sid,
"name": data.get("name", sid),
"lat": float(data.get("lat", 0)),
"lng": float(data.get("lng", 0)),
"updated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
conn = get_db()
conn.execute('''
INSERT INTO soldiers (soldier_id, name, lat, lng, updated_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(soldier_id) DO UPDATE SET
name=excluded.name, lat=excluded.lat, lng=excluded.lng, updated_at=excluded.updated_at
''', (sid, data.get("name", sid), float(data.get("lat", 0)), float(data.get("lng", 0)),
datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
conn.commit()
conn.close()
return jsonify({"ok": True})
@app.route("/api/soldiers", methods=["GET"])
def get_soldiers():
return jsonify({"soldiers": list(soldiers.values())})
conn = get_db()
rows = conn.execute('SELECT * FROM soldiers').fetchall()
conn.close()
return jsonify({"soldiers": [dict(r) for r in rows]})
# ===== REST API: 危险区域 =====
@app.route("/api/danger-zones", methods=["GET"])
def get_danger_zones():
return jsonify({"danger_zones": danger_zones})
conn = get_db()
rows = conn.execute('SELECT * FROM danger_zones').fetchall()
conn.close()
return jsonify({"danger_zones": [dict(r) for r in rows]})
@app.route("/api/danger-zones", methods=["POST"])
@require_auth
def add_danger_zone():
global _danger_id_counter
data = request.get_json(force=True)
_danger_id_counter += 1
zone = {
"id": _danger_id_counter,
"lat": float(data.get("lat", 0)),
"lng": float(data.get("lng", 0)),
"radius": float(data.get("radius", 50)),
"description": data.get("description", "危险区域"),
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
danger_zones.append(zone)
return jsonify({"ok": True, "id": zone["id"]})
conn = get_db()
cur = conn.execute('''
INSERT INTO danger_zones (lat, lng, radius, description, created_at)
VALUES (?, ?, ?, ?, ?)
''', (zone["lat"], zone["lng"], zone["radius"], zone["description"], zone["created_at"]))
zone_id = cur.lastrowid
conn.commit()
conn.close()
return jsonify({"ok": True, "id": zone_id})
# ===== REST API: 物资需求单兵APP =====
def _next_demand_id():
"""从数据库查询当前最大需求ID生成下一个避免内存计数器重启归零导致主键冲突"""
conn = get_db()
row = conn.execute("SELECT id FROM demands WHERE id LIKE 'REQ-%' ORDER BY id DESC LIMIT 1").fetchone()
conn.close()
if row:
try:
num = int(row["id"].replace("REQ-", ""))
except ValueError:
num = 0
else:
num = 0
return "REQ-" + str(num + 1).zfill(3)
@app.route("/api/demand", methods=["POST"])
@require_auth
def post_demand():
global _demand_id_counter
data = request.get_json(force=True)
_demand_id_counter += 1
# 获取投放点坐标(如果有)
drop_point = data.get("drop_point") or {}
lat = float(drop_point.get("lat", 0)) if isinstance(drop_point, dict) else 0
lng = float(drop_point.get("lng", 0)) if isinstance(drop_point, dict) else 0
# 如果APP没传坐标尝试从士兵位置获取
soldier_id = data.get("soldier_id", "unknown")
if lat == 0 and lng == 0 and soldier_id in soldiers:
lat = soldiers[soldier_id].get("lat", 0)
lng = soldiers[soldier_id].get("lng", 0)
if lat == 0 and lng == 0:
conn = get_db()
row = conn.execute('SELECT lat, lng FROM soldiers WHERE soldier_id=?', (soldier_id,)).fetchone()
conn.close()
if row:
lat, lng = row["lat"], row["lng"]
# 构建物资清单描述
qty = int(data.get("quantity", 0))
unit = data.get("unit", "")
item_type = data.get("type", "未知")
items = f"{item_type} × {qty}{unit}"
demand_id = _next_demand_id()
demand = {
"id": "REQ-" + str(_demand_id_counter).zfill(3),
"id": demand_id,
"soldier_id": soldier_id,
"soldier_name": data.get("soldier_name", "未知"),
"type": item_type + "补给",
@ -160,25 +303,38 @@ def post_demand():
"lng": lng,
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
_demands.append(demand)
conn = get_db()
conn.execute('''
INSERT INTO demands (id, soldier_id, soldier_name, type, items, quantity, unit, urgency, status, lat, lng, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (demand["id"], demand["soldier_id"], demand["soldier_name"], demand["type"],
demand["items"], demand["quantity"], demand["unit"], demand["urgency"],
demand["status"], demand["lat"], demand["lng"], demand["created_at"]))
conn.commit()
conn.close()
return jsonify({"ok": True, "id": demand["id"]})
@app.route("/api/demands", methods=["GET"])
def get_demands():
soldier_id = request.args.get("soldier_id")
conn = get_db()
if soldier_id:
return jsonify({"demands": [d for d in _demands if d["soldier_id"] == soldier_id]})
# 返回所有待处理需求(供电脑端使用)
pending = [d for d in _demands if d["status"] == "待处理"]
return jsonify({"demands": pending})
rows = conn.execute('SELECT * FROM demands WHERE soldier_id=?', (soldier_id,)).fetchall()
else:
rows = conn.execute("SELECT * FROM demands WHERE status='待处理'").fetchall()
conn.close()
return jsonify({"demands": [dict(r) for r in rows]})
@app.route("/api/demands/<int:demand_id>", methods=["GET"])
@app.route("/api/demands/<demand_id>", methods=["GET"])
@require_auth
def get_demand(demand_id):
for d in _demands:
if d["id"] == demand_id:
return jsonify({"demand": d})
conn = get_db()
row = conn.execute('SELECT * FROM demands WHERE id=?', (demand_id,)).fetchone()
conn.close()
if row:
return jsonify({"demand": dict(row)})
return jsonify({"ok": False, "error": "not found"}), 404
@ -190,24 +346,37 @@ def get_drop_points():
@app.route("/api/drop-point", methods=["POST"])
@require_auth
def post_drop_point():
data = request.get_json(force=True)
# 记录士兵选择的投放点
return jsonify({"ok": True})
# ===== REST API: 任务状态 =====
def _next_task_id():
"""从数据库查询当前最大任务ID生成下一个"""
conn = get_db()
row = conn.execute("SELECT id FROM tasks WHERE id LIKE '#%' ORDER BY id DESC LIMIT 1").fetchone()
conn.close()
if row:
try:
num = int(row["id"].replace("#", ""))
except ValueError:
num = 0
else:
num = 0
return "#" + str(num + 1).zfill(3)
@app.route("/api/task/dispatch", methods=["POST"])
@require_auth
def dispatch_task():
"""电脑端调度任务时调用"""
global _task_id_counter
data = request.get_json(force=True)
soldier_id = data.get("soldier_id", "unknown")
_task_id_counter += 1
task = {
"id": "#" + str(_task_id_counter).zfill(3),
"id": _next_task_id(),
"soldier_id": soldier_id,
"soldier_name": data.get("soldier_name", "未知"),
"type": data.get("type", "未知"),
@ -224,57 +393,76 @@ def dispatch_task():
"safety_score": data.get("safety_score", 90),
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
_tasks[soldier_id] = task
conn = get_db()
conn.execute('''
INSERT INTO tasks (id, soldier_id, soldier_name, type, status, progress, eta, remain_time,
start_name, target_name, start_lat, start_lng, end_lat, end_lng, safety_score, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (task["id"], task["soldier_id"], task["soldier_name"], task["type"], task["status"],
task["progress"], task["eta"], task["remain_time"], task["start_name"], task["target_name"],
task["start_lat"], task["start_lng"], task["end_lat"], task["end_lng"],
task["safety_score"], task["created_at"]))
conn.commit()
conn.close()
# 更新需求状态为"已调度"
demand_id = data.get("demand_id")
if demand_id:
for d in _demands:
if d["id"] == demand_id:
d["status"] = "已调度"
break
conn = get_db()
conn.execute("UPDATE demands SET status='已调度' WHERE id=?", (demand_id,))
conn.commit()
conn.close()
return jsonify({"ok": True, "task": task})
@app.route("/api/task/current", methods=["GET"])
@require_auth
def get_current_task():
soldier_id = request.args.get("soldier_id", "soldier_001")
task = _tasks.get(soldier_id)
if not task:
# 没有任务时返回模拟数据
task = {
"id": "#--",
"status": "无任务",
"progress": 0,
"eta": "--",
"remain_time": "--",
"start_name": "--",
"target_name": "--",
"safety_score": 0
}
return jsonify({"task": task})
conn = get_db()
row = conn.execute('SELECT * FROM tasks WHERE soldier_id=?', (soldier_id,)).fetchone()
conn.close()
if row:
return jsonify({"task": dict(row)})
return jsonify({"task": {
"id": "#--", "status": "无任务", "progress": 0, "eta": "--",
"remain_time": "--", "start_name": "--", "target_name": "--", "safety_score": 0
}})
@app.route("/api/task/update", methods=["POST"])
@require_auth
def update_task():
"""更新任务状态进度、ETA等"""
data = request.get_json(force=True)
soldier_id = data.get("soldier_id", "unknown")
task = _tasks.get(soldier_id)
if not task:
conn = get_db()
row = conn.execute('SELECT * FROM tasks WHERE soldier_id=?', (soldier_id,)).fetchone()
if not row:
conn.close()
return jsonify({"ok": False, "error": "无进行中的任务"}), 404
updates = []
params = []
if "progress" in data:
task["progress"] = int(data["progress"])
updates.append("progress=?")
params.append(int(data["progress"]))
if "status" in data:
task["status"] = data["status"]
updates.append("status=?")
params.append(data["status"])
if "eta" in data:
task["eta"] = data["eta"]
updates.append("eta=?")
params.append(data["eta"])
if "remain_time" in data:
task["remain_time"] = data["remain_time"]
return jsonify({"ok": True, "task": task})
updates.append("remain_time=?")
params.append(data["remain_time"])
if updates:
params.append(soldier_id)
conn.execute(f"UPDATE tasks SET {','.join(updates)} WHERE soldier_id=?", params)
conn.commit()
conn.close()
return jsonify({"ok": True})
# ===== REST API: 无人机状态 =====
@ -292,32 +480,28 @@ def get_drone_logs():
# ===== REST API: SOS求救 =====
@app.route("/api/sos", methods=["POST"])
@require_auth
def post_sos():
data = request.get_json(force=True)
alert = {
"soldier_id": data.get("soldier_id", "unknown"),
"soldier_name": data.get("soldier_name", "未知"),
"lat": float(data.get("lat", 0)),
"lng": float(data.get("lng", 0)),
"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"handled": False
}
_sos_alerts.append(alert)
alert_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conn = get_db()
conn.execute('''
INSERT INTO sos_alerts (soldier_id, soldier_name, lat, lng, alert_time, handled)
VALUES (?, ?, ?, ?, ?, ?)
''', (data.get("soldier_id", "unknown"), data.get("soldier_name", "未知"),
float(data.get("lat", 0)), float(data.get("lng", 0)), alert_time, 0))
# 同时标记为危险区域
global _danger_id_counter
_danger_id_counter += 1
danger_zones.append({
"id": _danger_id_counter,
"lat": alert["lat"],
"lng": alert["lng"],
"radius": 100,
"description": f"士兵求救: {alert['soldier_name']}",
"created_at": alert["time"]
})
return jsonify({"ok": True, "alert_id": len(_sos_alerts)})
conn.execute('''
INSERT INTO danger_zones (lat, lng, radius, description, created_at)
VALUES (?, ?, ?, ?, ?)
''', (float(data.get("lat", 0)), float(data.get("lng", 0)), 100,
f"士兵求救: {data.get('soldier_name', '未知')}", alert_time))
conn.commit()
conn.close()
return jsonify({"ok": True})
# ===== REST API: 账号系统 =====
# ===== REST API: 账号系统(核心改造:持久化 + Token + 密码加密) =====
@app.route("/api/auth/register", methods=["POST"])
def auth_register():
@ -330,15 +514,20 @@ def auth_register():
if not sid or not pwd or not name:
return jsonify({"ok": False, "error": "士兵编号、密码、姓名不能为空"}), 400
if sid in _accounts:
conn = get_db()
existing = conn.execute('SELECT 1 FROM accounts WHERE soldier_id=?', (sid,)).fetchone()
if existing:
conn.close()
return jsonify({"ok": False, "error": "该士兵编号已注册"}), 409
_accounts[sid] = {
"password": pwd,
"name": name,
"unit": unit,
"role": role
}
password_hash = generate_password_hash(pwd)
conn.execute('''
INSERT INTO accounts (soldier_id, password_hash, name, unit, role, created_at)
VALUES (?, ?, ?, ?, ?, ?)
''', (sid, password_hash, name, unit, role, datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
conn.commit()
conn.close()
return jsonify({"ok": True, "message": "注册成功"})
@ -348,31 +537,61 @@ def auth_login():
sid = data.get("soldier_id", "").strip()
pwd = data.get("password", "").strip()
account = _accounts.get(sid)
if not account:
conn = get_db()
row = conn.execute('SELECT * FROM accounts WHERE soldier_id=?', (sid,)).fetchone()
if not row:
conn.close()
return jsonify({"ok": False, "error": "士兵编号不存在"}), 404
if account["password"] != pwd:
if not check_password_hash(row["password_hash"], pwd):
conn.close()
return jsonify({"ok": False, "error": "密码错误"}), 401
token = secrets.token_hex(16)
conn.execute('''
INSERT INTO tokens (token, soldier_id, name, unit, role, created_at)
VALUES (?, ?, ?, ?, ?, ?)
''', (token, sid, row["name"], row["unit"], row["role"], datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
conn.commit()
conn.close()
return jsonify({
"ok": True,
"token": token,
"soldier_id": sid,
"name": account["name"],
"unit": account["unit"],
"role": account["role"]
"name": row["name"],
"unit": row["unit"],
"role": row["role"]
})
@app.route("/api/auth/accounts", methods=["GET"])
def get_accounts():
conn = get_db()
rows = conn.execute('SELECT soldier_id, name, unit, role FROM accounts').fetchall()
conn.close()
return jsonify({
"accounts": [
{"soldier_id": k, "name": v["name"], "unit": v["unit"], "role": v["role"]}
for k, v in _accounts.items()
]
"accounts": [dict(r) for r in rows]
})
@app.route("/api/auth/me", methods=["GET"])
def auth_me():
token = request.headers.get('X-Auth-Token') or request.args.get('token')
conn = get_db()
row = conn.execute('SELECT soldier_id, name, unit, role FROM tokens WHERE token=?', (token,)).fetchone()
conn.close()
if not row:
return jsonify({"ok": False, "error": "未登录"}), 401
return jsonify({"ok": True, "user": dict(row)})
@app.route("/api/ping", methods=["GET"])
def ping():
return jsonify({"ok": True, "message": "智途投送后端运行正常"})
if __name__ == "__main__":
import sys
import io

@ -0,0 +1,192 @@
#!/bin/bash
# =============================================================================
# 智途投送 - 云端一键部署脚本
# 适用系统Ubuntu 20.04 / 22.04 LTS
# 运行方式root 用户执行 bash deploy.sh
# =============================================================================
set -e # 遇到错误立即退出
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
PROJECT_DIR="/opt/zhitu"
SERVICE_NAME="zhitu"
NGINX_CONF="/etc/nginx/sites-available/zhitu"
echo -e "${GREEN}=========================================${NC}"
echo -e "${GREEN} 智途投送 - 云端一键部署脚本${NC}"
echo -e "${GREEN}=========================================${NC}"
echo ""
# 检查是否 root
if [ "$EUID" -ne 0 ]; then
echo -e "${RED}错误:请使用 root 用户运行此脚本${NC}"
echo "提示:先执行 sudo -i 切换到 root再运行 bash deploy.sh"
exit 1
fi
# 获取脚本所在目录(即后端代码目录)
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
echo -e "${YELLOW}[INFO]${NC} 后端代码目录: $SCRIPT_DIR"
# =============================================================================
# 步骤1: 更新系统并安装依赖
# =============================================================================
echo ""
echo -e "${GREEN}[1/8] 更新系统软件源...${NC}"
apt update -qq
echo -e "${GREEN}[2/8] 安装 Python3、Nginx、防火墙...${NC}"
apt install -y -qq python3 python3-pip python3-venv nginx curl ufw
# =============================================================================
# 步骤2: 创建项目目录并复制代码
# =============================================================================
echo -e "${GREEN}[3/8] 创建项目目录并复制代码...${NC}"
mkdir -p "$PROJECT_DIR"
# 复制所有文件(排除可能的旧数据库,避免覆盖)
rsync -av --exclude='zhitu.db' "$SCRIPT_DIR/" "$PROJECT_DIR/" 2>/dev/null || cp -r "$SCRIPT_DIR"/* "$PROJECT_DIR/"
# 如果本地有开发数据库,可以选择性复制(提示用户)
if [ -f "$SCRIPT_DIR/zhitu.db" ]; then
echo -e "${YELLOW}[提示]${NC} 检测到本地存在 zhitu.db 数据库文件"
read -p "是否将本地数据库复制到服务器?(y/n): " copy_db
if [ "$copy_db" = "y" ] || [ "$copy_db" = "Y" ]; then
cp "$SCRIPT_DIR/zhitu.db" "$PROJECT_DIR/zhitu.db"
echo -e "${GREEN}已复制数据库${NC}"
fi
fi
# =============================================================================
# 步骤3: 安装 Python 依赖
# =============================================================================
echo -e "${GREEN}[4/8] 安装 Python 依赖...${NC}"
cd "$PROJECT_DIR"
pip3 install -q -r requirements.txt
# =============================================================================
# 步骤4: 创建 Systemd 服务(比 Supervisor 更现代、更稳定)
# =============================================================================
echo -e "${GREEN}[5/8] 创建系统服务...${NC}"
cat > /etc/systemd/system/${SERVICE_NAME}.service << EOF
[Unit]
Description=智途投送后端服务 (ZhiTu Flask Backend)
After=network.target
[Service]
Type=simple
User=root
WorkingDirectory=${PROJECT_DIR}
ExecStart=/usr/local/bin/gunicorn -w 2 -b 127.0.0.1:5000 app:app
Restart=always
RestartSec=5
StandardOutput=append:/var/log/zhitu.log
StandardError=append:/var/log/zhitu.log
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
systemctl enable ${SERVICE_NAME}
# =============================================================================
# 步骤5: 配置 Nginx 反向代理
# =============================================================================
echo -e "${GREEN}[6/8] 配置 Nginx...${NC}"
cat > "$NGINX_CONF" << 'EOF'
server {
listen 80 default_server;
listen [::]:80 default_server;
server_name _;
client_max_body_size 20M;
location / {
proxy_pass http://127.0.0.1:5000;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
}
# 静态文件直接由 Nginx 处理(提高效率)
location /static {
alias /opt/zhitu/static;
expires 7d;
}
}
EOF
# 启用配置,禁用默认站点
rm -f /etc/nginx/sites-enabled/default
ln -sf "$NGINX_CONF" /etc/nginx/sites-enabled/zhitu
# 测试 Nginx 配置
nginx -t
# =============================================================================
# 步骤6: 配置防火墙
# =============================================================================
echo -e "${GREEN}[7/8] 配置防火墙...${NC}"
ufw default deny incoming
ufw default allow outgoing
ufw allow 80/tcp comment 'HTTP'
ufw allow 443/tcp comment 'HTTPS'
ufw allow 22/tcp comment 'SSH'
ufw --force enable
echo -e "${YELLOW}[提示]${NC} 防火墙已启用,仅开放 22(SSH)、80(HTTP)、443(HTTPS)"
# =============================================================================
# 步骤7: 启动服务
# =============================================================================
echo -e "${GREEN}[8/8] 启动服务...${NC}"
systemctl restart ${SERVICE_NAME}
systemctl restart nginx
sleep 2
# =============================================================================
# 部署完成,输出状态
# =============================================================================
echo ""
echo -e "${GREEN}=========================================${NC}"
echo -e "${GREEN} 部署完成!${NC}"
echo -e "${GREEN}=========================================${NC}"
echo ""
# 获取公网 IP
PUBLIC_IP=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null || echo "无法获取")
SERVICE_STATUS=$(systemctl is-active ${SERVICE_NAME})
NGINX_STATUS=$(systemctl is-active nginx)
echo -e "后端服务状态 : ${SERVICE_STATUS}"
echo -e "Nginx 状态 : ${NGINX_STATUS}"
echo -e "公网 IP : ${PUBLIC_IP}"
echo ""
echo -e "访问地址:"
echo -e " HTTP : ${YELLOW}http://${PUBLIC_IP}${NC}"
echo -e " 或 : ${YELLOW}http://你的域名${NC} (如果已配置域名)"
echo ""
echo -e "常用命令:"
echo -e " 查看后端日志 : ${YELLOW}journalctl -u zhitu -f${NC}"
echo -e " 重启后端 : ${YELLOW}systemctl restart zhitu${NC}"
echo -e " 查看服务状态 : ${YELLOW}systemctl status zhitu${NC}"
echo -e " 数据库位置 : ${YELLOW}${PROJECT_DIR}/zhitu.db${NC}"
echo ""
echo -e "${YELLOW}下一步:${NC}"
echo -e " 1. 打开单兵终端 APP将服务器地址配置为: ${GREEN}http://${PUBLIC_IP}${NC}"
echo -e " 2. 如需 HTTPS请先购买域名然后运行: ${GREEN}certbot --nginx${NC}"
echo ""

@ -0,0 +1,212 @@
#!/bin/bash
# =============================================================================
# 智途投送 - 安全部署脚本(自动检测冲突、自适应端口)
# 适用:已部署其他项目的共享服务器
# =============================================================================
set -e
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m'
PROJECT_DIR="/opt/zhitu"
SERVICE_NAME="zhitu"
echo -e "${GREEN}=========================================${NC}"
echo -e "${GREEN} 智途投送 - 安全部署脚本${NC}"
echo -e "${GREEN} (自动检测冲突 + 自适应端口)${NC}"
echo -e "${GREEN}=========================================${NC}"
echo ""
if [ "$EUID" -ne 0 ]; then
echo -e "${RED}错误:请使用 root 用户运行${NC}"
exit 1
fi
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
# =============================================================================
# 自动检测端口占用,选择可用端口
# =============================================================================
echo -e "${BLUE}[检测]${NC} 检查端口占用情况..."
BASE_PORT=5000
PORT=$BASE_PORT
while ss -tlnp 2>/dev/null | grep -q ":$PORT\b" || netstat -tlnp 2>/dev/null | grep -q ":$PORT\b"; do
PORT=$((PORT + 1))
if [ $PORT -gt 5100 ]; then
echo -e "${RED}错误5000-5100 端口全部被占用,请手动清理${NC}"
exit 1
fi
done
echo -e "${GREEN} ✓ 后端将使用端口: $PORT${NC}"
# =============================================================================
# 检测 Nginx 状态
# =============================================================================
NGINX_ACTIVE=false
if systemctl is-active --quiet nginx 2>/dev/null || service nginx status 2>/dev/null | grep -q running; then
NGINX_ACTIVE=true
echo -e "${YELLOW} ! 检测到 Nginx 已在运行(其他项目正在使用 80 端口)${NC}"
echo -e "${YELLOW} 智途投送将直接暴露在高端口 $PORT,不占用 80 端口${NC}"
else
echo -e "${GREEN} ✓ 80 端口空闲,将配置 Nginx 反向代理${NC}"
fi
# =============================================================================
# 检测 /opt/zhitu 是否已存在(防止覆盖其他项目)
# =============================================================================
if [ -d "$PROJECT_DIR" ]; then
echo -e "${YELLOW}[警告]${NC} 检测到目录 $PROJECT_DIR 已存在"
BACKUP_DIR="${PROJECT_DIR}.backup.$(date +%Y%m%d%H%M%S)"
echo -e "${YELLOW} 自动备份到: $BACKUP_DIR${NC}"
mv "$PROJECT_DIR" "$BACKUP_DIR"
fi
# =============================================================================
# 安装依赖
# =============================================================================
echo ""
echo -e "${GREEN}[1/5] 安装环境依赖...${NC}"
apt update -qq
apt install -y -qq python3 python3-pip python3-venv curl
# =============================================================================
# 复制代码
# =============================================================================
echo -e "${GREEN}[2/5] 复制后端代码...${NC}"
mkdir -p "$PROJECT_DIR"
rsync -av --exclude='zhitu.db' "$SCRIPT_DIR/" "$PROJECT_DIR/" 2>/dev/null || cp -r "$SCRIPT_DIR"/* "$PROJECT_DIR/"
# 如果本地有数据库,询问是否复制
if [ -f "$SCRIPT_DIR/zhitu.db" ]; then
echo -e "${YELLOW}[提示]${NC} 检测到本地数据库文件"
# 非交互式环境直接复制(课程服务器通常自动化部署)
cp "$SCRIPT_DIR/zhitu.db" "$PROJECT_DIR/zhitu.db" 2>/dev/null && echo -e "${GREEN} ✓ 已复制数据库${NC}" || true
fi
# =============================================================================
# 安装 Python 依赖
# =============================================================================
echo -e "${GREEN}[3/5] 创建虚拟环境并安装依赖...${NC}"
cd "$PROJECT_DIR"
# Ubuntu 24.04 禁止直接 pip 安装到系统,使用虚拟环境
if [ ! -d "$PROJECT_DIR/venv" ]; then
python3 -m venv venv
fi
source "$PROJECT_DIR/venv/bin/activate"
pip install -q -r requirements.txt
# =============================================================================
# 创建 Systemd 服务
# =============================================================================
echo -e "${GREEN}[4/5] 创建后端服务...${NC}"
# 如果存在旧服务,先停止
systemctl stop ${SERVICE_NAME} 2>/dev/null || true
systemctl disable ${SERVICE_NAME} 2>/dev/null || true
rm -f /etc/systemd/system/${SERVICE_NAME}.service
cat > /etc/systemd/system/${SERVICE_NAME}.service << EOF
[Unit]
Description=智途投送后端服务 (Port $PORT)
After=network.target
[Service]
Type=simple
User=root
WorkingDirectory=${PROJECT_DIR}
ExecStart=${PROJECT_DIR}/venv/bin/gunicorn -w 2 -b 0.0.0.0:${PORT} app:app
Restart=always
RestartSec=5
StandardOutput=append:/var/log/zhitu.log
StandardError=append:/var/log/zhitu.log
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
systemctl enable ${SERVICE_NAME}
systemctl restart ${SERVICE_NAME}
sleep 2
# =============================================================================
# 配置 Nginx仅在 Nginx 未运行时)
# =============================================================================
if [ "$NGINX_ACTIVE" = false ]; then
echo -e "${GREEN}[5/5] 配置 Nginx...${NC}"
apt install -y -qq nginx
cat > /etc/nginx/sites-available/zhitu << EOF
server {
listen 80 default_server;
listen [::]:80 default_server;
server_name _;
client_max_body_size 20M;
location / {
proxy_pass http://127.0.0.1:${PORT};
proxy_http_version 1.1;
proxy_set_header Host \$host;
proxy_set_header X-Real-IP \$remote_addr;
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
}
}
EOF
rm -f /etc/nginx/sites-enabled/default
ln -sf /etc/nginx/sites-available/zhitu /etc/nginx/sites-enabled/zhitu
nginx -t && systemctl restart nginx
# 防火墙
ufw allow 80/tcp 2>/dev/null || true
ufw allow 443/tcp 2>/dev/null || true
ufw allow 22/tcp 2>/dev/null || true
else
echo -e "${GREEN}[5/5] 跳过 Nginx 配置(已被其他项目占用)${NC}"
# 只开放智途使用的端口
ufw allow ${PORT}/tcp 2>/dev/null || true
fi
# =============================================================================
# 输出结果
# =============================================================================
echo ""
echo -e "${GREEN}=========================================${NC}"
echo -e "${GREEN} 部署完成!${NC}"
echo -e "${GREEN}=========================================${NC}"
echo ""
PUBLIC_IP=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null || echo "你的服务器IP")
SERVICE_STATUS=$(systemctl is-active ${SERVICE_NAME} 2>/dev/null || echo "unknown")
echo -e "后端服务状态 : ${SERVICE_STATUS}"
echo -e "使用端口 : ${YELLOW}${PORT}${NC}"
echo -e "公网 IP : ${PUBLIC_IP}"
echo ""
if [ "$NGINX_ACTIVE" = true ]; then
echo -e "访问地址(直接端口):"
echo -e " ${YELLOW}http://${PUBLIC_IP}:${PORT}${NC}"
echo ""
echo -e "${YELLOW}提示:${NC} 由于 80 端口被其他项目占用,"
echo -e " 单兵终端请配置为: ${GREEN}http://${PUBLIC_IP}:${PORT}${NC}"
else
echo -e "访问地址:"
echo -e " HTTP : ${YELLOW}http://${PUBLIC_IP}${NC}"
echo -e " 或 : ${YELLOW}http://${PUBLIC_IP}:${PORT}${NC}"
fi
echo ""
echo -e "常用命令:"
echo -e " 查看日志 : ${YELLOW}journalctl -u zhitu -f${NC}"
echo -e " 重启服务 : ${YELLOW}systemctl restart zhitu${NC}"
echo -e " 查看状态 : ${YELLOW}systemctl status zhitu${NC}"
echo -e " 数据库位置 : ${YELLOW}${PROJECT_DIR}/zhitu.db${NC}"
echo ""

@ -1,2 +1,3 @@
flask==3.0.0
flask-cors==5.0.0
gunicorn==23.0.0

@ -0,0 +1,146 @@
# 智途投送后端操作手册
> 服务器:阿里云 ECS `121.41.216.243`Ubuntu 24.04
> 部署路径:`/opt/zhitu/`
> 服务名:`zhitu.service`
> 数据库SQLite `/opt/zhitu/zhitu.db`
---
## 1. 连接服务器
```bash
ssh root@121.41.216.243
# 输入密码
```
---
## 2. 服务启停(最常用)
| 操作 | 命令 |
|------|------|
| 查看状态 | `systemctl status zhitu` |
| 启动服务 | `systemctl start zhitu` |
| 停止服务 | `systemctl stop zhitu` |
| 重启服务 | `systemctl restart zhitu` |
| 重载配置 | `systemctl daemon-reload` |
| 开机自启 | `systemctl enable zhitu`(已配置)|
---
## 3. 查看日志(排错必备)
```bash
# 查看最新 50 行日志
journalctl -u zhitu -n 50
# 实时跟踪日志(调试时用,按 Ctrl+C 退出)
journalctl -u zhitu -f
```
---
## 4. 改 Worker 数量为 1避免 SQLite 并发锁)
当前配置为 2 个 workerSQLite 并发写入时可能锁竞争导致请求卡住,建议改为 1 个:
```bash
sed -i 's/-w 2/-w 1/' /etc/systemd/system/zhitu.service
systemctl daemon-reload
systemctl restart zhitu
systemctl status zhitu
```
---
## 5. 备份数据库
数据库文件在 `/opt/zhitu/zhitu.db`,课程结束前务必备份:
```bash
# 方式一:在服务器上备份到 root 目录
cp /opt/zhitu/zhitu.db ~/zhitu_backup_$(date +%Y%m%d_%H%M%S).db
# 方式二:下载到本地电脑(在本地终端执行)
scp root@121.41.216.243:/opt/zhitu/zhitu.db D:\zhitu_backup.db
```
---
## 6. 修改后端代码后热更新
改完 `app.py` 后需要重启生效:
```bash
# 先检查语法
python3 -m py_compile /opt/zhitu/app.py
# 语法通过后重启
systemctl restart zhitu
```
---
## 7. 常用验证命令
```bash
# 测试后端是否存活
curl http://121.41.216.243/api/ping
# 查看所有需求(无需登录)
curl http://121.41.216.243/api/demands
# 查看账号列表(无需登录)
curl http://121.41.216.243/api/auth/accounts
# 直接查看数据库中的 token调试用
sqlite3 /opt/zhitu/zhitu.db "SELECT * FROM tokens;"
# 查看所有需求记录
sqlite3 /opt/zhitu/zhitu.db "SELECT * FROM demands;"
```
---
## 8. 网络/端口检查
```bash
# 查看 5000 端口是否在监听
ss -tlnp | grep 5000
# 查看 Nginx 状态
systemctl status nginx
# 查看防火墙状态
ufw status
```
---
## 9. 常见问题速查
| 现象 | 可能原因 | 解决 |
|------|----------|------|
| APP 上报 500 | `_demand_id_counter` 内存计数器归零导致主键冲突 | 已修复:改为数据库查询生成 ID |
| APP 上报超时/卡住 | Gunicorn 多 worker + SQLite 并发锁竞争 | 改 `-w 1` |
| APP 登录 401 | Token 跨 worker 不共享 / localStorage 缓存旧地址 | 已修复Token 存数据库;检查 APP 服务器地址设置 |
| 电脑端首页 404 | 前端静态文件未上传到 `/opt/zhitu/` | 确保 `index.html`、`css/`、`js/` 都在部署目录 |
| 服务启动失败 | 语法错误 / 端口被占 | `python3 -m py_compile /opt/zhitu/app.py` 检查语法 |
---
## 10. 关键文件路径
| 文件/目录 | 路径 |
|-----------|------|
| 后端主程序 | `/opt/zhitu/app.py` |
| 数据库 | `/opt/zhitu/zhitu.db` |
| 服务配置 | `/etc/systemd/system/zhitu.service` |
| Nginx 配置 | `/etc/nginx/sites-enabled/zhitu` |
| Python 虚拟环境 | `/opt/zhitu/venv/` |
| 日志输出 | `journalctl -u zhitu` |
---
*手册生成时间2026-05-23*
Loading…
Cancel
Save