Compare commits
4 Commits
caoweiqion
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
4013ca06fc | 1 week ago |
|
|
1cb26cb6eb | 1 week ago |
|
|
7e932c82cc | 1 month ago |
|
|
2f6e88a6a4 | 1 month ago |
@ -0,0 +1,38 @@
|
|||||||
|
# C4ISR Server Environment Variables
|
||||||
|
|
||||||
|
# MySQL
|
||||||
|
MYSQL_ROOT_PASSWORD=go2_patrol_2026
|
||||||
|
MYSQL_DATABASE=go2_patrol
|
||||||
|
MYSQL_USER=patrol_admin
|
||||||
|
MYSQL_PASSWORD=patrol_pass_2026
|
||||||
|
|
||||||
|
# MongoDB
|
||||||
|
MONGO_URI=mongodb://mongodb:27017
|
||||||
|
MONGO_DATABASE=go2_patrol
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
REDIS_URL=redis://:redis_pass_2026@redis:6379/0
|
||||||
|
|
||||||
|
# FastAPI App Server
|
||||||
|
APP_HOST=0.0.0.0
|
||||||
|
APP_PORT=8000
|
||||||
|
APP_DEBUG=true
|
||||||
|
APP_SECRET_KEY=change-me-in-production-32chars
|
||||||
|
APP_ACCESS_TOKEN_EXPIRE_MINUTES=120
|
||||||
|
|
||||||
|
# Comm Server
|
||||||
|
COMM_HOST=0.0.0.0
|
||||||
|
COMM_PORT=8001
|
||||||
|
AI_SERVER_URL=localhost:50051
|
||||||
|
|
||||||
|
# AI Server (gRPC)
|
||||||
|
AI_HOST=0.0.0.0
|
||||||
|
AI_PORT=50051
|
||||||
|
AI_DB_URL=mysql+pymysql://patrol_admin:patrol_pass_2026@localhost:3306/go2_patrol
|
||||||
|
|
||||||
|
# JWT
|
||||||
|
JWT_SECRET_KEY=change-me-jwt-secret-32chars
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
|
||||||
|
# MapLibre
|
||||||
|
MAPTILER_KEY=your-maptiler-key-here
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package threat;
|
||||||
|
|
||||||
|
// 威胁评估服务
|
||||||
|
service ThreatAssessment {
|
||||||
|
// 威胁分类
|
||||||
|
rpc ClassifyThreat (ThreatRequest) returns (ThreatClassification);
|
||||||
|
// 威胁等级评估
|
||||||
|
rpc AssessThreatLevel (ThreatRequest) returns (ThreatLevelAssessment);
|
||||||
|
// 异常行为检测
|
||||||
|
rpc DetectAnomaly (AnomalyRequest) returns (AnomalyResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
message ThreatRequest {
|
||||||
|
string threat_id = 1;
|
||||||
|
string class_label = 2;
|
||||||
|
float confidence = 3;
|
||||||
|
double latitude = 4;
|
||||||
|
double longitude = 5;
|
||||||
|
string detected_by_dog = 6;
|
||||||
|
int64 detected_at_unix = 7;
|
||||||
|
string image_url = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ThreatClassification {
|
||||||
|
string threat_id = 1;
|
||||||
|
string threat_type = 2;
|
||||||
|
float risk_score = 3;
|
||||||
|
string description = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ThreatLevelAssessment {
|
||||||
|
string threat_id = 1;
|
||||||
|
string threat_level = 2; // LOW, MEDIUM, HIGH, CRITICAL
|
||||||
|
float confidence = 3;
|
||||||
|
string reasoning = 4;
|
||||||
|
string recommended_action = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AnomalyRequest {
|
||||||
|
string dog_id = 1;
|
||||||
|
repeated TelemetrySample samples = 2;
|
||||||
|
int64 window_start_unix = 3;
|
||||||
|
int64 window_end_unix = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TelemetrySample {
|
||||||
|
int64 timestamp_unix = 1;
|
||||||
|
double latitude = 2;
|
||||||
|
double longitude = 3;
|
||||||
|
float speed_mps = 4;
|
||||||
|
float heading_deg = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AnomalyResponse {
|
||||||
|
string dog_id = 1;
|
||||||
|
bool is_anomalous = 2;
|
||||||
|
string anomaly_type = 3; // loitering, intrusion, gathering
|
||||||
|
float anomaly_score = 4;
|
||||||
|
string description = 5;
|
||||||
|
}
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
[project]
|
||||||
|
name = "go2-ai-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "GO2 Patrol C4ISR AI Inference Server (gRPC)"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"grpcio>=1.62",
|
||||||
|
"grpcio-tools>=1.62",
|
||||||
|
"pydantic>=2.5",
|
||||||
|
"pydantic-settings>=2.2",
|
||||||
|
"structlog>=24.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
@ -0,0 +1,275 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
# Ensure ai_server/ is on sys.path so `from src import ...` works
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from src import threat_assessment_pb2 as pb2
|
||||||
|
from src import threat_assessment_pb2_grpc as pb2_grpc
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("ai_server")
|
||||||
|
|
||||||
|
# ── Database-backed rule engine ──
|
||||||
|
|
||||||
|
class RuleEngine:
|
||||||
|
REFRESH_INTERVAL = 60 # seconds
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.label_to_type: dict[str, str] = {}
|
||||||
|
self.label_risk: dict[str, float] = {}
|
||||||
|
self.type_descriptions: dict[str, str] = {}
|
||||||
|
self.thresholds: list[tuple[float, str, str, str]] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._engine = None
|
||||||
|
self._db_ok = False
|
||||||
|
self._init_db()
|
||||||
|
self.load_rules()
|
||||||
|
# Background refresh
|
||||||
|
t = threading.Thread(target=self._refresh_loop, daemon=True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
try:
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
url = os.environ.get(
|
||||||
|
"AI_DB_URL",
|
||||||
|
"mysql+pymysql://patrol_admin:patrol_pass_2026@localhost:3306/go2_patrol",
|
||||||
|
)
|
||||||
|
self._engine = create_engine(url, pool_pre_ping=True, pool_recycle=1800)
|
||||||
|
self._db_ok = True
|
||||||
|
logger.info("Rule engine DB connection ready")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Rule engine DB unavailable, using defaults: {e}")
|
||||||
|
self._db_ok = False
|
||||||
|
|
||||||
|
def load_rules(self):
|
||||||
|
defaults = self._defaults()
|
||||||
|
if not self._db_ok:
|
||||||
|
with self._lock:
|
||||||
|
self.label_to_type = defaults["label_to_type"]
|
||||||
|
self.label_risk = defaults["label_risk"]
|
||||||
|
self.type_descriptions = defaults["type_descriptions"]
|
||||||
|
self.thresholds = defaults["thresholds"]
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import text
|
||||||
|
with self._engine.connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text("SELECT rule_type, rule_condition, rule_action, is_active "
|
||||||
|
"FROM threat_rules WHERE is_active = 1 ORDER BY priority DESC")
|
||||||
|
).fetchall()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load rules from DB: {e}")
|
||||||
|
with self._lock:
|
||||||
|
self.label_to_type = defaults["label_to_type"]
|
||||||
|
self.label_risk = defaults["label_risk"]
|
||||||
|
self.type_descriptions = defaults["type_descriptions"]
|
||||||
|
self.thresholds = defaults["thresholds"]
|
||||||
|
return
|
||||||
|
|
||||||
|
label_to_type = dict(defaults["label_to_type"])
|
||||||
|
label_risk = dict(defaults["label_risk"])
|
||||||
|
type_descriptions = dict(defaults["type_descriptions"])
|
||||||
|
thresholds = list(defaults["thresholds"])
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
rule_type, condition_str, action_str, _ = row
|
||||||
|
try:
|
||||||
|
condition = json.loads(condition_str) if isinstance(condition_str, str) else condition_str
|
||||||
|
action = json.loads(action_str) if isinstance(action_str, str) else action_str
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if rule_type == "classification":
|
||||||
|
label = condition.get("class_label", "").lower().strip()
|
||||||
|
if label and "threat_type" in action:
|
||||||
|
label_to_type[label] = action["threat_type"]
|
||||||
|
if label and "description" in action:
|
||||||
|
type_descriptions[action.get("threat_type", "UNKNOWN")] = action["description"]
|
||||||
|
|
||||||
|
elif rule_type == "risk":
|
||||||
|
label = condition.get("class_label", "").lower().strip()
|
||||||
|
if label and "risk_score" in action:
|
||||||
|
label_risk[label] = float(action["risk_score"])
|
||||||
|
|
||||||
|
elif rule_type == "threshold":
|
||||||
|
if "levels" in action:
|
||||||
|
thresholds = []
|
||||||
|
for lvl in action["levels"]:
|
||||||
|
thresholds.append((
|
||||||
|
float(lvl.get("min_combined", 0)),
|
||||||
|
lvl.get("level", "LOW"),
|
||||||
|
lvl.get("reasoning", ""),
|
||||||
|
lvl.get("action", ""),
|
||||||
|
))
|
||||||
|
thresholds.sort(key=lambda x: -x[0])
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self.label_to_type = label_to_type
|
||||||
|
self.label_risk = label_risk
|
||||||
|
self.type_descriptions = type_descriptions
|
||||||
|
self.thresholds = thresholds
|
||||||
|
|
||||||
|
logger.info(f"Rules loaded: {len(label_to_type)} classifications, "
|
||||||
|
f"{len(label_risk)} risk scores, {len(thresholds)} thresholds")
|
||||||
|
|
||||||
|
def _refresh_loop(self):
|
||||||
|
while True:
|
||||||
|
time.sleep(self.REFRESH_INTERVAL)
|
||||||
|
self.load_rules()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _defaults():
|
||||||
|
return {
|
||||||
|
"label_to_type": {
|
||||||
|
"person": "INTRUSION", "human": "INTRUSION",
|
||||||
|
"vehicle": "SUSPICIOUS_VEHICLE", "car": "SUSPICIOUS_VEHICLE",
|
||||||
|
"fire": "FIRE", "smoke": "FIRE", "flame": "FIRE",
|
||||||
|
"animal": "WILDLIFE", "dog": "WILDLIFE", "cat": "WILDLIFE",
|
||||||
|
"obstacle": "OBSTACLE", "debris": "OBSTACLE",
|
||||||
|
},
|
||||||
|
"label_risk": {
|
||||||
|
"person": 0.7, "human": 0.7,
|
||||||
|
"fire": 0.95, "smoke": 0.85, "flame": 0.95,
|
||||||
|
"vehicle": 0.5, "car": 0.5,
|
||||||
|
"animal": 0.2, "dog": 0.15, "cat": 0.1,
|
||||||
|
"obstacle": 0.3, "debris": 0.25,
|
||||||
|
},
|
||||||
|
"type_descriptions": {
|
||||||
|
"INTRUSION": "检测到人员入侵",
|
||||||
|
"SUSPICIOUS_VEHICLE": "检测到可疑车辆",
|
||||||
|
"FIRE": "检测到火灾/烟雾",
|
||||||
|
"WILDLIFE": "检测到动物活动",
|
||||||
|
"OBSTACLE": "检测到路径障碍物",
|
||||||
|
"UNKNOWN": "未知威胁类型",
|
||||||
|
},
|
||||||
|
"thresholds": [
|
||||||
|
(0.7, "CRITICAL", "高置信度+高风险类别", "立即派遣拦截"),
|
||||||
|
(0.5, "HIGH", "中高置信度检测", "通知指挥员评估"),
|
||||||
|
(0.3, "MEDIUM", "中等置信度", "持续监控并记录"),
|
||||||
|
(0.0, "LOW", "低置信度检测", "记录归档"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
rules = RuleEngine()
|
||||||
|
|
||||||
|
|
||||||
|
def _assess_level(confidence: float, base_risk: float) -> tuple[str, str, str]:
|
||||||
|
combined = confidence * base_risk
|
||||||
|
with rules._lock:
|
||||||
|
for min_val, level, reasoning, action in rules.thresholds:
|
||||||
|
if combined >= min_val:
|
||||||
|
return level, reasoning, action
|
||||||
|
return "LOW", "低置信度检测", "记录归档"
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAssessmentServicer(pb2_grpc.ThreatAssessmentServicer):
|
||||||
|
|
||||||
|
def ClassifyThreat(self, request, context):
|
||||||
|
label = (request.class_label or "").lower().strip()
|
||||||
|
with rules._lock:
|
||||||
|
threat_type = rules.label_to_type.get(label, "UNKNOWN")
|
||||||
|
base_risk = rules.label_risk.get(label, 0.3)
|
||||||
|
description = rules.type_descriptions.get(threat_type, "未知威胁")
|
||||||
|
risk_score = round(request.confidence * base_risk, 3)
|
||||||
|
|
||||||
|
logger.info(f"ClassifyThreat: {label} → {threat_type} (risk={risk_score})")
|
||||||
|
|
||||||
|
return pb2.ThreatClassification(
|
||||||
|
threat_id=request.threat_id,
|
||||||
|
threat_type=threat_type,
|
||||||
|
risk_score=risk_score,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
def AssessThreatLevel(self, request, context):
|
||||||
|
label = (request.class_label or "").lower().strip()
|
||||||
|
with rules._lock:
|
||||||
|
base_risk = rules.label_risk.get(label, 0.3)
|
||||||
|
level, reasoning, action = _assess_level(request.confidence, base_risk)
|
||||||
|
|
||||||
|
logger.info(f"AssessThreatLevel: {label} → {level} (conf={request.confidence:.2f})")
|
||||||
|
|
||||||
|
return pb2.ThreatLevelAssessment(
|
||||||
|
threat_id=request.threat_id,
|
||||||
|
threat_level=level,
|
||||||
|
confidence=round(request.confidence, 3),
|
||||||
|
reasoning=reasoning,
|
||||||
|
recommended_action=action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def DetectAnomaly(self, request, context):
|
||||||
|
samples = list(request.samples)
|
||||||
|
is_anomalous = False
|
||||||
|
anomaly_type = ""
|
||||||
|
description = "行为正常"
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
if len(samples) < 2:
|
||||||
|
return pb2.AnomalyResponse(
|
||||||
|
dog_id=request.dog_id,
|
||||||
|
is_anomalous=False,
|
||||||
|
anomaly_type="",
|
||||||
|
anomaly_score=0.0,
|
||||||
|
description="样本不足,无法分析",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for loitering (staying in one area)
|
||||||
|
lats = [s.latitude for s in samples]
|
||||||
|
lons = [s.longitude for s in samples]
|
||||||
|
lat_range = max(lats) - min(lats)
|
||||||
|
lon_range = max(lons) - min(lons)
|
||||||
|
if lat_range < 0.0001 and lon_range < 0.0001 and len(samples) >= 5:
|
||||||
|
is_anomalous = True
|
||||||
|
anomaly_type = "loitering"
|
||||||
|
score = 0.6
|
||||||
|
description = f"设备在区域内长时间停留 ({len(samples)}个采样点)"
|
||||||
|
else:
|
||||||
|
# Check speed anomalies
|
||||||
|
speeds = [s.speed_mps for s in samples]
|
||||||
|
avg_speed = sum(speeds) / len(speeds) if speeds else 0
|
||||||
|
max_speed = max(speeds) if speeds else 0
|
||||||
|
if max_speed > 2.0 and avg_speed < 0.3:
|
||||||
|
is_anomalous = True
|
||||||
|
anomaly_type = "erratic_movement"
|
||||||
|
score = 0.5
|
||||||
|
description = f"异常运动模式: 突发高速 ({max_speed:.1f}m/s) 后静止"
|
||||||
|
elif avg_speed > 1.5:
|
||||||
|
is_anomalous = True
|
||||||
|
anomaly_type = "high_speed"
|
||||||
|
score = 0.4
|
||||||
|
description = f"持续高速移动: 平均 {avg_speed:.1f}m/s"
|
||||||
|
|
||||||
|
logger.info(f"DetectAnomaly: dog={request.dog_id} anomalous={is_anomalous} type={anomaly_type}")
|
||||||
|
|
||||||
|
return pb2.AnomalyResponse(
|
||||||
|
dog_id=request.dog_id,
|
||||||
|
is_anomalous=is_anomalous,
|
||||||
|
anomaly_type=anomaly_type,
|
||||||
|
anomaly_score=score,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
port = 50051
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
pb2_grpc.add_ThreatAssessmentServicer_to_server(ThreatAssessmentServicer(), server)
|
||||||
|
server.add_insecure_port(f"[::]:{port}")
|
||||||
|
server.start()
|
||||||
|
logger.info(f"AI server started on port {port}")
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve()
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: threat_assessment.proto
|
||||||
|
# Protobuf Python Version: 6.31.1
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC,
|
||||||
|
6,
|
||||||
|
31,
|
||||||
|
1,
|
||||||
|
'',
|
||||||
|
'threat_assessment.proto'
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17threat_assessment.proto\x12\x06threat\"\xb6\x01\n\rThreatRequest\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x13\n\x0b\x63lass_label\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x10\n\x08latitude\x18\x04 \x01(\x01\x12\x11\n\tlongitude\x18\x05 \x01(\x01\x12\x17\n\x0f\x64\x65tected_by_dog\x18\x06 \x01(\t\x12\x18\n\x10\x64\x65tected_at_unix\x18\x07 \x01(\x03\x12\x11\n\timage_url\x18\x08 \x01(\t\"g\n\x14ThreatClassification\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x13\n\x0bthreat_type\x18\x02 \x01(\t\x12\x12\n\nrisk_score\x18\x03 \x01(\x02\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\x83\x01\n\x15ThreatLevelAssessment\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x14\n\x0cthreat_level\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x11\n\treasoning\x18\x04 \x01(\t\x12\x1a\n\x12recommended_action\x18\x05 \x01(\t\"~\n\x0e\x41nomalyRequest\x12\x0e\n\x06\x64og_id\x18\x01 \x01(\t\x12(\n\x07samples\x18\x02 \x03(\x0b\x32\x17.threat.TelemetrySample\x12\x19\n\x11window_start_unix\x18\x03 \x01(\x03\x12\x17\n\x0fwindow_end_unix\x18\x04 \x01(\x03\"v\n\x0fTelemetrySample\x12\x16\n\x0etimestamp_unix\x18\x01 \x01(\x03\x12\x10\n\x08latitude\x18\x02 \x01(\x01\x12\x11\n\tlongitude\x18\x03 \x01(\x01\x12\x11\n\tspeed_mps\x18\x04 \x01(\x02\x12\x13\n\x0bheading_deg\x18\x05 \x01(\x02\"y\n\x0f\x41nomalyResponse\x12\x0e\n\x06\x64og_id\x18\x01 \x01(\t\x12\x14\n\x0cis_anomalous\x18\x02 \x01(\x08\x12\x14\n\x0c\x61nomaly_type\x18\x03 \x01(\t\x12\x15\n\ranomaly_score\x18\x04 \x01(\x02\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t2\xe6\x01\n\x10ThreatAssessment\x12\x45\n\x0e\x43lassifyThreat\x12\x15.threat.ThreatRequest\x1a\x1c.threat.ThreatClassification\x12I\n\x11\x41ssessThreatLevel\x12\x15.threat.ThreatRequest\x1a\x1d.threat.ThreatLevelAssessment\x12@\n\rDetectAnomaly\x12\x16.threat.AnomalyRequest\x1a\x17.threat.AnomalyResponseb\x06proto3')
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'threat_assessment_pb2', _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_THREATREQUEST']._serialized_start=36
|
||||||
|
_globals['_THREATREQUEST']._serialized_end=218
|
||||||
|
_globals['_THREATCLASSIFICATION']._serialized_start=220
|
||||||
|
_globals['_THREATCLASSIFICATION']._serialized_end=323
|
||||||
|
_globals['_THREATLEVELASSESSMENT']._serialized_start=326
|
||||||
|
_globals['_THREATLEVELASSESSMENT']._serialized_end=457
|
||||||
|
_globals['_ANOMALYREQUEST']._serialized_start=459
|
||||||
|
_globals['_ANOMALYREQUEST']._serialized_end=585
|
||||||
|
_globals['_TELEMETRYSAMPLE']._serialized_start=587
|
||||||
|
_globals['_TELEMETRYSAMPLE']._serialized_end=705
|
||||||
|
_globals['_ANOMALYRESPONSE']._serialized_start=707
|
||||||
|
_globals['_ANOMALYRESPONSE']._serialized_end=828
|
||||||
|
_globals['_THREATASSESSMENT']._serialized_start=831
|
||||||
|
_globals['_THREATASSESSMENT']._serialized_end=1061
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
@ -0,0 +1,189 @@
|
|||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from src import threat_assessment_pb2 as threat__assessment__pb2
|
||||||
|
|
||||||
|
GRPC_GENERATED_VERSION = '1.80.0'
|
||||||
|
GRPC_VERSION = grpc.__version__
|
||||||
|
_version_not_supported = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from grpc._utilities import first_version_is_lower
|
||||||
|
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||||
|
except ImportError:
|
||||||
|
_version_not_supported = True
|
||||||
|
|
||||||
|
if _version_not_supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
|
+ ' but the generated code in threat_assessment_pb2_grpc.py depends on'
|
||||||
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAssessmentStub(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.ClassifyThreat = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/ClassifyThreat',
|
||||||
|
request_serializer=threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.ThreatClassification.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.AssessThreatLevel = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/AssessThreatLevel',
|
||||||
|
request_serializer=threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.ThreatLevelAssessment.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.DetectAnomaly = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/DetectAnomaly',
|
||||||
|
request_serializer=threat__assessment__pb2.AnomalyRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.AnomalyResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAssessmentServicer(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ClassifyThreat(self, request, context):
|
||||||
|
"""威胁分类
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def AssessThreatLevel(self, request, context):
|
||||||
|
"""威胁等级评估
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def DetectAnomaly(self, request, context):
|
||||||
|
"""异常行为检测
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_ThreatAssessmentServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'ClassifyThreat': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.ClassifyThreat,
|
||||||
|
request_deserializer=threat__assessment__pb2.ThreatRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.ThreatClassification.SerializeToString,
|
||||||
|
),
|
||||||
|
'AssessThreatLevel': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.AssessThreatLevel,
|
||||||
|
request_deserializer=threat__assessment__pb2.ThreatRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.ThreatLevelAssessment.SerializeToString,
|
||||||
|
),
|
||||||
|
'DetectAnomaly': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.DetectAnomaly,
|
||||||
|
request_deserializer=threat__assessment__pb2.AnomalyRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.AnomalyResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'threat.ThreatAssessment', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('threat.ThreatAssessment', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class ThreatAssessment(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ClassifyThreat(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/ClassifyThreat',
|
||||||
|
threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.ThreatClassification.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def AssessThreatLevel(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/AssessThreatLevel',
|
||||||
|
threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.ThreatLevelAssessment.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def DetectAnomaly(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/DetectAnomaly',
|
||||||
|
threat__assessment__pb2.AnomalyRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.AnomalyResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
@ -0,0 +1,41 @@
|
|||||||
|
[project]
|
||||||
|
name = "go2-app-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "GO2 Patrol C4ISR Application Server"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.110",
|
||||||
|
"uvicorn[standard]>=0.29",
|
||||||
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
|
"asyncmy>=0.2",
|
||||||
|
"motor>=3.3",
|
||||||
|
"redis[hiredis]>=5.0",
|
||||||
|
"python-jose[cryptography]>=3.3",
|
||||||
|
"bcrypt>=4.1",
|
||||||
|
"pydantic>=2.5",
|
||||||
|
"pydantic-settings>=2.2",
|
||||||
|
"alembic>=1.13",
|
||||||
|
"structlog>=24.1",
|
||||||
|
"python-multipart>=0.0.9",
|
||||||
|
"httpx>=0.27",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.1",
|
||||||
|
"pytest-asyncio>=0.23",
|
||||||
|
"httpx>=0.27",
|
||||||
|
"ruff>=0.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
@ -0,0 +1,119 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.core.security import (
|
||||||
|
create_access_token, create_refresh_token, decode_refresh_token,
|
||||||
|
hash_password, verify_password,
|
||||||
|
)
|
||||||
|
from src.dependencies import get_current_user, get_db
|
||||||
|
from src.models import User
|
||||||
|
from src.schemas.auth import (
|
||||||
|
LoginRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
TokenResponse,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(User).where(User.username == body.username))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user or not verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(status_code=403, detail="账号已禁用")
|
||||||
|
|
||||||
|
user.last_login_at = datetime.now()
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=create_access_token(subject=user.id, role=user.role),
|
||||||
|
refresh_token=create_refresh_token(subject=user.id, role=user.role),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=TokenResponse)
|
||||||
|
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(User).where(User.username == body.username))
|
||||||
|
if result.scalar_one_or_none():
|
||||||
|
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=body.username,
|
||||||
|
password_hash=hash_password(body.password),
|
||||||
|
display_name=body.display_name,
|
||||||
|
role=body.role,
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=create_access_token(subject=user.id, role=user.role),
|
||||||
|
refresh_token=create_refresh_token(subject=user.id, role=user.role),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
payload = decode_refresh_token(body.refresh_token)
|
||||||
|
if not payload:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的刷新令牌")
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
role = payload.get("role", "VIEWER")
|
||||||
|
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=create_access_token(subject=user.id, role=user.role),
|
||||||
|
refresh_token=create_refresh_token(subject=user.id, role=user.role),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_me(
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user["user_id"]))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordBody(BaseModel):
|
||||||
|
old_password: str
|
||||||
|
new_password: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password")
|
||||||
|
async def change_password(
|
||||||
|
body: ChangePasswordBody,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user["user_id"]))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
if not verify_password(body.old_password, user.password_hash):
|
||||||
|
raise HTTPException(status_code=400, detail="旧密码错误")
|
||||||
|
user.password_hash = hash_password(body.new_password)
|
||||||
|
await db.commit()
|
||||||
|
return {"message": "密码修改成功"}
|
||||||
@ -0,0 +1,77 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.dependencies import get_current_user, get_db
|
||||||
|
from src.models import Dog
|
||||||
|
from src.schemas.dog import DogCreate, DogResponse, DogUpdate
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/dogs", tags=["dogs"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[DogResponse])
|
||||||
|
async def list_dogs(
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Dog).order_by(Dog.created_at))
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=DogResponse, status_code=201)
|
||||||
|
async def create_dog(
|
||||||
|
body: DogCreate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
dog = Dog(**body.model_dump())
|
||||||
|
db.add(dog)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(dog)
|
||||||
|
return dog
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{dog_id}", response_model=DogResponse)
|
||||||
|
async def get_dog(
|
||||||
|
dog_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Dog).where(Dog.id == dog_id))
|
||||||
|
dog = result.scalar_one_or_none()
|
||||||
|
if not dog:
|
||||||
|
raise HTTPException(status_code=404, detail="机器狗不存在")
|
||||||
|
return dog
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{dog_id}", response_model=DogResponse)
|
||||||
|
async def update_dog(
|
||||||
|
dog_id: str,
|
||||||
|
body: DogUpdate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Dog).where(Dog.id == dog_id))
|
||||||
|
dog = result.scalar_one_or_none()
|
||||||
|
if not dog:
|
||||||
|
raise HTTPException(status_code=404, detail="机器狗不存在")
|
||||||
|
|
||||||
|
for key, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(dog, key, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(dog)
|
||||||
|
return dog
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{dog_id}", status_code=204)
|
||||||
|
async def delete_dog(
|
||||||
|
dog_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Dog).where(Dog.id == dog_id))
|
||||||
|
dog = result.scalar_one_or_none()
|
||||||
|
if not dog:
|
||||||
|
raise HTTPException(status_code=404, detail="机器狗不存在")
|
||||||
|
await db.delete(dog)
|
||||||
|
await db.commit()
|
||||||
@ -0,0 +1,200 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from src.core.redis import get_redis
|
||||||
|
from src.dependencies import get_current_user, get_db
|
||||||
|
from src.models import PatrolTask, Waypoint
|
||||||
|
from src.schemas.patrol import (
|
||||||
|
PatrolCreate,
|
||||||
|
PatrolResponse,
|
||||||
|
PatrolUpdate,
|
||||||
|
WaypointResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(prefix="/patrols", tags=["patrols"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[PatrolResponse])
|
||||||
|
async def list_patrols(
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(PatrolTask).order_by(PatrolTask.created_at.desc())
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_patrol(task_id: str, db: AsyncSession) -> PatrolTask:
|
||||||
|
result = await db.execute(
|
||||||
|
select(PatrolTask)
|
||||||
|
.options(selectinload(PatrolTask.waypoints))
|
||||||
|
.where(PatrolTask.task_id == task_id)
|
||||||
|
)
|
||||||
|
task = result.scalar_one_or_none()
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="巡逻任务不存在")
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def _waypoints_to_response(task: PatrolTask) -> list[WaypointResponse]:
|
||||||
|
return [
|
||||||
|
WaypointResponse.model_validate(w)
|
||||||
|
for w in sorted(task.waypoints, key=lambda x: x.sequence_order)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _patrol_to_response(task: PatrolTask) -> PatrolResponse:
|
||||||
|
resp = PatrolResponse.model_validate(task)
|
||||||
|
resp.waypoints = _waypoints_to_response(task)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=PatrolResponse, status_code=201)
|
||||||
|
async def create_patrol(
|
||||||
|
body: PatrolCreate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
data = body.model_dump(exclude={"waypoints"})
|
||||||
|
|
||||||
|
task = PatrolTask(**data, created_by=current_user["user_id"])
|
||||||
|
db.add(task)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
for wp in (body.waypoints or []):
|
||||||
|
waypoint = Waypoint(task_id=task.task_id, **wp.model_dump())
|
||||||
|
db.add(waypoint)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
task = await _load_patrol(task.task_id, db)
|
||||||
|
return _patrol_to_response(task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}", response_model=PatrolResponse)
|
||||||
|
async def get_patrol(
|
||||||
|
task_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
return _patrol_to_response(task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{task_id}", response_model=PatrolResponse)
|
||||||
|
async def update_patrol(
|
||||||
|
task_id: str,
|
||||||
|
body: PatrolUpdate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
update_data = body.model_dump(exclude_unset=True, exclude={"waypoints"})
|
||||||
|
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(task, key, value)
|
||||||
|
|
||||||
|
if body.waypoints is not None:
|
||||||
|
for wp in task.waypoints:
|
||||||
|
await db.delete(wp)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
for wp in body.waypoints:
|
||||||
|
waypoint = Waypoint(task_id=task_id, **wp.model_dump())
|
||||||
|
db.add(waypoint)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
return _patrol_to_response(task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{task_id}/deploy", response_model=PatrolResponse)
|
||||||
|
async def deploy_patrol(
|
||||||
|
task_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
|
||||||
|
if not task.assigned_dog_id:
|
||||||
|
raise HTTPException(status_code=400, detail="任务未分配机器狗")
|
||||||
|
if task.status not in ("DRAFT", "SAVED"):
|
||||||
|
raise HTTPException(status_code=400, detail=f"任务状态 {task.status} 不可部署")
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
task.status = "DEPLOYED"
|
||||||
|
task.started_at = datetime.now()
|
||||||
|
await db.commit()
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
|
||||||
|
# Send task to comm_server via Redis pub/sub → dog
|
||||||
|
deploy_msg = {
|
||||||
|
"type": "task_deploy",
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"dog_id": task.assigned_dog_id,
|
||||||
|
"waypoints": [
|
||||||
|
{
|
||||||
|
"index": wp.sequence_order,
|
||||||
|
"latitude": float(wp.latitude),
|
||||||
|
"longitude": float(wp.longitude),
|
||||||
|
"action_type": wp.action_type,
|
||||||
|
"dwell_time_sec": wp.dwell_time_sec,
|
||||||
|
}
|
||||||
|
for wp in sorted(task.waypoints, key=lambda x: x.sequence_order)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = get_redis()
|
||||||
|
await redis.publish("comm:command", json.dumps(deploy_msg))
|
||||||
|
logger.info(f"Deployed task {task_id} to dog {task.assigned_dog_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish deploy command via Redis: {e}")
|
||||||
|
|
||||||
|
return _patrol_to_response(task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{task_id}/cancel", response_model=PatrolResponse)
|
||||||
|
async def cancel_patrol(
|
||||||
|
task_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
if task.status not in ("DEPLOYED", "EXECUTING"):
|
||||||
|
raise HTTPException(status_code=400, detail="只能取消已部署或执行中的任务")
|
||||||
|
|
||||||
|
task.status = "CANCELLED"
|
||||||
|
await db.commit()
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
|
||||||
|
cancel_msg = {
|
||||||
|
"type": "task_cancel",
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"dog_id": task.assigned_dog_id,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
redis = get_redis()
|
||||||
|
await redis.publish("comm:command", json.dumps(cancel_msg))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish cancel command: {e}")
|
||||||
|
|
||||||
|
return _patrol_to_response(task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{task_id}", status_code=204)
|
||||||
|
async def delete_patrol(
|
||||||
|
task_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
task = await _load_patrol(task_id, db)
|
||||||
|
await db.delete(task)
|
||||||
|
await db.commit()
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.api.v1.auth import router as auth_router
|
||||||
|
from src.api.v1.dogs import router as dogs_router
|
||||||
|
from src.api.v1.patrols import router as patrols_router
|
||||||
|
from src.api.v1.threats import router as threats_router
|
||||||
|
from src.api.v1.settings import router as settings_router
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router.include_router(auth_router)
|
||||||
|
router.include_router(dogs_router)
|
||||||
|
router.include_router(patrols_router)
|
||||||
|
router.include_router(threats_router)
|
||||||
|
router.include_router(settings_router)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "ok", "service": "go2-app-server"}
|
||||||
@ -0,0 +1,81 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.dependencies import get_current_user, get_db, require_role
|
||||||
|
from src.models import SystemParameter, User
|
||||||
|
from src.schemas.auth import UserResponse
|
||||||
|
from src.schemas.settings import SystemParamResponse, SystemParamUpdate
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/settings", tags=["settings"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/params", response_model=list[SystemParamResponse])
|
||||||
|
async def list_params(
|
||||||
|
section: str | None = None,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(SystemParameter).order_by(SystemParameter.param_section, SystemParameter.param_key)
|
||||||
|
if section:
|
||||||
|
q = q.where(SystemParameter.param_section == section)
|
||||||
|
result = await db.execute(q)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/params/{param_key}", response_model=SystemParamResponse)
|
||||||
|
async def update_param(
|
||||||
|
param_key: str,
|
||||||
|
body: SystemParamUpdate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(SystemParameter).where(SystemParameter.param_key == param_key))
|
||||||
|
param = result.scalar_one_or_none()
|
||||||
|
if not param:
|
||||||
|
raise HTTPException(status_code=404, detail="参数不存在")
|
||||||
|
|
||||||
|
param.param_value = body.param_value
|
||||||
|
param.updated_by = current_user["user_id"]
|
||||||
|
if body.description is not None:
|
||||||
|
param.description = body.description
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(param)
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
# ── User management (admin only) ──
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users", response_model=list[UserResponse])
|
||||||
|
async def list_users(
|
||||||
|
current_user: dict = Depends(require_role("ADMIN")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(User).order_by(User.created_at))
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/users/{user_id}", response_model=UserResponse)
|
||||||
|
async def update_user(
|
||||||
|
user_id: str,
|
||||||
|
display_name: str | None = None,
|
||||||
|
role: str | None = None,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
current_user: dict = Depends(require_role("ADMIN")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
|
||||||
|
if display_name is not None:
|
||||||
|
user.display_name = display_name
|
||||||
|
if role is not None:
|
||||||
|
user.role = role
|
||||||
|
if is_active is not None:
|
||||||
|
user.is_active = is_active
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
@ -0,0 +1,320 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.dependencies import get_current_user, get_db
|
||||||
|
from src.models import AlertEvent, ThreatAlert, ThreatArchive, ThreatMark, ThreatRule
|
||||||
|
from src.schemas.threat import (
|
||||||
|
AlertAckRequest, AlertResponse, ThreatMarkCreate, ThreatMarkResponse,
|
||||||
|
ThreatMarkUpdate, ThreatResponse, ThreatRuleCreate, ThreatRuleResponse,
|
||||||
|
ThreatRuleUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("threats_api")
|
||||||
|
router = APIRouter(prefix="/threats", tags=["threats"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _broadcast_alert(alert: ThreatAlert):
|
||||||
|
"""Publish new alert to comm_server for browser broadcast."""
|
||||||
|
try:
|
||||||
|
from src.core.redis import get_redis
|
||||||
|
redis = get_redis()
|
||||||
|
msg = {
|
||||||
|
"type": "alert",
|
||||||
|
"data": {
|
||||||
|
"alert_id": alert.alert_id,
|
||||||
|
"threat_id": alert.threat_id,
|
||||||
|
"dog_id": alert.dog_id,
|
||||||
|
"alert_level": alert.alert_level,
|
||||||
|
"status": alert.status,
|
||||||
|
"message": alert.message,
|
||||||
|
"created_at": alert.created_at.isoformat() if alert.created_at else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await redis.publish("comm:broadcast", json.dumps(msg))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to broadcast alert: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Alerts ──
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/alerts", response_model=list[AlertResponse])
|
||||||
|
async def list_alerts(
|
||||||
|
status_filter: str | None = Query(None, alias="status"),
|
||||||
|
level: str | None = Query(None, alias="level"),
|
||||||
|
limit: int = Query(50, le=200),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(ThreatAlert).order_by(ThreatAlert.created_at.desc()).limit(limit)
|
||||||
|
if status_filter:
|
||||||
|
q = q.where(ThreatAlert.status == status_filter)
|
||||||
|
if level:
|
||||||
|
q = q.where(ThreatAlert.alert_level == level)
|
||||||
|
result = await db.execute(q)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/alerts/stats")
|
||||||
|
async def alert_stats(
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
total = await db.scalar(select(func.count(ThreatAlert.alert_id)))
|
||||||
|
unack = await db.scalar(
|
||||||
|
select(func.count(ThreatAlert.alert_id)).where(
|
||||||
|
ThreatAlert.status.in_(["INITIAL", "PENDING"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
critical = await db.scalar(
|
||||||
|
select(func.count(ThreatAlert.alert_id)).where(ThreatAlert.alert_level == "CRITICAL")
|
||||||
|
)
|
||||||
|
return {"total": total or 0, "unacknowledged": unack or 0, "critical": critical or 0}
|
||||||
|
|
||||||
|
|
||||||
|
class AlertCreate(BaseModel):
|
||||||
|
dog_id: str | None = None
|
||||||
|
threat_id: str | None = None
|
||||||
|
alert_level: str # LOW, MEDIUM, HIGH, CRITICAL
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/alerts", response_model=AlertResponse, status_code=201)
|
||||||
|
async def create_alert(
|
||||||
|
body: AlertCreate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
alert = ThreatAlert(
|
||||||
|
dog_id=body.dog_id,
|
||||||
|
threat_id=body.threat_id,
|
||||||
|
alert_level=body.alert_level,
|
||||||
|
message=body.message,
|
||||||
|
)
|
||||||
|
db.add(alert)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(alert)
|
||||||
|
await _broadcast_alert(alert)
|
||||||
|
return alert
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/alerts/{alert_id}/action", response_model=AlertResponse)
|
||||||
|
async def alert_action(
|
||||||
|
alert_id: str,
|
||||||
|
body: AlertAckRequest,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatAlert).where(ThreatAlert.alert_id == alert_id))
|
||||||
|
alert = result.scalar_one_or_none()
|
||||||
|
if not alert:
|
||||||
|
raise HTTPException(status_code=404, detail="警报不存在")
|
||||||
|
|
||||||
|
user_id = current_user["user_id"]
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
if body.action == "acknowledge":
|
||||||
|
alert.status = "CONFIRMED"
|
||||||
|
alert.acknowledged_at = now
|
||||||
|
alert.acknowledged_by = user_id
|
||||||
|
event_type = "ACKNOWLEDGED"
|
||||||
|
elif body.action == "resolve":
|
||||||
|
alert.status = "RESOLVED"
|
||||||
|
alert.resolved_at = now
|
||||||
|
event_type = "RESOLVED"
|
||||||
|
elif body.action == "dismiss":
|
||||||
|
alert.status = "DISMISSED"
|
||||||
|
event_type = "DISMISSED"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail=f"未知操作: {body.action}")
|
||||||
|
|
||||||
|
event = AlertEvent(
|
||||||
|
alert_id=alert_id,
|
||||||
|
event_type=event_type,
|
||||||
|
performed_by=user_id,
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(alert)
|
||||||
|
return alert
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threat Archive ──
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/archive", response_model=list[ThreatResponse])
|
||||||
|
async def list_threats(
|
||||||
|
status_filter: str | None = Query(None, alias="status"),
|
||||||
|
level: str | None = Query(None, alias="level"),
|
||||||
|
limit: int = Query(50, le=200),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(ThreatArchive).order_by(ThreatArchive.detected_at.desc()).limit(limit)
|
||||||
|
if status_filter:
|
||||||
|
q = q.where(ThreatArchive.status == status_filter)
|
||||||
|
if level:
|
||||||
|
q = q.where(ThreatArchive.threat_level == level)
|
||||||
|
result = await db.execute(q)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/archive/{threat_id}", response_model=ThreatResponse)
|
||||||
|
async def get_threat(
|
||||||
|
threat_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatArchive).where(ThreatArchive.threat_id == threat_id))
|
||||||
|
threat = result.scalar_one_or_none()
|
||||||
|
if not threat:
|
||||||
|
raise HTTPException(status_code=404, detail="威胁记录不存在")
|
||||||
|
return threat
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threat Marks (map markers) ──
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/marks", response_model=list[ThreatMarkResponse])
|
||||||
|
async def list_marks(
|
||||||
|
mark_type: str | None = Query(None),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(ThreatMark).order_by(ThreatMark.created_at.desc())
|
||||||
|
if mark_type:
|
||||||
|
q = q.where(ThreatMark.mark_type == mark_type)
|
||||||
|
result = await db.execute(q)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/marks", response_model=ThreatMarkResponse, status_code=201)
|
||||||
|
async def create_mark(
|
||||||
|
body: ThreatMarkCreate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
mark = ThreatMark(
|
||||||
|
mark_type=body.mark_type,
|
||||||
|
latitude=body.latitude,
|
||||||
|
longitude=body.longitude,
|
||||||
|
radius_meters=body.radius_meters,
|
||||||
|
priority=body.priority,
|
||||||
|
threat_id=body.threat_id,
|
||||||
|
expires_at=body.expires_at,
|
||||||
|
notes=body.notes,
|
||||||
|
created_by=current_user["user_id"],
|
||||||
|
)
|
||||||
|
db.add(mark)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(mark)
|
||||||
|
return mark
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/marks/{mark_id}", response_model=ThreatMarkResponse)
|
||||||
|
async def update_mark(
|
||||||
|
mark_id: str,
|
||||||
|
body: ThreatMarkUpdate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatMark).where(ThreatMark.mark_id == mark_id))
|
||||||
|
mark = result.scalar_one_or_none()
|
||||||
|
if not mark:
|
||||||
|
raise HTTPException(status_code=404, detail="标记不存在")
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(mark, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(mark)
|
||||||
|
return mark
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/marks/{mark_id}", status_code=204)
|
||||||
|
async def delete_mark(
|
||||||
|
mark_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatMark).where(ThreatMark.mark_id == mark_id))
|
||||||
|
mark = result.scalar_one_or_none()
|
||||||
|
if not mark:
|
||||||
|
raise HTTPException(status_code=404, detail="标记不存在")
|
||||||
|
await db.delete(mark)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threat Rules ──
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/rules", response_model=list[ThreatRuleResponse])
|
||||||
|
async def list_rules(
|
||||||
|
rule_type: str | None = Query(None),
|
||||||
|
is_active: bool | None = Query(None),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(ThreatRule).order_by(ThreatRule.priority.desc(), ThreatRule.created_at.desc())
|
||||||
|
if rule_type:
|
||||||
|
q = q.where(ThreatRule.rule_type == rule_type)
|
||||||
|
if is_active is not None:
|
||||||
|
q = q.where(ThreatRule.is_active == is_active)
|
||||||
|
result = await db.execute(q)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rules", response_model=ThreatRuleResponse, status_code=201)
|
||||||
|
async def create_rule(
|
||||||
|
body: ThreatRuleCreate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
rule = ThreatRule(
|
||||||
|
rule_name=body.rule_name,
|
||||||
|
rule_type=body.rule_type,
|
||||||
|
rule_condition=body.rule_condition,
|
||||||
|
rule_action=body.rule_action,
|
||||||
|
priority=body.priority,
|
||||||
|
is_active=body.is_active,
|
||||||
|
)
|
||||||
|
db.add(rule)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(rule)
|
||||||
|
return rule
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/rules/{rule_id}", response_model=ThreatRuleResponse)
|
||||||
|
async def update_rule(
|
||||||
|
rule_id: str,
|
||||||
|
body: ThreatRuleUpdate,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatRule).where(ThreatRule.rule_id == rule_id))
|
||||||
|
rule = result.scalar_one_or_none()
|
||||||
|
if not rule:
|
||||||
|
raise HTTPException(status_code=404, detail="规则不存在")
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(rule, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(rule)
|
||||||
|
return rule
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/rules/{rule_id}", status_code=204)
|
||||||
|
async def delete_rule(
|
||||||
|
rule_id: str,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(select(ThreatRule).where(ThreatRule.rule_id == rule_id))
|
||||||
|
rule = result.scalar_one_or_none()
|
||||||
|
if not rule:
|
||||||
|
raise HTTPException(status_code=404, detail="规则不存在")
|
||||||
|
await db.delete(rule)
|
||||||
|
await db.commit()
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# App
|
||||||
|
app_host: str = "0.0.0.0"
|
||||||
|
app_port: int = 8000
|
||||||
|
app_debug: bool = False
|
||||||
|
app_secret_key: str = "change-me-in-production-32chars"
|
||||||
|
app_access_token_expire_minutes: int = 120
|
||||||
|
|
||||||
|
# MySQL
|
||||||
|
database_url: str = "mysql+asyncmy://patrol_admin:patrol_pass_2026@localhost:3306/go2_patrol"
|
||||||
|
|
||||||
|
# MongoDB
|
||||||
|
mongo_uri: str = "mongodb://localhost:27017"
|
||||||
|
mongo_database: str = "go2_patrol"
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
redis_url: str = "redis://:redis_pass_2026@localhost:6379/0"
|
||||||
|
|
||||||
|
# JWT
|
||||||
|
jwt_secret_key: str = "change-me-jwt-secret-32chars"
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
|
||||||
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
class AppError(Exception):
|
||||||
|
def __init__(self, message: str, status_code: int = 400):
|
||||||
|
self.message = message
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(AppError):
|
||||||
|
def __init__(self, resource: str, resource_id: str):
|
||||||
|
super().__init__(f"{resource} '{resource_id}' not found", status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
class ForbiddenError(AppError):
|
||||||
|
def __init__(self, message: str = "Permission denied"):
|
||||||
|
super().__init__(message, status_code=403)
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedError(AppError):
|
||||||
|
def __init__(self, message: str = "Invalid credentials"):
|
||||||
|
super().__init__(message, status_code=401)
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
echo=settings.app_debug,
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session_factory = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
@ -0,0 +1,23 @@
|
|||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
_client: AsyncIOMotorClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_client() -> AsyncIOMotorClient:
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = AsyncIOMotorClient(settings.mongo_uri)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_db():
|
||||||
|
return get_mongo_client()[settings.mongo_database]
|
||||||
|
|
||||||
|
|
||||||
|
async def close_mongo():
|
||||||
|
global _client
|
||||||
|
if _client:
|
||||||
|
_client.close()
|
||||||
|
_client = None
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
_redis: aioredis.Redis | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis() -> aioredis.Redis:
|
||||||
|
global _redis
|
||||||
|
if _redis is None:
|
||||||
|
_redis = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||||
|
return _redis
|
||||||
|
|
||||||
|
|
||||||
|
async def close_redis():
|
||||||
|
global _redis
|
||||||
|
if _redis:
|
||||||
|
await _redis.close()
|
||||||
|
_redis = None
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
|
return pwd_context.verify(plain, hashed)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(subject: str, role: str, expires_delta: timedelta | None = None) -> str:
|
||||||
|
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=settings.app_access_token_expire_minutes))
|
||||||
|
payload = {"sub": subject, "role": role, "exp": expire, "type": "access"}
|
||||||
|
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(subject: str, role: str) -> str:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
||||||
|
payload = {"sub": subject, "role": role, "exp": expire, "type": "refresh"}
|
||||||
|
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_access_token(token: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def decode_refresh_token(token: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
from src.core.database import async_session_factory
|
||||||
|
from src.core.mongodb import get_mongo_db
|
||||||
|
from src.core.redis import get_redis
|
||||||
|
from src.core.security import decode_access_token
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession:
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> dict:
|
||||||
|
payload = decode_access_token(token)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||||
|
return {"user_id": payload["sub"], "role": payload["role"]}
|
||||||
|
|
||||||
|
|
||||||
|
def require_role(*roles: str):
|
||||||
|
def checker(current_user: Annotated[dict, Depends(get_current_user)]):
|
||||||
|
if current_user["role"] not in roles:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions")
|
||||||
|
return current_user
|
||||||
|
return checker
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from src.api.v1.router import router as v1_router
|
||||||
|
from src.config import settings
|
||||||
|
from src.core.mongodb import close_mongo
|
||||||
|
from src.core.redis import close_redis
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
await close_mongo()
|
||||||
|
await close_redis()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="GO2 Patrol C4ISR",
|
||||||
|
version="1.0.0",
|
||||||
|
description="无人狗自主巡逻系统 — 指挥控制应用服务器",
|
||||||
|
lifespan=lifespan,
|
||||||
|
docs_url="/docs",
|
||||||
|
openapi_url="/openapi.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(v1_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"src.main:app",
|
||||||
|
host=settings.app_host,
|
||||||
|
port=settings.app_port,
|
||||||
|
reload=settings.app_debug,
|
||||||
|
)
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
from src.models.user import Base, User
|
||||||
|
from src.models.dog import Dog
|
||||||
|
from src.models.patrol_task import PatrolTask
|
||||||
|
from src.models.waypoint import Waypoint
|
||||||
|
from src.models.threat import ThreatArchive, ThreatAlert, ThreatMark, ThreatRule, AlertEvent
|
||||||
|
from src.models.system_parameter import SystemParameter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"User",
|
||||||
|
"Dog",
|
||||||
|
"PatrolTask",
|
||||||
|
"Waypoint",
|
||||||
|
"ThreatArchive",
|
||||||
|
"ThreatAlert",
|
||||||
|
"ThreatMark",
|
||||||
|
"ThreatRule",
|
||||||
|
"AlertEvent",
|
||||||
|
"SystemParameter",
|
||||||
|
]
|
||||||
@ -0,0 +1,23 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String
|
||||||
|
from sqlalchemy.dialects.mysql import ENUM
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.models.user import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Dog(Base):
|
||||||
|
__tablename__ = "dog"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||||
|
serial_number: Mapped[str | None] = mapped_column(String(100), unique=True)
|
||||||
|
model: Mapped[str] = mapped_column(String(50), default="GO2")
|
||||||
|
ip_address: Mapped[str | None] = mapped_column(String(45))
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
ENUM("ONLINE", "OFFLINE", "MAINTENANCE", "EMERGENCY"), default="OFFLINE"
|
||||||
|
)
|
||||||
|
last_heartbeat: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Float, ForeignKey, String, Text
|
||||||
|
from sqlalchemy.dialects.mysql import ENUM, JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from src.models.user import Base
|
||||||
|
|
||||||
|
|
||||||
|
class PatrolTask(Base):
|
||||||
|
__tablename__ = "patrol_task"
|
||||||
|
|
||||||
|
task_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
task_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
ENUM("DRAFT", "SAVED", "DEPLOYED", "EXECUTING", "COMPLETED", "CANCELLED", "ABORTED"),
|
||||||
|
default="DRAFT",
|
||||||
|
)
|
||||||
|
assigned_dog_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("dog.id"))
|
||||||
|
created_by: Mapped[str] = mapped_column(String(36), ForeignKey("user.id"), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||||
|
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
schedule_type: Mapped[str] = mapped_column(
|
||||||
|
ENUM("ONCE", "PERIODIC", "CRON"), default="ONCE"
|
||||||
|
)
|
||||||
|
schedule_config: Mapped[dict | None] = mapped_column(JSON)
|
||||||
|
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
|
||||||
|
waypoints: Mapped[list["Waypoint"]] = relationship( # noqa: F821
|
||||||
|
"Waypoint",
|
||||||
|
backref="patrol_task",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
lazy="selectin",
|
||||||
|
)
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String, Text
|
||||||
|
from sqlalchemy.dialects.mysql import JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.models.user import Base
|
||||||
|
|
||||||
|
|
||||||
|
class SystemParameter(Base):
|
||||||
|
__tablename__ = "system_parameter"
|
||||||
|
|
||||||
|
param_key: Mapped[str] = mapped_column(String(200), primary_key=True)
|
||||||
|
param_value: Mapped[str | int | float | dict | list] = mapped_column(JSON, nullable=False)
|
||||||
|
param_section: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text)
|
||||||
|
updated_by: Mapped[str | None] = mapped_column(String(36))
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||||
@ -0,0 +1,111 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Float, ForeignKey, Integer, Numeric, String, Text
|
||||||
|
from sqlalchemy.dialects.mysql import ENUM, JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.models.user import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatArchive(Base):
|
||||||
|
__tablename__ = "threat_archive"
|
||||||
|
|
||||||
|
threat_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
threat_type: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
threat_level: Mapped[str] = mapped_column(
|
||||||
|
ENUM("LOW", "MEDIUM", "HIGH", "CRITICAL"), nullable=False
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
ENUM("PENDING", "CONFIRMED", "DISMISSED", "TRACKING", "ESCALATED", "RESOLVED"),
|
||||||
|
default="PENDING",
|
||||||
|
)
|
||||||
|
detected_by_dog: Mapped[str] = mapped_column(String(36), ForeignKey("dog.id"), nullable=False)
|
||||||
|
detected_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||||
|
latitude: Mapped[Decimal | None] = mapped_column(Numeric(10, 7))
|
||||||
|
longitude: Mapped[Decimal | None] = mapped_column(Numeric(10, 7))
|
||||||
|
altitude: Mapped[Decimal] = mapped_column(Numeric(8, 2), default=Decimal("0"))
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
class_label: Mapped[str | None] = mapped_column(String(100))
|
||||||
|
description: Mapped[str | None] = mapped_column(Text)
|
||||||
|
image_url: Mapped[str | None] = mapped_column(String(500))
|
||||||
|
resolved_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
resolved_by: Mapped[str | None] = mapped_column(String(36), ForeignKey("user.id"), nullable=True)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAlert(Base):
|
||||||
|
__tablename__ = "threat_alert"
|
||||||
|
|
||||||
|
alert_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
threat_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("threat_archive.threat_id"))
|
||||||
|
dog_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||||
|
alert_level: Mapped[str] = mapped_column(
|
||||||
|
ENUM("LOW", "MEDIUM", "HIGH", "CRITICAL"), nullable=False
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
ENUM("INITIAL", "PENDING", "CONFIRMED", "DISMISSED", "TRACKING", "ESCALATED", "RESOLVED"),
|
||||||
|
default="INITIAL",
|
||||||
|
)
|
||||||
|
message: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
acknowledged_by: Mapped[str | None] = mapped_column(String(36), ForeignKey("user.id"), nullable=True)
|
||||||
|
resolved_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatMark(Base):
|
||||||
|
__tablename__ = "threat_mark"
|
||||||
|
|
||||||
|
mark_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
threat_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("threat_archive.threat_id"))
|
||||||
|
mark_type: Mapped[str] = mapped_column(
|
||||||
|
ENUM("KEEP_OUT", "CAUTION", "POINT_OF_INTEREST", "INCIDENT"), nullable=False
|
||||||
|
)
|
||||||
|
priority: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
latitude: Mapped[Decimal] = mapped_column(Numeric(10, 7), nullable=False)
|
||||||
|
longitude: Mapped[Decimal] = mapped_column(Numeric(10, 7), nullable=False)
|
||||||
|
radius_meters: Mapped[Decimal] = mapped_column(Numeric(8, 2), default=Decimal("10.0"))
|
||||||
|
created_by: Mapped[str] = mapped_column(String(36), ForeignKey("user.id"), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatRule(Base):
|
||||||
|
__tablename__ = "threat_rules"
|
||||||
|
|
||||||
|
rule_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
rule_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
rule_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
rule_condition: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||||
|
rule_action: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||||
|
priority: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||||
|
|
||||||
|
|
||||||
|
class AlertEvent(Base):
|
||||||
|
__tablename__ = "alert_event"
|
||||||
|
|
||||||
|
event_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
alert_id: Mapped[str] = mapped_column(String(36), ForeignKey("threat_alert.alert_id"), nullable=False)
|
||||||
|
event_type: Mapped[str] = mapped_column(
|
||||||
|
ENUM("CREATED", "ACKNOWLEDGED", "DISMISSED", "CONFIRMED", "TRACKING", "ESCALATED", "RESOLVED"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
performed_by: Mapped[str | None] = mapped_column(String(36))
|
||||||
|
performed_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
details: Mapped[dict | None] = mapped_column(JSON)
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, String, Text
|
||||||
|
from sqlalchemy.dialects.mysql import ENUM
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.core.database import engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
username: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
display_name: Mapped[str | None] = mapped_column(String(200))
|
||||||
|
role: Mapped[str] = mapped_column(ENUM("ADMIN", "COMMANDER", "OPERATOR"), default="OPERATOR")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey, Integer, Numeric, String
|
||||||
|
from sqlalchemy.dialects.mysql import ENUM
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.models.user import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Waypoint(Base):
|
||||||
|
__tablename__ = "waypoint"
|
||||||
|
|
||||||
|
waypoint_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
task_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("patrol_task.task_id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
sequence_order: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
latitude: Mapped[Decimal] = mapped_column(Numeric(10, 7), nullable=False)
|
||||||
|
longitude: Mapped[Decimal] = mapped_column(Numeric(10, 7), nullable=False)
|
||||||
|
altitude: Mapped[Decimal] = mapped_column(Numeric(8, 2), default=Decimal("0"))
|
||||||
|
action_type: Mapped[str] = mapped_column(
|
||||||
|
ENUM("PASS", "SCAN", "HOVER", "OBSERVE"), default="PASS"
|
||||||
|
)
|
||||||
|
dwell_time_sec: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
heading_deg: Mapped[Decimal | None] = mapped_column(Numeric(5, 2), nullable=True)
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
username: str = Field(min_length=3, max_length=100)
|
||||||
|
password: str = Field(min_length=1, max_length=200)
|
||||||
|
display_name: str | None = None
|
||||||
|
role: str = "OPERATOR"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
display_name: str | None
|
||||||
|
role: str
|
||||||
|
is_active: bool
|
||||||
|
last_login_at: datetime | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DogCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
serial_number: str | None = None
|
||||||
|
model: str = "GO2"
|
||||||
|
ip_address: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DogUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
serial_number: str | None = None
|
||||||
|
model: str | None = None
|
||||||
|
ip_address: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DogResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
serial_number: str | None
|
||||||
|
model: str
|
||||||
|
ip_address: str | None
|
||||||
|
status: str
|
||||||
|
last_heartbeat: datetime | None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
@ -0,0 +1,67 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointInput(BaseModel):
|
||||||
|
sequence_order: int
|
||||||
|
latitude: Decimal = Field(decimal_places=7)
|
||||||
|
longitude: Decimal = Field(decimal_places=7)
|
||||||
|
altitude: Decimal = Field(default=Decimal("0"), decimal_places=2)
|
||||||
|
action_type: str = "PASS"
|
||||||
|
dwell_time_sec: int = 0
|
||||||
|
heading_deg: Decimal | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointResponse(BaseModel):
|
||||||
|
waypoint_id: str
|
||||||
|
task_id: str
|
||||||
|
sequence_order: int
|
||||||
|
latitude: Decimal
|
||||||
|
longitude: Decimal
|
||||||
|
altitude: Decimal
|
||||||
|
action_type: str
|
||||||
|
dwell_time_sec: int
|
||||||
|
heading_deg: Decimal | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class PatrolCreate(BaseModel):
|
||||||
|
task_name: str = Field(min_length=1, max_length=200)
|
||||||
|
description: str | None = None
|
||||||
|
assigned_dog_id: str | None = None
|
||||||
|
schedule_type: str = "ONCE"
|
||||||
|
schedule_config: dict | None = None
|
||||||
|
waypoints: list[WaypointInput] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PatrolUpdate(BaseModel):
|
||||||
|
task_name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
assigned_dog_id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
schedule_type: str | None = None
|
||||||
|
schedule_config: dict | None = None
|
||||||
|
progress: float | None = None
|
||||||
|
waypoints: list[WaypointInput] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PatrolResponse(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
task_name: str
|
||||||
|
description: str | None
|
||||||
|
status: str
|
||||||
|
assigned_dog_id: str | None
|
||||||
|
created_by: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
started_at: datetime | None
|
||||||
|
completed_at: datetime | None
|
||||||
|
schedule_type: str
|
||||||
|
schedule_config: dict | None
|
||||||
|
progress: float
|
||||||
|
waypoints: list[WaypointResponse] = []
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SystemParamResponse(BaseModel):
|
||||||
|
param_key: str
|
||||||
|
param_value: Any
|
||||||
|
param_section: str
|
||||||
|
description: str | None
|
||||||
|
updated_by: str | None
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SystemParamUpdate(BaseModel):
|
||||||
|
param_value: Any
|
||||||
|
description: str | None = None
|
||||||
@ -0,0 +1,111 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AlertResponse(BaseModel):
|
||||||
|
alert_id: str
|
||||||
|
threat_id: str | None
|
||||||
|
dog_id: str | None
|
||||||
|
alert_level: str
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
created_at: datetime
|
||||||
|
acknowledged_at: datetime | None
|
||||||
|
acknowledged_by: str | None
|
||||||
|
resolved_at: datetime | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AlertAckRequest(BaseModel):
|
||||||
|
action: str # "acknowledge" or "resolve" or "dismiss"
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatResponse(BaseModel):
|
||||||
|
threat_id: str
|
||||||
|
threat_type: str
|
||||||
|
threat_level: str
|
||||||
|
status: str
|
||||||
|
detected_by_dog: str
|
||||||
|
detected_at: datetime
|
||||||
|
latitude: Decimal | None
|
||||||
|
longitude: Decimal | None
|
||||||
|
confidence: float
|
||||||
|
class_label: str | None
|
||||||
|
description: str | None
|
||||||
|
image_url: str | None
|
||||||
|
resolved_at: datetime | None
|
||||||
|
notes: str | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatMarkCreate(BaseModel):
|
||||||
|
mark_type: str # KEEP_OUT, CAUTION, POINT_OF_INTEREST, INCIDENT
|
||||||
|
latitude: Decimal
|
||||||
|
longitude: Decimal
|
||||||
|
radius_meters: Decimal = Decimal("10.0")
|
||||||
|
priority: int = 0
|
||||||
|
threat_id: str | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatMarkUpdate(BaseModel):
|
||||||
|
mark_type: str | None = None
|
||||||
|
latitude: Decimal | None = None
|
||||||
|
longitude: Decimal | None = None
|
||||||
|
radius_meters: Decimal | None = None
|
||||||
|
priority: int | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatMarkResponse(BaseModel):
|
||||||
|
mark_id: str
|
||||||
|
threat_id: str | None
|
||||||
|
mark_type: str
|
||||||
|
priority: int
|
||||||
|
latitude: Decimal
|
||||||
|
longitude: Decimal
|
||||||
|
radius_meters: Decimal
|
||||||
|
created_by: str
|
||||||
|
created_at: datetime
|
||||||
|
expires_at: datetime | None
|
||||||
|
notes: str | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatRuleCreate(BaseModel):
|
||||||
|
rule_name: str
|
||||||
|
rule_type: str # e.g. "classification", "threshold", "anomaly"
|
||||||
|
rule_condition: dict
|
||||||
|
rule_action: dict
|
||||||
|
priority: int = 0
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatRuleUpdate(BaseModel):
|
||||||
|
rule_name: str | None = None
|
||||||
|
rule_type: str | None = None
|
||||||
|
rule_condition: dict | None = None
|
||||||
|
rule_action: dict | None = None
|
||||||
|
priority: int | None = None
|
||||||
|
is_active: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatRuleResponse(BaseModel):
|
||||||
|
rule_id: str
|
||||||
|
rule_name: str
|
||||||
|
rule_type: str
|
||||||
|
rule_condition: dict
|
||||||
|
rule_action: dict
|
||||||
|
priority: int
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
from src.core.security import create_access_token
|
||||||
|
|
||||||
|
_TEST_USER_ID = "sys-admin-00000000-0000-0000-0000-00"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_token():
|
||||||
|
return create_access_token(subject=_TEST_USER_ID, role="admin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers(admin_token):
|
||||||
|
return {"Authorization": f"Bearer {admin_token}"}
|
||||||
@ -0,0 +1,211 @@
|
|||||||
|
"""Integration tests — each test uses its own event loop and DB engine."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
|
from src.main import app
|
||||||
|
from src.core.security import create_access_token
|
||||||
|
from src.dependencies import get_db
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
TEST_USER = "sys-admin-00000000-0000-0000-0000-00"
|
||||||
|
|
||||||
|
|
||||||
|
def _headers():
|
||||||
|
return {"Authorization": f"Bearer {create_access_token(TEST_USER, 'admin')}"}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_test(coro):
|
||||||
|
"""Run an async test with a fresh DB engine bound to its own loop."""
|
||||||
|
engine = create_async_engine(settings.database_url, pool_pre_ping=True)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def _get_test_db():
|
||||||
|
async with session_factory() as s:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = _get_test_db
|
||||||
|
try:
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||||
|
await coro(c)
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run(coro):
|
||||||
|
return asyncio.run(_run_test(coro))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth ──
|
||||||
|
|
||||||
|
class TestAuth:
|
||||||
|
def test_register(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/auth/register", json={
|
||||||
|
"username": "pytest_user", "password": "Test1234!",
|
||||||
|
"display_name": "Pytest", "role": "operator",
|
||||||
|
})
|
||||||
|
assert r.status_code in (200, 400)
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_login_bad_password(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/auth/login", json={"username": "admin", "password": "wrong"})
|
||||||
|
assert r.status_code == 401
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_me_unauthorized(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/auth/me")
|
||||||
|
assert r.status_code == 401
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_me_with_token(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/auth/me", headers=_headers())
|
||||||
|
assert r.status_code in (200, 404)
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dogs ──
|
||||||
|
|
||||||
|
class TestDogs:
|
||||||
|
def test_list_dogs(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/dogs", headers=_headers())
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert isinstance(r.json(), list)
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_create_and_delete(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/dogs", headers=_headers(), json={
|
||||||
|
"name": "pytest-dog", "serial_number": "PYTEST", "ip_address": "10.0.0.99",
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
did = r.json()["id"]
|
||||||
|
r = await c.delete(f"/api/v1/dogs/{did}", headers=_headers())
|
||||||
|
assert r.status_code == 204
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_unauthorized(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/dogs")
|
||||||
|
assert r.status_code == 401
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Patrols ──
|
||||||
|
|
||||||
|
class TestPatrols:
|
||||||
|
def test_list(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/patrols", headers=_headers())
|
||||||
|
assert r.status_code == 200
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_create_and_delete(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/patrols", headers=_headers(), json={
|
||||||
|
"task_name": "pytest-patrol",
|
||||||
|
"waypoints": [
|
||||||
|
{"sequence_order": 1, "latitude": 28.25, "longitude": 112.97, "action_type": "PASS"},
|
||||||
|
{"sequence_order": 2, "latitude": 28.251, "longitude": 112.971, "action_type": "SCAN"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
assert len(r.json().get("waypoints", [])) == 2
|
||||||
|
await c.delete(f"/api/v1/patrols/{r.json()['task_id']}", headers=_headers())
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/patrols", headers=_headers(), json={
|
||||||
|
"task_name": "pytest-upd", "waypoints": [
|
||||||
|
{"sequence_order": 1, "latitude": 28.25, "longitude": 112.97, "action_type": "PASS"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
tid = r.json()["task_id"]
|
||||||
|
r = await c.put(f"/api/v1/patrols/{tid}", headers=_headers(), json={"task_name": "updated"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["task_name"] == "updated"
|
||||||
|
await c.delete(f"/api/v1/patrols/{tid}", headers=_headers())
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_deploy_not_found(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/patrols/no-such-id/deploy", headers=_headers())
|
||||||
|
assert r.status_code == 404
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threats ──
|
||||||
|
|
||||||
|
class TestThreats:
|
||||||
|
def test_list_alerts(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/threats/alerts", headers=_headers())
|
||||||
|
assert r.status_code == 200
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_alert_stats(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/threats/alerts/stats", headers=_headers())
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert "total" in r.json()
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_create_alert(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/threats/alerts", headers=_headers(), json={
|
||||||
|
"alert_level": "HIGH", "message": "pytest alert",
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
assert r.json()["alert_level"] == "HIGH"
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_alert_action_not_found(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/threats/alerts/nonexistent/action",
|
||||||
|
headers=_headers(), json={"action": "acknowledge"})
|
||||||
|
assert r.status_code == 404
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_archive(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.get("/api/v1/threats/archive", headers=_headers())
|
||||||
|
assert r.status_code == 200
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_marks_crud(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/threats/marks", headers=_headers(), json={
|
||||||
|
"mark_type": "CAUTION", "latitude": 28.251, "longitude": 112.971,
|
||||||
|
"radius_meters": 15, "notes": "pytest",
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
mid = r.json()["mark_id"]
|
||||||
|
r = await c.put(f"/api/v1/threats/marks/{mid}", headers=_headers(), json={"notes": "upd"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
r = await c.delete(f"/api/v1/threats/marks/{mid}", headers=_headers())
|
||||||
|
assert r.status_code == 204
|
||||||
|
run(t)
|
||||||
|
|
||||||
|
def test_rules_crud(self):
|
||||||
|
async def t(c):
|
||||||
|
r = await c.post("/api/v1/threats/rules", headers=_headers(), json={
|
||||||
|
"rule_name": "pytest rule", "rule_type": "classification",
|
||||||
|
"rule_condition": {"class_label": "pytest"},
|
||||||
|
"rule_action": {"threat_type": "INTRUSION"},
|
||||||
|
})
|
||||||
|
assert r.status_code == 201
|
||||||
|
rid = r.json()["rule_id"]
|
||||||
|
r = await c.put(f"/api/v1/threats/rules/{rid}", headers=_headers(), json={"is_active": False})
|
||||||
|
assert r.status_code == 200
|
||||||
|
r = await c.delete(f"/api/v1/threats/rules/{rid}", headers=_headers())
|
||||||
|
assert r.status_code == 204
|
||||||
|
run(t)
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
[project]
|
||||||
|
name = "go2-comm-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "GO2 Patrol C4ISR Communication Server (WebSocket Hub)"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"websockets>=12.0",
|
||||||
|
"redis[hiredis]>=5.0",
|
||||||
|
"pydantic>=2.5",
|
||||||
|
"pydantic-settings>=2.2",
|
||||||
|
"structlog>=24.1",
|
||||||
|
"grpcio>=1.60.0",
|
||||||
|
"grpcio-tools>=1.60.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.1",
|
||||||
|
"pytest-asyncio>=0.23",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: threat_assessment.proto
|
||||||
|
# Protobuf Python Version: 6.31.1
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC,
|
||||||
|
6,
|
||||||
|
31,
|
||||||
|
1,
|
||||||
|
'',
|
||||||
|
'threat_assessment.proto'
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17threat_assessment.proto\x12\x06threat\"\xb6\x01\n\rThreatRequest\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x13\n\x0b\x63lass_label\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x10\n\x08latitude\x18\x04 \x01(\x01\x12\x11\n\tlongitude\x18\x05 \x01(\x01\x12\x17\n\x0f\x64\x65tected_by_dog\x18\x06 \x01(\t\x12\x18\n\x10\x64\x65tected_at_unix\x18\x07 \x01(\x03\x12\x11\n\timage_url\x18\x08 \x01(\t\"g\n\x14ThreatClassification\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x13\n\x0bthreat_type\x18\x02 \x01(\t\x12\x12\n\nrisk_score\x18\x03 \x01(\x02\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\x83\x01\n\x15ThreatLevelAssessment\x12\x11\n\tthreat_id\x18\x01 \x01(\t\x12\x14\n\x0cthreat_level\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x11\n\treasoning\x18\x04 \x01(\t\x12\x1a\n\x12recommended_action\x18\x05 \x01(\t\"~\n\x0e\x41nomalyRequest\x12\x0e\n\x06\x64og_id\x18\x01 \x01(\t\x12(\n\x07samples\x18\x02 \x03(\x0b\x32\x17.threat.TelemetrySample\x12\x19\n\x11window_start_unix\x18\x03 \x01(\x03\x12\x17\n\x0fwindow_end_unix\x18\x04 \x01(\x03\"v\n\x0fTelemetrySample\x12\x16\n\x0etimestamp_unix\x18\x01 \x01(\x03\x12\x10\n\x08latitude\x18\x02 \x01(\x01\x12\x11\n\tlongitude\x18\x03 \x01(\x01\x12\x11\n\tspeed_mps\x18\x04 \x01(\x02\x12\x13\n\x0bheading_deg\x18\x05 \x01(\x02\"y\n\x0f\x41nomalyResponse\x12\x0e\n\x06\x64og_id\x18\x01 \x01(\t\x12\x14\n\x0cis_anomalous\x18\x02 \x01(\x08\x12\x14\n\x0c\x61nomaly_type\x18\x03 \x01(\t\x12\x15\n\ranomaly_score\x18\x04 \x01(\x02\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t2\xe6\x01\n\x10ThreatAssessment\x12\x45\n\x0e\x43lassifyThreat\x12\x15.threat.ThreatRequest\x1a\x1c.threat.ThreatClassification\x12I\n\x11\x41ssessThreatLevel\x12\x15.threat.ThreatRequest\x1a\x1d.threat.ThreatLevelAssessment\x12@\n\rDetectAnomaly\x12\x16.threat.AnomalyRequest\x1a\x17.threat.AnomalyResponseb\x06proto3')
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'threat_assessment_pb2', _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_THREATREQUEST']._serialized_start=36
|
||||||
|
_globals['_THREATREQUEST']._serialized_end=218
|
||||||
|
_globals['_THREATCLASSIFICATION']._serialized_start=220
|
||||||
|
_globals['_THREATCLASSIFICATION']._serialized_end=323
|
||||||
|
_globals['_THREATLEVELASSESSMENT']._serialized_start=326
|
||||||
|
_globals['_THREATLEVELASSESSMENT']._serialized_end=457
|
||||||
|
_globals['_ANOMALYREQUEST']._serialized_start=459
|
||||||
|
_globals['_ANOMALYREQUEST']._serialized_end=585
|
||||||
|
_globals['_TELEMETRYSAMPLE']._serialized_start=587
|
||||||
|
_globals['_TELEMETRYSAMPLE']._serialized_end=705
|
||||||
|
_globals['_ANOMALYRESPONSE']._serialized_start=707
|
||||||
|
_globals['_ANOMALYRESPONSE']._serialized_end=828
|
||||||
|
_globals['_THREATASSESSMENT']._serialized_start=831
|
||||||
|
_globals['_THREATASSESSMENT']._serialized_end=1061
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
@ -0,0 +1,189 @@
|
|||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from . import threat_assessment_pb2 as threat__assessment__pb2
|
||||||
|
|
||||||
|
GRPC_GENERATED_VERSION = '1.80.0'
|
||||||
|
GRPC_VERSION = grpc.__version__
|
||||||
|
_version_not_supported = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from grpc._utilities import first_version_is_lower
|
||||||
|
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||||
|
except ImportError:
|
||||||
|
_version_not_supported = True
|
||||||
|
|
||||||
|
if _version_not_supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
|
+ ' but the generated code in threat_assessment_pb2_grpc.py depends on'
|
||||||
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAssessmentStub(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.ClassifyThreat = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/ClassifyThreat',
|
||||||
|
request_serializer=threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.ThreatClassification.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.AssessThreatLevel = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/AssessThreatLevel',
|
||||||
|
request_serializer=threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.ThreatLevelAssessment.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.DetectAnomaly = channel.unary_unary(
|
||||||
|
'/threat.ThreatAssessment/DetectAnomaly',
|
||||||
|
request_serializer=threat__assessment__pb2.AnomalyRequest.SerializeToString,
|
||||||
|
response_deserializer=threat__assessment__pb2.AnomalyResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatAssessmentServicer(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ClassifyThreat(self, request, context):
|
||||||
|
"""威胁分类
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def AssessThreatLevel(self, request, context):
|
||||||
|
"""威胁等级评估
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def DetectAnomaly(self, request, context):
|
||||||
|
"""异常行为检测
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_ThreatAssessmentServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'ClassifyThreat': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.ClassifyThreat,
|
||||||
|
request_deserializer=threat__assessment__pb2.ThreatRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.ThreatClassification.SerializeToString,
|
||||||
|
),
|
||||||
|
'AssessThreatLevel': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.AssessThreatLevel,
|
||||||
|
request_deserializer=threat__assessment__pb2.ThreatRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.ThreatLevelAssessment.SerializeToString,
|
||||||
|
),
|
||||||
|
'DetectAnomaly': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.DetectAnomaly,
|
||||||
|
request_deserializer=threat__assessment__pb2.AnomalyRequest.FromString,
|
||||||
|
response_serializer=threat__assessment__pb2.AnomalyResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'threat.ThreatAssessment', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('threat.ThreatAssessment', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class ThreatAssessment(object):
|
||||||
|
"""威胁评估服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ClassifyThreat(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/ClassifyThreat',
|
||||||
|
threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.ThreatClassification.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def AssessThreatLevel(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/AssessThreatLevel',
|
||||||
|
threat__assessment__pb2.ThreatRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.ThreatLevelAssessment.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def DetectAnomaly(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/threat.ThreatAssessment/DetectAnomaly',
|
||||||
|
threat__assessment__pb2.AnomalyRequest.SerializeToString,
|
||||||
|
threat__assessment__pb2.AnomalyResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
@ -0,0 +1,393 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time as _time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("comm_server")
|
||||||
|
|
||||||
|
# ── Threat classification: gRPC (primary) + inline fallback ──
|
||||||
|
|
||||||
|
_ai_channel = None
|
||||||
|
_ai_stub = None
|
||||||
|
_ai_pb2 = None
|
||||||
|
_ai_available = False
|
||||||
|
|
||||||
|
# Inline fallback rules (used when ai_server is unreachable)
|
||||||
|
_LABEL_TO_TYPE = {
|
||||||
|
"person": "INTRUSION", "human": "INTRUSION",
|
||||||
|
"vehicle": "SUSPICIOUS_VEHICLE", "car": "SUSPICIOUS_VEHICLE",
|
||||||
|
"truck": "SUSPICIOUS_VEHICLE", "bus": "SUSPICIOUS_VEHICLE",
|
||||||
|
"motorcycle": "SUSPICIOUS_VEHICLE", "bicycle": "SUSPICIOUS_VEHICLE",
|
||||||
|
"fire": "FIRE", "smoke": "FIRE", "flame": "FIRE",
|
||||||
|
"animal": "WILDLIFE", "dog": "WILDLIFE", "cat": "WILDLIFE",
|
||||||
|
"obstacle": "OBSTACLE", "debris": "OBSTACLE",
|
||||||
|
}
|
||||||
|
_LABEL_RISK = {
|
||||||
|
"person": 0.7, "human": 0.7,
|
||||||
|
"fire": 0.95, "smoke": 0.85, "flame": 0.95,
|
||||||
|
"vehicle": 0.5, "car": 0.5, "truck": 0.5, "bus": 0.5,
|
||||||
|
"motorcycle": 0.4, "bicycle": 0.3,
|
||||||
|
"animal": 0.2, "dog": 0.15, "cat": 0.1,
|
||||||
|
"obstacle": 0.3, "debris": 0.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _init_ai_grpc():
|
||||||
|
global _ai_channel, _ai_stub, _ai_pb2, _ai_available
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from grpc_stubs import threat_assessment_pb2 as pb2
|
||||||
|
from grpc_stubs import threat_assessment_pb2_grpc as pb2_grpc
|
||||||
|
|
||||||
|
_ai_pb2 = pb2
|
||||||
|
ai_url = os.environ.get("AI_SERVER_URL", "localhost:50051")
|
||||||
|
_ai_channel = grpc.aio.insecure_channel(ai_url)
|
||||||
|
_ai_stub = pb2_grpc.ThreatAssessmentStub(_ai_channel)
|
||||||
|
await _ai_stub.ClassifyThreat(pb2.ThreatRequest(
|
||||||
|
threat_id="health-check", class_label="test", confidence=0.0
|
||||||
|
), timeout=3)
|
||||||
|
_ai_available = True
|
||||||
|
logger.info(f"AI server gRPC connected: {ai_url}")
|
||||||
|
except Exception as e:
|
||||||
|
_ai_available = False
|
||||||
|
logger.warning(f"AI server unavailable, using inline classification: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_inline(det: dict) -> dict:
|
||||||
|
label = det.get("class_label", "").lower().strip()
|
||||||
|
confidence = det.get("confidence", 0.0)
|
||||||
|
base_risk = _LABEL_RISK.get(label, 0.3)
|
||||||
|
risk_score = round(confidence * base_risk, 3)
|
||||||
|
threat_type = _LABEL_TO_TYPE.get(label, "UNKNOWN")
|
||||||
|
combined = confidence * base_risk
|
||||||
|
if combined >= 0.7:
|
||||||
|
level = "CRITICAL"
|
||||||
|
elif combined >= 0.5:
|
||||||
|
level = "HIGH"
|
||||||
|
elif combined >= 0.3:
|
||||||
|
level = "MEDIUM"
|
||||||
|
else:
|
||||||
|
level = "LOW"
|
||||||
|
det["threat_type"] = threat_type
|
||||||
|
det["risk_score"] = risk_score
|
||||||
|
det["threat_level"] = level
|
||||||
|
return det
|
||||||
|
|
||||||
|
|
||||||
|
async def _classify_detection(det: dict) -> dict:
|
||||||
|
if _ai_available and _ai_stub is not None and _ai_pb2 is not None:
|
||||||
|
try:
|
||||||
|
label = det.get("class_label", "").lower().strip()
|
||||||
|
confidence = det.get("confidence", 0.0)
|
||||||
|
det_id = str(det.get("class_id", 0))
|
||||||
|
resp = await _ai_stub.ClassifyThreat(_ai_pb2.ThreatRequest(
|
||||||
|
threat_id=det_id, class_label=label, confidence=confidence
|
||||||
|
), timeout=2)
|
||||||
|
level_resp = await _ai_stub.AssessThreatLevel(_ai_pb2.ThreatRequest(
|
||||||
|
threat_id=det_id, class_label=label, confidence=confidence
|
||||||
|
), timeout=2)
|
||||||
|
det["threat_type"] = resp.threat_type
|
||||||
|
det["risk_score"] = round(resp.risk_score, 3)
|
||||||
|
det["threat_level"] = level_resp.threat_level
|
||||||
|
return det
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"gRPC classification failed, falling back to inline: {e}")
|
||||||
|
return _classify_inline(det)
|
||||||
|
return _classify_inline(det)
|
||||||
|
|
||||||
|
|
||||||
|
# Optional Redis caching
|
||||||
|
redis_client = None
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
async def init_redis():
|
||||||
|
global redis_client
|
||||||
|
url = os.environ.get("REDIS_URL", "redis://:redis_pass_2026@localhost:6379/0")
|
||||||
|
try:
|
||||||
|
redis_client = await aioredis.from_url(url, decode_responses=True)
|
||||||
|
await redis_client.ping()
|
||||||
|
logger.info("Redis connected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis unavailable: {e}")
|
||||||
|
redis_client = None
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("redis package not installed, caching disabled")
|
||||||
|
|
||||||
|
# Optional MongoDB for telemetry persistence
|
||||||
|
mongo_db = None
|
||||||
|
try:
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
async def init_mongo():
|
||||||
|
global mongo_db
|
||||||
|
uri = os.environ.get("MONGO_URI", "mongodb://localhost:27017")
|
||||||
|
db_name = os.environ.get("MONGO_DATABASE", "go2_patrol")
|
||||||
|
try:
|
||||||
|
client = AsyncIOMotorClient(uri)
|
||||||
|
mongo_db = client[db_name]
|
||||||
|
await mongo_db.command("ping")
|
||||||
|
# Ensure indexes
|
||||||
|
await mongo_db.telemetry_records.create_index([("dog_id", 1), ("timestamp", -1)])
|
||||||
|
await mongo_db.detection_records.create_index([("dog_id", 1), ("timestamp", -1)])
|
||||||
|
logger.info("MongoDB connected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"MongoDB unavailable: {e}")
|
||||||
|
mongo_db = None
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("motor package not installed, MongoDB logging disabled")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Client:
|
||||||
|
ws: websockets.WebSocketServerProtocol
|
||||||
|
client_type: str
|
||||||
|
client_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketHub:
|
||||||
|
MAX_HISTORY = 200
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.dog_connections: dict[str, Client] = {}
|
||||||
|
self.browser_connections: dict[str, Client] = {}
|
||||||
|
self.recent_alerts: deque = deque(maxlen=self.MAX_HISTORY)
|
||||||
|
self.latest_telemetry: dict[str, dict] = {} # dog_id -> last telemetry
|
||||||
|
|
||||||
|
async def register(self, client: Client):
|
||||||
|
if client.client_type == "dog":
|
||||||
|
self.dog_connections[client.client_id] = client
|
||||||
|
logger.info(f"Dog connected: {client.client_id}")
|
||||||
|
else:
|
||||||
|
self.browser_connections[client.client_id] = client
|
||||||
|
logger.info(f"Browser connected: {client.client_id}")
|
||||||
|
# Send cached state to new browser
|
||||||
|
await self._send_cached_state(client)
|
||||||
|
|
||||||
|
async def unregister(self, client: Client):
|
||||||
|
if client.client_type == "dog":
|
||||||
|
self.dog_connections.pop(client.client_id, None)
|
||||||
|
logger.info(f"Dog disconnected: {client.client_id}")
|
||||||
|
else:
|
||||||
|
self.browser_connections.pop(client.client_id, None)
|
||||||
|
logger.info(f"Browser disconnected: {client.client_id}")
|
||||||
|
|
||||||
|
async def broadcast_to_browsers(self, message: dict):
|
||||||
|
self._cache_message(message)
|
||||||
|
raw = json.dumps(message)
|
||||||
|
stale = []
|
||||||
|
for cid, client in self.browser_connections.items():
|
||||||
|
try:
|
||||||
|
await client.ws.send(raw)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
stale.append(cid)
|
||||||
|
for cid in stale:
|
||||||
|
self.browser_connections.pop(cid, None)
|
||||||
|
|
||||||
|
async def send_to_dog(self, dog_id: str, message: dict):
|
||||||
|
client = self.dog_connections.get(dog_id)
|
||||||
|
if client:
|
||||||
|
try:
|
||||||
|
await client.ws.send(json.dumps(message))
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
self.dog_connections.pop(dog_id, None)
|
||||||
|
|
||||||
|
def _cache_message(self, message: dict):
|
||||||
|
msg_type = message.get("type")
|
||||||
|
if msg_type == "telemetry":
|
||||||
|
dog_id = message.get("dog_id", "")
|
||||||
|
self.latest_telemetry[dog_id] = message
|
||||||
|
elif msg_type == "alert":
|
||||||
|
self.recent_alerts.append(message)
|
||||||
|
if redis_client is not None:
|
||||||
|
asyncio.create_task(self._cache_to_redis(message))
|
||||||
|
if mongo_db is not None and msg_type in ("telemetry", "detection", "task_progress"):
|
||||||
|
asyncio.create_task(self._write_to_mongo(message))
|
||||||
|
|
||||||
|
async def _cache_to_redis(self, message: dict):
|
||||||
|
try:
|
||||||
|
msg_type = message.get("type")
|
||||||
|
raw = json.dumps(message)
|
||||||
|
if msg_type == "telemetry":
|
||||||
|
dog_id = message.get("dog_id", "unknown")
|
||||||
|
await redis_client.setex(f"telemetry:{dog_id}", 60, raw)
|
||||||
|
elif msg_type == "alert":
|
||||||
|
await redis_client.lpush("alerts:recent", raw)
|
||||||
|
await redis_client.ltrim("alerts:recent", 0, self.MAX_HISTORY - 1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis cache error: {e}")
|
||||||
|
|
||||||
|
async def _write_to_mongo(self, message: dict):
|
||||||
|
try:
|
||||||
|
msg_type = message.get("type")
|
||||||
|
dog_id = message.get("dog_id", "unknown")
|
||||||
|
data = message.get("data", {})
|
||||||
|
import time
|
||||||
|
doc = {
|
||||||
|
"dog_id": dog_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
if msg_type == "telemetry":
|
||||||
|
doc["type"] = "telemetry"
|
||||||
|
await mongo_db.telemetry_records.insert_one(doc)
|
||||||
|
elif msg_type == "detection":
|
||||||
|
doc["type"] = "detection"
|
||||||
|
await mongo_db.detection_records.insert_one(doc)
|
||||||
|
elif msg_type == "task_progress":
|
||||||
|
doc["type"] = "task_progress"
|
||||||
|
doc["task_id"] = message.get("task_id")
|
||||||
|
await mongo_db.patrol_execution_logs.insert_one(doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"MongoDB write error: {e}")
|
||||||
|
|
||||||
|
async def _send_cached_state(self, client: Client):
|
||||||
|
"""Send cached telemetry + recent alerts to newly connected browser."""
|
||||||
|
# Latest telemetry for each dog
|
||||||
|
for dog_id, telemetry in self.latest_telemetry.items():
|
||||||
|
try:
|
||||||
|
await client.ws.send(json.dumps(telemetry))
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
return
|
||||||
|
# Recent alerts
|
||||||
|
for alert in self.recent_alerts:
|
||||||
|
try:
|
||||||
|
await client.ws.send(json.dumps(alert))
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
return
|
||||||
|
# Send dog status summary
|
||||||
|
if self.dog_connections:
|
||||||
|
status_msg = {
|
||||||
|
"type": "dog_status",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": dog_id,
|
||||||
|
"status": "online",
|
||||||
|
"lastSeen": telemetry.get("data", {}).get("lastSeen", ""),
|
||||||
|
}
|
||||||
|
for dog_id, telemetry in self.latest_telemetry.items()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await client.ws.send(json.dumps(status_msg))
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
hub = WebSocketHub()
|
||||||
|
|
||||||
|
|
||||||
|
async def handler(websocket: websockets.WebSocketServerProtocol):
|
||||||
|
client_type = "browser"
|
||||||
|
client_id = "unknown"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for raw in websocket:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
|
||||||
|
if msg.get("type") == "register":
|
||||||
|
client_type = msg.get("client_type", "browser")
|
||||||
|
client_id = msg.get("client_id", "unknown")
|
||||||
|
client = Client(ws=websocket, client_type=client_type, client_id=client_id)
|
||||||
|
await hub.register(client)
|
||||||
|
await websocket.send(json.dumps({"type": "register_ack", "status": "ok"}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg.get("type") == "ping":
|
||||||
|
await websocket.send(json.dumps({"type": "pong"}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dog telemetry -> broadcast to browsers
|
||||||
|
if client_type == "dog" and msg.get("type") in ("telemetry", "detection", "task_progress", "alert", "lidar_scan"):
|
||||||
|
msg["dog_id"] = client_id
|
||||||
|
|
||||||
|
# Enrich detections with threat classification
|
||||||
|
if msg.get("type") == "detection":
|
||||||
|
detections = msg.get("data", {}).get("detections", [])
|
||||||
|
significant = False
|
||||||
|
for det in detections:
|
||||||
|
await _classify_detection(det)
|
||||||
|
if det.get("threat_level") in ("CRITICAL", "HIGH"):
|
||||||
|
significant = True
|
||||||
|
|
||||||
|
# Auto-create alert for significant detections
|
||||||
|
if significant and redis_client is not None:
|
||||||
|
top = max(detections, key=lambda d: d.get("risk_score", 0))
|
||||||
|
alert_msg = {
|
||||||
|
"type": "alert",
|
||||||
|
"dog_id": client_id,
|
||||||
|
"data": {
|
||||||
|
"alert_id": str(uuid.uuid4()),
|
||||||
|
"alert_level": top.get("threat_level", "MEDIUM"),
|
||||||
|
"threat_type": top.get("threat_type", "UNKNOWN"),
|
||||||
|
"message": f"检测到{top.get('class_label', '未知目标')}, 置信度 {top.get('confidence', 0):.0%}",
|
||||||
|
"dog_id": client_id,
|
||||||
|
"created_at": _time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await hub.broadcast_to_browsers(alert_msg)
|
||||||
|
|
||||||
|
await hub.broadcast_to_browsers(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Browser command -> forward to dog
|
||||||
|
if client_type == "browser" and msg.get("type") in ("task_deploy", "task_cancel", "command"):
|
||||||
|
target_dog = msg.get("dog_id")
|
||||||
|
if target_dog:
|
||||||
|
await hub.send_to_dog(target_dog, msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
client = Client(ws=websocket, client_type=client_type, client_id=client_id)
|
||||||
|
await hub.unregister(client)
|
||||||
|
|
||||||
|
|
||||||
|
async def redis_subscriber():
|
||||||
|
"""Listen for commands from app_server via Redis pub/sub."""
|
||||||
|
if not redis_client:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.subscribe("comm:command", "comm:broadcast")
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] != "message":
|
||||||
|
continue
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
channel = message.get("channel", "")
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
if channel == "comm:broadcast":
|
||||||
|
await hub.broadcast_to_browsers(data)
|
||||||
|
logger.info(f"Broadcast {data.get('type')} to browsers")
|
||||||
|
else:
|
||||||
|
dog_id = data.get("dog_id")
|
||||||
|
if dog_id:
|
||||||
|
await hub.send_to_dog(dog_id, data)
|
||||||
|
logger.info(f"Forwarded {data.get('type')} to dog {dog_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis subscriber error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await init_redis()
|
||||||
|
await init_mongo()
|
||||||
|
await _init_ai_grpc()
|
||||||
|
port = int(os.environ.get("COMM_PORT", "8001"))
|
||||||
|
logger.info(f"Comm server starting on ws://0.0.0.0:{port}")
|
||||||
|
async with websockets.serve(handler, "0.0.0.0", port):
|
||||||
|
if redis_client is not None:
|
||||||
|
asyncio.create_task(redis_subscriber())
|
||||||
|
await asyncio.Future()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
# C4ISR Server Environment Variables
|
||||||
|
|
||||||
|
# MySQL
|
||||||
|
MYSQL_ROOT_PASSWORD=go2_patrol_2026
|
||||||
|
MYSQL_DATABASE=go2_patrol
|
||||||
|
MYSQL_USER=patrol_admin
|
||||||
|
MYSQL_PASSWORD=patrol_pass_2026
|
||||||
|
|
||||||
|
# MongoDB
|
||||||
|
MONGO_URI=mongodb://mongodb:27017
|
||||||
|
MONGO_DATABASE=go2_patrol
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
REDIS_URL=redis://:redis_pass_2026@redis:6379/0
|
||||||
|
|
||||||
|
# FastAPI App Server
|
||||||
|
APP_HOST=0.0.0.0
|
||||||
|
APP_PORT=8000
|
||||||
|
APP_DEBUG=true
|
||||||
|
APP_SECRET_KEY=change-me-in-production-32chars
|
||||||
|
APP_ACCESS_TOKEN_EXPIRE_MINUTES=120
|
||||||
|
|
||||||
|
# Comm Server
|
||||||
|
COMM_HOST=0.0.0.0
|
||||||
|
COMM_PORT=8001
|
||||||
|
AI_SERVER_URL=localhost:50051
|
||||||
|
|
||||||
|
# AI Server (gRPC)
|
||||||
|
AI_HOST=0.0.0.0
|
||||||
|
AI_PORT=50051
|
||||||
|
AI_DB_URL=mysql+pymysql://patrol_admin:patrol_pass_2026@localhost:3306/go2_patrol
|
||||||
|
|
||||||
|
# JWT
|
||||||
|
JWT_SECRET_KEY=change-me-jwt-secret-32chars
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
|
||||||
|
# MapLibre
|
||||||
|
MAPTILER_KEY=your-maptiler-key-here
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
CMD ["python", "-m", "src.main"]
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
CMD ["python", "-m", "src.main"]
|
||||||
@ -0,0 +1,139 @@
|
|||||||
|
version: "3.9"
|
||||||
|
|
||||||
|
services:
|
||||||
|
# ── MySQL 8.0 ──────────────────────────────────────
|
||||||
|
mysql:
|
||||||
|
image: mysql:8.0
|
||||||
|
container_name: go2_mysql
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD}
|
||||||
|
MYSQL_DATABASE: ${MYSQL_DATABASE}
|
||||||
|
MYSQL_USER: ${MYSQL_USER}
|
||||||
|
MYSQL_PASSWORD: ${MYSQL_PASSWORD}
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
volumes:
|
||||||
|
- mysql_data:/var/lib/mysql
|
||||||
|
- ./mysql/init.sql:/docker-entrypoint-initdb.d/01_schema.sql
|
||||||
|
command: >
|
||||||
|
--character-set-server=utf8mb4
|
||||||
|
--collation-server=utf8mb4_unicode_ci
|
||||||
|
--default-authentication-plugin=mysql_native_password
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
# ── MongoDB 6 ──────────────────────────────────────
|
||||||
|
mongodb:
|
||||||
|
image: mongo:6
|
||||||
|
container_name: go2_mongodb
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "27017:27017"
|
||||||
|
volumes:
|
||||||
|
- mongo_data:/data/db
|
||||||
|
- ./mongodb/init.js:/docker-entrypoint-initdb.d/init.js
|
||||||
|
environment:
|
||||||
|
MONGO_INITDB_DATABASE: ${MONGO_DATABASE}
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
# ── Redis 7 ────────────────────────────────────────
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
container_name: go2_redis
|
||||||
|
restart: unless-stopped
|
||||||
|
command: redis-server --requirepass ${REDIS_PASSWORD:-redis_pass_2026} --appendonly yes
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD:-redis_pass_2026}", "ping"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
# ── FastAPI App Server ─────────────────────────────
|
||||||
|
app_server:
|
||||||
|
build:
|
||||||
|
context: ../backend/app_server
|
||||||
|
dockerfile: ../../docker/app_server/Dockerfile
|
||||||
|
container_name: go2_app_server
|
||||||
|
restart: unless-stopped
|
||||||
|
env_file: ../.env
|
||||||
|
environment:
|
||||||
|
- DATABASE_URL=mysql+asyncmy://${MYSQL_USER}:${MYSQL_PASSWORD}@mysql:3306/${MYSQL_DATABASE}
|
||||||
|
- MONGO_URI=mongodb://mongodb:27017
|
||||||
|
- MONGO_DATABASE=${MONGO_DATABASE}
|
||||||
|
- REDIS_URL=redis://:${REDIS_PASSWORD:-redis_pass_2026}@redis:6379/0
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
depends_on:
|
||||||
|
mysql:
|
||||||
|
condition: service_healthy
|
||||||
|
mongodb:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ── Comm Server (WebSocket Hub) ────────────────────
|
||||||
|
comm_server:
|
||||||
|
build:
|
||||||
|
context: ../backend/comm_server
|
||||||
|
dockerfile: ../../docker/comm_server/Dockerfile
|
||||||
|
container_name: go2_comm_server
|
||||||
|
restart: unless-stopped
|
||||||
|
env_file: ../.env
|
||||||
|
environment:
|
||||||
|
- REDIS_URL=redis://:${REDIS_PASSWORD:-redis_pass_2026}@redis:6379/0
|
||||||
|
- AI_SERVER_URL=ai_server:50051
|
||||||
|
- MONGO_URI=mongodb://mongodb:27017
|
||||||
|
ports:
|
||||||
|
- "8001:8001"
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
ai_server:
|
||||||
|
condition: service_started
|
||||||
|
|
||||||
|
# ── AI Server (gRPC) ───────────────────────────────
|
||||||
|
ai_server:
|
||||||
|
build:
|
||||||
|
context: ../backend/ai_server
|
||||||
|
dockerfile: ../../docker/ai_server/Dockerfile
|
||||||
|
container_name: go2_ai_server
|
||||||
|
restart: unless-stopped
|
||||||
|
env_file: ../.env
|
||||||
|
environment:
|
||||||
|
- AI_DB_URL=mysql+pymysql://${MYSQL_USER}:${MYSQL_PASSWORD}@mysql:3306/${MYSQL_DATABASE}
|
||||||
|
ports:
|
||||||
|
- "50051:50051"
|
||||||
|
depends_on:
|
||||||
|
mysql:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ── Nginx Reverse Proxy ────────────────────────────
|
||||||
|
nginx:
|
||||||
|
image: nginx:alpine
|
||||||
|
container_name: go2_nginx
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
volumes:
|
||||||
|
- ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro
|
||||||
|
- ../frontend/dist:/usr/share/nginx/html:ro
|
||||||
|
depends_on:
|
||||||
|
- app_server
|
||||||
|
- comm_server
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
mysql_data:
|
||||||
|
mongo_data:
|
||||||
|
redis_data:
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
// GO2 自主巡逻系统 MongoDB 初始化脚本
|
||||||
|
// 版本: V1.0
|
||||||
|
|
||||||
|
db = db.getSiblingDB('go2_patrol');
|
||||||
|
|
||||||
|
// ── 遥测记录 ──────────────────────────────────────
|
||||||
|
db.createCollection('telemetry_records');
|
||||||
|
|
||||||
|
db.telemetry_records.createIndex({ "dog_id": 1, "timestamp": -1 });
|
||||||
|
db.telemetry_records.createIndex({ "timestamp": 1 }, { expireAfterSeconds: 7776000 }); // TTL 90 天
|
||||||
|
db.telemetry_records.createIndex({ "task_id": 1 });
|
||||||
|
|
||||||
|
// ── 检测记录 ──────────────────────────────────────
|
||||||
|
db.createCollection('detection_records');
|
||||||
|
|
||||||
|
db.detection_records.createIndex({ "dog_id": 1, "timestamp": -1 });
|
||||||
|
db.detection_records.createIndex({ "timestamp": 1 }, { expireAfterSeconds: 7776000 }); // TTL 90 天
|
||||||
|
db.detection_records.createIndex({ "detections.class_label": 1 });
|
||||||
|
|
||||||
|
// ── 巡逻执行日志 ──────────────────────────────────
|
||||||
|
db.createCollection('patrol_execution_logs');
|
||||||
|
|
||||||
|
db.patrol_execution_logs.createIndex({ "task_id": 1, "timestamp": 1 });
|
||||||
|
db.patrol_execution_logs.createIndex({ "dog_id": 1, "timestamp": -1 });
|
||||||
|
|
||||||
|
// ── 视频片段 (GridFS 元数据) ─────────────────────
|
||||||
|
db.createCollection('video_clips');
|
||||||
|
|
||||||
|
db.video_clips.createIndex({ "dog_id": 1, "start_time": -1 });
|
||||||
|
db.video_clips.createIndex({ "related_alert_id": 1 });
|
||||||
|
db.video_clips.createIndex({ "start_time": 1 }, { expireAfterSeconds: 2592000 }); // TTL 30 天
|
||||||
@ -0,0 +1,196 @@
|
|||||||
|
-- GO2 自主巡逻系统 MySQL Schema
|
||||||
|
-- 版本: V1.0
|
||||||
|
-- 日期: 2026-04-22
|
||||||
|
|
||||||
|
SET NAMES utf8mb4;
|
||||||
|
SET CHARACTER SET utf8mb4;
|
||||||
|
|
||||||
|
-- ── 用户表 ──────────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `user` (
|
||||||
|
`id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`username` VARCHAR(100) NOT NULL UNIQUE,
|
||||||
|
`password_hash` VARCHAR(255) NOT NULL,
|
||||||
|
`display_name` VARCHAR(200),
|
||||||
|
`role` ENUM('ADMIN','COMMANDER','OPERATOR') NOT NULL DEFAULT 'OPERATOR',
|
||||||
|
`is_active` BOOLEAN DEFAULT TRUE,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`last_login_at` DATETIME NULL,
|
||||||
|
INDEX `idx_role` (`role`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 机器狗表 ────────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `dog` (
|
||||||
|
`id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`name` VARCHAR(100) NOT NULL UNIQUE,
|
||||||
|
`serial_number` VARCHAR(100) UNIQUE,
|
||||||
|
`model` VARCHAR(50) DEFAULT 'GO2',
|
||||||
|
`ip_address` VARCHAR(45),
|
||||||
|
`status` ENUM('ONLINE','OFFLINE','MAINTENANCE','EMERGENCY') DEFAULT 'OFFLINE',
|
||||||
|
`last_heartbeat` DATETIME NULL,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
INDEX `idx_status` (`status`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 巡逻任务表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `patrol_task` (
|
||||||
|
`task_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`task_name` VARCHAR(200) NOT NULL,
|
||||||
|
`description` TEXT,
|
||||||
|
`status` ENUM('DRAFT','SAVED','DEPLOYED','EXECUTING','COMPLETED','CANCELLED','ABORTED')
|
||||||
|
NOT NULL DEFAULT 'DRAFT',
|
||||||
|
`assigned_dog_id` VARCHAR(36),
|
||||||
|
`created_by` VARCHAR(36) NOT NULL,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
`started_at` DATETIME NULL,
|
||||||
|
`completed_at` DATETIME NULL,
|
||||||
|
`schedule_type` ENUM('ONCE','PERIODIC','CRON') NOT NULL DEFAULT 'ONCE',
|
||||||
|
`schedule_config` JSON,
|
||||||
|
`progress` FLOAT DEFAULT 0.0,
|
||||||
|
FOREIGN KEY (`assigned_dog_id`) REFERENCES `dog`(`id`),
|
||||||
|
FOREIGN KEY (`created_by`) REFERENCES `user`(`id`),
|
||||||
|
INDEX `idx_status` (`status`),
|
||||||
|
INDEX `idx_dog` (`assigned_dog_id`),
|
||||||
|
INDEX `idx_created` (`created_at`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 航点表 ──────────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `waypoint` (
|
||||||
|
`waypoint_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`task_id` VARCHAR(36) NOT NULL,
|
||||||
|
`sequence_order` INT NOT NULL,
|
||||||
|
`latitude` DECIMAL(10,7) NOT NULL,
|
||||||
|
`longitude` DECIMAL(10,7) NOT NULL,
|
||||||
|
`altitude` DECIMAL(8,2) DEFAULT 0,
|
||||||
|
`action_type` ENUM('PASS','SCAN','HOVER','OBSERVE') NOT NULL DEFAULT 'PASS',
|
||||||
|
`dwell_time_sec` INT DEFAULT 0,
|
||||||
|
`heading_deg` DECIMAL(5,2) NULL,
|
||||||
|
FOREIGN KEY (`task_id`) REFERENCES `patrol_task`(`task_id`) ON DELETE CASCADE,
|
||||||
|
INDEX `idx_task_order` (`task_id`, `sequence_order`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 威胁存档表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `threat_archive` (
|
||||||
|
`threat_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`threat_type` VARCHAR(100) NOT NULL,
|
||||||
|
`threat_level` ENUM('LOW','MEDIUM','HIGH','CRITICAL') NOT NULL,
|
||||||
|
`status` ENUM('PENDING','CONFIRMED','DISMISSED','TRACKING','ESCALATED','RESOLVED')
|
||||||
|
NOT NULL DEFAULT 'PENDING',
|
||||||
|
`detected_by_dog` VARCHAR(36) NOT NULL,
|
||||||
|
`detected_at` DATETIME NOT NULL,
|
||||||
|
`latitude` DECIMAL(10,7),
|
||||||
|
`longitude` DECIMAL(10,7),
|
||||||
|
`altitude` DECIMAL(8,2) DEFAULT 0,
|
||||||
|
`confidence` FLOAT NOT NULL,
|
||||||
|
`class_label` VARCHAR(100),
|
||||||
|
`description` TEXT,
|
||||||
|
`image_url` VARCHAR(500),
|
||||||
|
`resolved_at` DATETIME NULL,
|
||||||
|
`resolved_by` VARCHAR(36) NULL,
|
||||||
|
`notes` TEXT,
|
||||||
|
FOREIGN KEY (`detected_by_dog`) REFERENCES `dog`(`id`),
|
||||||
|
FOREIGN KEY (`resolved_by`) REFERENCES `user`(`id`),
|
||||||
|
INDEX `idx_level` (`threat_level`),
|
||||||
|
INDEX `idx_status` (`status`),
|
||||||
|
INDEX `idx_detected_at` (`detected_at`),
|
||||||
|
INDEX `idx_dog` (`detected_by_dog`),
|
||||||
|
INDEX `idx_type_level` (`threat_type`, `threat_level`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 威胁警报表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `threat_alert` (
|
||||||
|
`alert_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`threat_id` VARCHAR(36),
|
||||||
|
`dog_id` VARCHAR(36) NULL,
|
||||||
|
`alert_level` ENUM('LOW','MEDIUM','HIGH','CRITICAL') NOT NULL,
|
||||||
|
`status` ENUM('INITIAL','PENDING','CONFIRMED','DISMISSED','TRACKING','ESCALATED','RESOLVED')
|
||||||
|
NOT NULL DEFAULT 'INITIAL',
|
||||||
|
`message` TEXT NOT NULL,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`acknowledged_at` DATETIME NULL,
|
||||||
|
`acknowledged_by` VARCHAR(36) NULL,
|
||||||
|
`resolved_at` DATETIME NULL,
|
||||||
|
FOREIGN KEY (`threat_id`) REFERENCES `threat_archive`(`threat_id`),
|
||||||
|
FOREIGN KEY (`acknowledged_by`) REFERENCES `user`(`id`),
|
||||||
|
INDEX `idx_status` (`status`),
|
||||||
|
INDEX `idx_level` (`alert_level`),
|
||||||
|
INDEX `idx_created` (`created_at`),
|
||||||
|
INDEX `idx_dog` (`dog_id`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 威胁标记表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `threat_mark` (
|
||||||
|
`mark_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`threat_id` VARCHAR(36),
|
||||||
|
`mark_type` ENUM('KEEP_OUT','CAUTION','POINT_OF_INTEREST','INCIDENT') NOT NULL,
|
||||||
|
`priority` INT DEFAULT 0,
|
||||||
|
`latitude` DECIMAL(10,7) NOT NULL,
|
||||||
|
`longitude` DECIMAL(10,7) NOT NULL,
|
||||||
|
`radius_meters` DECIMAL(8,2) DEFAULT 10.0,
|
||||||
|
`created_by` VARCHAR(36) NOT NULL,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`expires_at` DATETIME NULL,
|
||||||
|
`notes` TEXT,
|
||||||
|
FOREIGN KEY (`threat_id`) REFERENCES `threat_archive`(`threat_id`),
|
||||||
|
FOREIGN KEY (`created_by`) REFERENCES `user`(`id`),
|
||||||
|
INDEX `idx_type` (`mark_type`),
|
||||||
|
INDEX `idx_location` (`latitude`, `longitude`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 威胁规则表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `threat_rules` (
|
||||||
|
`rule_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`rule_name` VARCHAR(200) NOT NULL,
|
||||||
|
`rule_type` VARCHAR(50) NOT NULL,
|
||||||
|
`rule_condition` JSON NOT NULL,
|
||||||
|
`rule_action` JSON NOT NULL,
|
||||||
|
`priority` INT DEFAULT 0,
|
||||||
|
`is_active` BOOLEAN DEFAULT TRUE,
|
||||||
|
`created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
INDEX `idx_type_active` (`rule_type`, `is_active`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 系统参数表 ──────────────────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `system_parameter` (
|
||||||
|
`param_key` VARCHAR(200) PRIMARY KEY,
|
||||||
|
`param_value` JSON NOT NULL,
|
||||||
|
`param_section` VARCHAR(100) NOT NULL,
|
||||||
|
`description` TEXT,
|
||||||
|
`updated_by` VARCHAR(36),
|
||||||
|
`updated_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
INDEX `idx_section` (`param_section`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 警报事件表(事件溯源) ──────────────────────────
|
||||||
|
CREATE TABLE IF NOT EXISTS `alert_event` (
|
||||||
|
`event_id` VARCHAR(36) PRIMARY KEY,
|
||||||
|
`alert_id` VARCHAR(36) NOT NULL,
|
||||||
|
`event_type` ENUM('CREATED','ACKNOWLEDGED','DISMISSED','CONFIRMED','TRACKING','ESCALATED','RESOLVED')
|
||||||
|
NOT NULL,
|
||||||
|
`performed_by` VARCHAR(36),
|
||||||
|
`performed_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
`details` JSON,
|
||||||
|
FOREIGN KEY (`alert_id`) REFERENCES `threat_alert`(`alert_id`),
|
||||||
|
INDEX `idx_alert` (`alert_id`),
|
||||||
|
INDEX `idx_time` (`performed_at`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||||
|
|
||||||
|
-- ── 初始数据 ────────────────────────────────────────
|
||||||
|
INSERT IGNORE INTO `user` (`id`, `username`, `password_hash`, `display_name`, `role`) VALUES
|
||||||
|
('sys-admin-00000000-0000-0000-0000-000000000001', 'admin', '$2b$12$lJhNHUHOYTh1yN0xkvGXkO36yojupBc5H.uwB28hC8K0J2uCVwd0a', '系统管理员', 'ADMIN');
|
||||||
|
|
||||||
|
INSERT IGNORE INTO `dog` (`id`, `name`, `serial_number`, `model`, `ip_address`, `status`) VALUES
|
||||||
|
('dog-00000000-0000-0000-0000-000000000001', 'GO2-Alpha', 'GO2-EDU-001', 'GO2-EDU', '192.168.123.161', 'OFFLINE');
|
||||||
|
|
||||||
|
INSERT IGNORE INTO `system_parameter` (`param_key`, `param_value`, `param_section`, `description`) VALUES
|
||||||
|
('yolo.confidence_threshold', '0.5', 'perception', 'YOLO 检测置信度阈值'),
|
||||||
|
('yolo.model_name', '"yolov8n"', 'perception', 'YOLO 模型文件名'),
|
||||||
|
('patrol.default_speed_level', '3', 'patrol', '默认速度等级 (1-5)'),
|
||||||
|
('patrol.default_gait', '"trot"', 'patrol', '默认步态模式'),
|
||||||
|
('patrol.position_tolerance_m', '0.5', 'patrol', '航点到达判定距离 (米)'),
|
||||||
|
('patrol.heading_tolerance_deg', '15.0', 'patrol', '航点到达判定角度 (度)'),
|
||||||
|
('alert.auto_acknowledge_timeout_sec', '300', 'alert', '警报自动确认超时 (秒)'),
|
||||||
|
('comm.heartbeat_interval_sec', '5', 'communication', '心跳间隔 (秒)'),
|
||||||
|
('comm.reconnect_max_retries', '10', 'communication', '最大重连次数'),
|
||||||
|
('comm.telemetry_rate_hz', '1.0', 'communication', '遥测上报频率 (Hz)');
|
||||||
@ -0,0 +1,61 @@
|
|||||||
|
events {
|
||||||
|
worker_connections 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
http {
|
||||||
|
upstream app_server {
|
||||||
|
server app_server:8000;
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream comm_server {
|
||||||
|
server comm_server:8001;
|
||||||
|
}
|
||||||
|
|
||||||
|
# WebSocket upgrade support
|
||||||
|
map $http_upgrade $connection_upgrade {
|
||||||
|
default upgrade;
|
||||||
|
'' close;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name _;
|
||||||
|
|
||||||
|
# Frontend static files
|
||||||
|
location / {
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
index index.html;
|
||||||
|
try_files $uri $uri/ /index.html;
|
||||||
|
}
|
||||||
|
|
||||||
|
# REST API proxy
|
||||||
|
location /api/ {
|
||||||
|
proxy_pass http://app_server;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
}
|
||||||
|
|
||||||
|
# WebSocket proxy (dog-side connections + browser events)
|
||||||
|
location /ws {
|
||||||
|
proxy_pass http://comm_server;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection $connection_upgrade;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_read_timeout 3600s;
|
||||||
|
proxy_send_timeout 3600s;
|
||||||
|
}
|
||||||
|
|
||||||
|
# API docs
|
||||||
|
location /docs {
|
||||||
|
proxy_pass http://app_server;
|
||||||
|
}
|
||||||
|
|
||||||
|
location /openapi.json {
|
||||||
|
proxy_pass http://app_server;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Mock detection simulator — sends synthetic detection data to comm_server.
|
||||||
|
Useful for testing the detection pipeline without Jetson hardware.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 mock_detector.py --comm-url ws://127.0.0.1:8001 --dog-id go2-001
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
print("pip3 install websockets")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
COCO_CLASSES = {0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}
|
||||||
|
|
||||||
|
|
||||||
|
def random_detection():
|
||||||
|
cid = random.choice(list(COCO_CLASSES.keys()))
|
||||||
|
return {
|
||||||
|
"class_id": cid,
|
||||||
|
"class_label": COCO_CLASSES[cid],
|
||||||
|
"confidence": round(random.uniform(0.4, 0.95), 3),
|
||||||
|
"x_center": round(random.uniform(0.1, 0.9), 4),
|
||||||
|
"y_center": round(random.uniform(0.1, 0.9), 4),
|
||||||
|
"width": round(random.uniform(0.05, 0.25), 4),
|
||||||
|
"height": round(random.uniform(0.08, 0.35), 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--comm-url", default="ws://127.0.0.1:8001")
|
||||||
|
parser.add_argument("--dog-id", default="go2-001")
|
||||||
|
parser.add_argument("--fps", type=float, default=2)
|
||||||
|
parser.add_argument("--max-objects", type=int, default=3)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
interval = 1.0 / args.fps
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(f"[MOCK] Connecting to {args.comm_url}...")
|
||||||
|
async with websockets.connect(args.comm_url) as ws:
|
||||||
|
await ws.send(json.dumps({
|
||||||
|
"type": "register",
|
||||||
|
"client_type": "dog",
|
||||||
|
"client_id": args.dog_id + "-detector",
|
||||||
|
}))
|
||||||
|
resp = await ws.recv()
|
||||||
|
print(f"[MOCK] Registered: {resp}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
n = random.randint(0, args.max_objects)
|
||||||
|
detections = [random_detection() for _ in range(n)]
|
||||||
|
msg = {
|
||||||
|
"type": "detection",
|
||||||
|
"dog_id": args.dog_id,
|
||||||
|
"data": {
|
||||||
|
"inference_ms": round(random.uniform(5, 35), 1),
|
||||||
|
"frame_timestamp": time.time(),
|
||||||
|
"detections": detections,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await ws.send(json.dumps(msg))
|
||||||
|
count += 1
|
||||||
|
if count % 20 == 0:
|
||||||
|
print(f"[MOCK] #{count}: {len(detections)} objects")
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
except (ConnectionRefusedError, websockets.ConnectionClosed, OSError) as e:
|
||||||
|
print(f"[MOCK] Error: {e}, retry 3s...")
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(go2_nav_bringup)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
|
||||||
|
install(DIRECTORY
|
||||||
|
launch
|
||||||
|
config
|
||||||
|
DESTINATION share/${PROJECT_NAME}/
|
||||||
|
)
|
||||||
|
|
||||||
|
install(PROGRAMS
|
||||||
|
scripts/odom_to_tf.py
|
||||||
|
scripts/cmd_vel_bridge.py
|
||||||
|
DESTINATION lib/${PROJECT_NAME}/
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
Panels:
|
||||||
|
- Class: rviz_common/Displays
|
||||||
|
Help Height: 0
|
||||||
|
Name: Displays
|
||||||
|
Property Tree Widget:
|
||||||
|
Expanded: ~
|
||||||
|
Splitter Ratio: 0.5
|
||||||
|
Tree Height: 400
|
||||||
|
- Class: rviz_common/Views
|
||||||
|
Expanded:
|
||||||
|
- /Current View1
|
||||||
|
Name: Views
|
||||||
|
Splitter Ratio: 0.5
|
||||||
|
Visualization Manager:
|
||||||
|
Class: ""
|
||||||
|
Displays:
|
||||||
|
- Alpha: 0.5
|
||||||
|
Cell Count: 10
|
||||||
|
Cell Size: 1
|
||||||
|
Class: rviz_default_plugins/Grid
|
||||||
|
Color: 128; 128; 128
|
||||||
|
Enabled: true
|
||||||
|
Name: Grid
|
||||||
|
- Alpha: 1
|
||||||
|
Autocompute Intensity Bounds: true
|
||||||
|
Class: rviz_default_plugins/LaserScan
|
||||||
|
Decay Time: 0
|
||||||
|
Enabled: true
|
||||||
|
Max Intensity: 4096
|
||||||
|
Min Intensity: 0
|
||||||
|
Name: LaserScan
|
||||||
|
Position Transformer: ""
|
||||||
|
Selectable: true
|
||||||
|
Size (Pixels): 3
|
||||||
|
Size (m): 0.05
|
||||||
|
Style: Flat Squares
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Best Effort
|
||||||
|
Value: /scan
|
||||||
|
Use Fixed Frame: true
|
||||||
|
Use rainbow: true
|
||||||
|
- Alpha: 0.7
|
||||||
|
Class: rviz_default_plugins/Map
|
||||||
|
Color Scheme: map
|
||||||
|
Draw Behind: true
|
||||||
|
Enabled: true
|
||||||
|
Name: SLAM Map
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /map
|
||||||
|
- Alpha: 1
|
||||||
|
Autocompute Intensity Bounds: true
|
||||||
|
Class: rviz_default_plugins/PointCloud2
|
||||||
|
Decay Time: 0
|
||||||
|
Enabled: true
|
||||||
|
Max Intensity: 4096
|
||||||
|
Min Intensity: 0
|
||||||
|
Name: PointCloud
|
||||||
|
Position Transformer: ""
|
||||||
|
Selectable: true
|
||||||
|
Size (Pixels): 2
|
||||||
|
Size (m): 0.01
|
||||||
|
Style: Flat Squares
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Best Effort
|
||||||
|
Value: /utlidar/cloud
|
||||||
|
Use Fixed Frame: true
|
||||||
|
Use rainbow: true
|
||||||
|
- Class: rviz_default_plugins/TF
|
||||||
|
Enabled: true
|
||||||
|
Name: TF
|
||||||
|
Marker Scale: 1.0
|
||||||
|
Show Arrows: true
|
||||||
|
Show Axes: true
|
||||||
|
Show Names: true
|
||||||
|
Enabled: true
|
||||||
|
Global Options:
|
||||||
|
Background Color: 48; 48; 48
|
||||||
|
Fixed Frame: map
|
||||||
|
Frame Rate: 30
|
||||||
|
Name: root
|
||||||
|
Tools:
|
||||||
|
- Class: rviz_default_plugins/MoveCamera
|
||||||
|
- Class: rviz_default_plugins/SetInitialPose
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /initialpose
|
||||||
|
Value: true
|
||||||
|
Views:
|
||||||
|
Current:
|
||||||
|
Class: rviz_default_plugins/TopDownOrtho
|
||||||
|
Enabled: true
|
||||||
|
Name: Current View
|
||||||
|
Near Clip Distance: 0.01
|
||||||
|
Scale: 50
|
||||||
|
Target Frame: <Fixed Frame>
|
||||||
|
X: 0
|
||||||
|
Y: 0
|
||||||
|
Window Geometry:
|
||||||
|
Displays:
|
||||||
|
collapsed: false
|
||||||
|
Height: 800
|
||||||
|
Width: 1200
|
||||||
|
X: 100
|
||||||
|
Y: 100
|
||||||
@ -0,0 +1,236 @@
|
|||||||
|
amcl:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
alpha1: 0.2
|
||||||
|
alpha2: 0.2
|
||||||
|
alpha3: 0.2
|
||||||
|
alpha4: 0.2
|
||||||
|
alpha5: 0.2
|
||||||
|
base_frame_id: "base_link"
|
||||||
|
beam_skip_distance: 0.5
|
||||||
|
beam_skip_error_threshold: 0.9
|
||||||
|
beam_skip_threshold: 0.3
|
||||||
|
do_beams_skip: false
|
||||||
|
global_frame_id: "map"
|
||||||
|
lambda_short: 0.1
|
||||||
|
laser_likelihood_max_dist: 2.0
|
||||||
|
laser_max_range: 10.0
|
||||||
|
laser_min_range: 0.3
|
||||||
|
laser_model_type: "likelihood_field"
|
||||||
|
max_beams: 60
|
||||||
|
max_particles: 2000
|
||||||
|
min_particles: 500
|
||||||
|
odom_frame_id: "odom"
|
||||||
|
pf_err: 0.05
|
||||||
|
pf_z: 0.99
|
||||||
|
recovery_alpha_fast: 0.0
|
||||||
|
recovery_alpha_slow: 0.0
|
||||||
|
resample_interval: 1
|
||||||
|
robot_model_type: "nav2_amcl::DifferentialMotionModel"
|
||||||
|
save_pose_rate: 0.5
|
||||||
|
sigma_hit: 0.2
|
||||||
|
tf_broadcast: true
|
||||||
|
transform_tolerance: 1.0
|
||||||
|
update_min_a: 0.2
|
||||||
|
update_min_d: 0.25
|
||||||
|
z_hit: 0.5
|
||||||
|
z_max: 0.05
|
||||||
|
z_rand: 0.5
|
||||||
|
z_short: 0.05
|
||||||
|
scan_topic: scan
|
||||||
|
|
||||||
|
controller_server:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
controller_frequency: 10.0
|
||||||
|
min_x_velocity_threshold: 0.001
|
||||||
|
min_y_velocity_threshold: 0.5
|
||||||
|
min_theta_velocity_threshold: 0.001
|
||||||
|
progress_checker_plugins: ["progress_checker"]
|
||||||
|
goal_checker_plugins: ["general_goal_checker"]
|
||||||
|
controller_plugins: ["FollowPath"]
|
||||||
|
|
||||||
|
progress_checker:
|
||||||
|
plugin: "nav2_controller::SimpleProgressChecker"
|
||||||
|
required_movement_radius: 0.5
|
||||||
|
movement_time_allowance: 10.0
|
||||||
|
|
||||||
|
general_goal_checker:
|
||||||
|
plugin: "nav2_controller::SimpleGoalChecker"
|
||||||
|
stateful: true
|
||||||
|
xy_goal_tolerance: 0.25
|
||||||
|
yaw_goal_tolerance: 0.25
|
||||||
|
|
||||||
|
FollowPath:
|
||||||
|
plugin: "dwb_core::DWBLocalPlanner"
|
||||||
|
debug_trajectory_details: true
|
||||||
|
min_vel_x: 0.0
|
||||||
|
min_vel_y: 0.0
|
||||||
|
max_vel_x: 0.3
|
||||||
|
max_vel_y: 0.0
|
||||||
|
max_vel_theta: 1.0
|
||||||
|
min_speed_xy: 0.0
|
||||||
|
max_speed_xy: 0.3
|
||||||
|
min_speed_theta: 0.0
|
||||||
|
acc_lim_x: 2.5
|
||||||
|
acc_lim_y: 0.0
|
||||||
|
acc_lim_theta: 3.2
|
||||||
|
decel_lim_x: 2.5
|
||||||
|
decel_lim_y: 0.0
|
||||||
|
decel_lim_theta: 3.2
|
||||||
|
vx_samples: 20
|
||||||
|
vy_samples: 0
|
||||||
|
vtheta_samples: 40
|
||||||
|
sim_time: 1.5
|
||||||
|
linear_granularity: 0.05
|
||||||
|
angular_granularity: 0.025
|
||||||
|
transform_tolerance: 0.2
|
||||||
|
xy_goal_tolerance: 0.25
|
||||||
|
trans_stopped_velocity: 0.25
|
||||||
|
short_circuit_trajectory_evaluation: true
|
||||||
|
limit_vel_cmd_in_traj: false
|
||||||
|
critics: ["RotateToGoal", "Oscillation", "BaseObstacle", "GoalAlign", "PathAlign", "PathDist", "GoalDist"]
|
||||||
|
BaseObstacle.scale: 0.02
|
||||||
|
PathAlign.scale: 32.0
|
||||||
|
PathAlign.forward_point_distance: 0.1
|
||||||
|
GoalAlign.scale: 24.0
|
||||||
|
GoalAlign.forward_point_distance: 0.1
|
||||||
|
PathDist.scale: 32.0
|
||||||
|
GoalDist.scale: 24.0
|
||||||
|
RotateToGoal.scale: 32.0
|
||||||
|
RotateToGoal.slowing_factor: 5.0
|
||||||
|
RotateToGoal.lookahead_time: -1.0
|
||||||
|
|
||||||
|
local_costmap:
|
||||||
|
local_costmap:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
update_frequency: 5.0
|
||||||
|
publish_frequency: 2.0
|
||||||
|
global_frame: odom
|
||||||
|
robot_base_frame: base_link
|
||||||
|
rolling_window: true
|
||||||
|
width: 3
|
||||||
|
height: 3
|
||||||
|
resolution: 0.05
|
||||||
|
robot_radius: 0.35
|
||||||
|
plugins: ["voxel_layer", "inflation_layer"]
|
||||||
|
voxel_layer:
|
||||||
|
plugin: "nav2_costmap_2d::VoxelLayer"
|
||||||
|
enabled: true
|
||||||
|
publish_voxel_map: true
|
||||||
|
origin_z: 0.0
|
||||||
|
z_resolution: 0.05
|
||||||
|
z_voxels: 16
|
||||||
|
max_obstacle_height: 2.0
|
||||||
|
mark_threshold: 0
|
||||||
|
observation_sources: "pointcloud"
|
||||||
|
pointcloud:
|
||||||
|
topic: /utlidar/cloud
|
||||||
|
max_obstacle_height: 2.0
|
||||||
|
clearing: true
|
||||||
|
marking: true
|
||||||
|
data_type: "PointCloud2"
|
||||||
|
observation_persistence: 0.0
|
||||||
|
inflation_layer:
|
||||||
|
plugin: "nav2_costmap_2d::InflationLayer"
|
||||||
|
cost_scaling_factor: 3.0
|
||||||
|
inflation_radius: 0.8
|
||||||
|
always_send_full_costmap: true
|
||||||
|
|
||||||
|
global_costmap:
|
||||||
|
global_costmap:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
update_frequency: 1.0
|
||||||
|
publish_frequency: 1.0
|
||||||
|
global_frame: map
|
||||||
|
robot_base_frame: base_link
|
||||||
|
robot_radius: 0.35
|
||||||
|
resolution: 0.05
|
||||||
|
track_unknown_space: true
|
||||||
|
plugins: ["static_layer", "obstacle_layer", "inflation_layer"]
|
||||||
|
static_layer:
|
||||||
|
plugin: "nav2_costmap_2d::StaticLayer"
|
||||||
|
map_subscribe_transient_local: true
|
||||||
|
obstacle_layer:
|
||||||
|
plugin: "nav2_costmap_2d::ObstacleLayer"
|
||||||
|
enabled: true
|
||||||
|
observation_sources: "pointcloud"
|
||||||
|
pointcloud:
|
||||||
|
topic: /utlidar/cloud
|
||||||
|
max_obstacle_height: 2.0
|
||||||
|
clearing: true
|
||||||
|
marking: true
|
||||||
|
data_type: "PointCloud2"
|
||||||
|
inflation_layer:
|
||||||
|
plugin: "nav2_costmap_2d::InflationLayer"
|
||||||
|
cost_scaling_factor: 3.0
|
||||||
|
inflation_radius: 0.8
|
||||||
|
always_send_full_costmap: true
|
||||||
|
|
||||||
|
planner_server:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
expected_planner_frequency: 5.0
|
||||||
|
planner_plugins: ["GridBased"]
|
||||||
|
GridBased:
|
||||||
|
plugin: "nav2_navfn_planner::NavfnPlanner"
|
||||||
|
tolerance: 0.5
|
||||||
|
use_astar: true
|
||||||
|
allow_unknown: true
|
||||||
|
|
||||||
|
smoother_server:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
smoother_plugins: ["simple_smoother"]
|
||||||
|
simple_smoother:
|
||||||
|
plugin: "nav2_smoother::SimpleSmoother"
|
||||||
|
tolerance: 1.0e-10
|
||||||
|
max_its: 1000
|
||||||
|
do_refinement: true
|
||||||
|
|
||||||
|
behavior_server:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
costmap_topic: local_costmap/costmap_raw
|
||||||
|
footprint_topic: local_costmap/published_footprint
|
||||||
|
cycle_frequency: 10.0
|
||||||
|
behavior_plugins: ["spin", "backup", "drive_on_heading", "wait"]
|
||||||
|
spin:
|
||||||
|
plugin: "nav2_behaviors::Spin"
|
||||||
|
backup:
|
||||||
|
plugin: "nav2_behaviors::BackUp"
|
||||||
|
drive_on_heading:
|
||||||
|
plugin: "nav2_behaviors::DriveOnHeading"
|
||||||
|
wait:
|
||||||
|
plugin: "nav2_behaviors::Wait"
|
||||||
|
global_frame: odom
|
||||||
|
robot_base_frame: base_link
|
||||||
|
transform_tolerance: 0.1
|
||||||
|
|
||||||
|
waypoint_follower:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
loop_rate: 20
|
||||||
|
stop_on_failure: false
|
||||||
|
waypoint_task_executor_plugin: "wait_at_waypoint"
|
||||||
|
wait_at_waypoint:
|
||||||
|
plugin: "nav2_waypoint_follower::WaitAtWaypoint"
|
||||||
|
enabled: true
|
||||||
|
waypoint_pause_duration: 500
|
||||||
|
|
||||||
|
velocity_smoother:
|
||||||
|
ros__parameters:
|
||||||
|
use_sim_time: false
|
||||||
|
smoothing_frequency: 20.0
|
||||||
|
scale_velocities: false
|
||||||
|
feedback: "OPEN_LOOP"
|
||||||
|
max_velocity: [0.3, 0.0, 1.0]
|
||||||
|
min_velocity: [-0.3, 0.0, -1.0]
|
||||||
|
max_accel: [2.5, 0.0, 3.2]
|
||||||
|
max_decel: [-2.5, 0.0, -3.2]
|
||||||
|
odom_topic: "odom"
|
||||||
|
odom_duration: 0.1
|
||||||
|
deadband_velocity: [0.0, 0.0, 0.0]
|
||||||
|
velocity_timeout: 1.0
|
||||||
@ -0,0 +1,127 @@
|
|||||||
|
Panels:
|
||||||
|
- Class: rviz_common/Displays
|
||||||
|
Help Height: 0
|
||||||
|
Name: Displays
|
||||||
|
Property Tree Widget:
|
||||||
|
Expanded: ~
|
||||||
|
Splitter Ratio: 0.5
|
||||||
|
Tree Height: 400
|
||||||
|
- Class: rviz_common/Views
|
||||||
|
Expanded:
|
||||||
|
- /Current View1
|
||||||
|
Name: Views
|
||||||
|
Splitter Ratio: 0.5
|
||||||
|
Visualization Manager:
|
||||||
|
Class: ""
|
||||||
|
Displays:
|
||||||
|
- Alpha: 0.5
|
||||||
|
Cell Count: 10
|
||||||
|
Cell Size: 1
|
||||||
|
Class: rviz_default_plugins/Grid
|
||||||
|
Color: 128; 128; 128
|
||||||
|
Enabled: true
|
||||||
|
Name: Grid
|
||||||
|
- Alpha: 0.7
|
||||||
|
Class: rviz_default_plugins/Map
|
||||||
|
Color Scheme: map
|
||||||
|
Draw Behind: true
|
||||||
|
Enabled: true
|
||||||
|
Name: Map
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /map
|
||||||
|
- Alpha: 1
|
||||||
|
Class: rviz_default_plugins/Map
|
||||||
|
Color Scheme: costmap
|
||||||
|
Draw Behind: false
|
||||||
|
Enabled: true
|
||||||
|
Name: Local Costmap
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /local_costmap/costmap
|
||||||
|
- Alpha: 1
|
||||||
|
Autocompute Intensity Bounds: true
|
||||||
|
Class: rviz_default_plugins/LaserScan
|
||||||
|
Decay Time: 0
|
||||||
|
Enabled: true
|
||||||
|
Name: LaserScan
|
||||||
|
Position Transformer: ""
|
||||||
|
Selectable: true
|
||||||
|
Size (Pixels): 3
|
||||||
|
Size (m): 0.05
|
||||||
|
Style: Flat Squares
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Best Effort
|
||||||
|
Value: /scan
|
||||||
|
- Class: rviz_default_plugins/TF
|
||||||
|
Enabled: true
|
||||||
|
Name: TF
|
||||||
|
Marker Scale: 1.0
|
||||||
|
Show Arrows: false
|
||||||
|
Show Axes: true
|
||||||
|
Show Names: true
|
||||||
|
- Alpha: 1
|
||||||
|
Buffer Length: 100
|
||||||
|
Class: rviz_default_plugins/Path
|
||||||
|
Enabled: true
|
||||||
|
Name: Global Path
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /plan
|
||||||
|
- Alpha: 1
|
||||||
|
Buffer Length: 100
|
||||||
|
Class: rviz_default_plugins/Path
|
||||||
|
Color: 0; 255; 0
|
||||||
|
Enabled: true
|
||||||
|
Name: Local Path
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /local_plan
|
||||||
|
Enabled: true
|
||||||
|
Global Options:
|
||||||
|
Background Color: 48; 48; 48
|
||||||
|
Fixed Frame: map
|
||||||
|
Frame Rate: 30
|
||||||
|
Name: root
|
||||||
|
Tools:
|
||||||
|
- Class: rviz_default_plugins/MoveCamera
|
||||||
|
- Class: rviz_default_plugins/SetInitialPose
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /initialpose
|
||||||
|
- Class: rviz_default_plugins/SetGoal
|
||||||
|
Topic:
|
||||||
|
Depth: 5
|
||||||
|
Durability Policy: Volatile
|
||||||
|
Reliability Policy: Reliable
|
||||||
|
Value: /goal_pose
|
||||||
|
Value: true
|
||||||
|
Views:
|
||||||
|
Current:
|
||||||
|
Class: rviz_default_plugins/TopDownOrtho
|
||||||
|
Enabled: true
|
||||||
|
Name: Current View
|
||||||
|
Near Clip Distance: 0.01
|
||||||
|
Scale: 50
|
||||||
|
Target Frame: <Fixed Frame>
|
||||||
|
X: 0
|
||||||
|
Y: 0
|
||||||
|
Window Geometry:
|
||||||
|
Displays:
|
||||||
|
collapsed: false
|
||||||
|
Height: 800
|
||||||
|
Width: 1200
|
||||||
|
X: 100
|
||||||
|
Y: 100
|
||||||
@ -0,0 +1,14 @@
|
|||||||
|
pointcloud_to_laserscan:
|
||||||
|
ros__parameters:
|
||||||
|
target_frame: "base_link"
|
||||||
|
transform_tolerance: 0.01
|
||||||
|
min_height: -0.1
|
||||||
|
max_height: 1.0
|
||||||
|
angle_min: -3.14159 # -π
|
||||||
|
angle_max: 3.14159 # +π
|
||||||
|
angle_increment: 0.0087 # ~0.5°
|
||||||
|
scan_time: 0.0667 # 15Hz
|
||||||
|
range_min: 0.3
|
||||||
|
range_max: 10.0
|
||||||
|
use_inf: true
|
||||||
|
inf_epsilon: 1.0
|
||||||
@ -0,0 +1,66 @@
|
|||||||
|
slam_toolbox:
|
||||||
|
ros__parameters:
|
||||||
|
solver_plugin: solver_plugins::CeresSolver
|
||||||
|
ceres_linear_solver: SPARSE_NORMAL_CHOLESKY
|
||||||
|
ceres_preconditioner: SCHUR_JACOBI
|
||||||
|
ceres_trust_strategy: LM
|
||||||
|
ceres_dogleg_type: TRADITIONAL_DOGLEG
|
||||||
|
ceres_loss_function: None
|
||||||
|
|
||||||
|
# ROS Parameters
|
||||||
|
odom_frame: odom
|
||||||
|
map_frame: map
|
||||||
|
base_frame: base_link
|
||||||
|
scan_topic: /scan
|
||||||
|
mode: mapping
|
||||||
|
use_sim_time: false
|
||||||
|
|
||||||
|
# Map parameters
|
||||||
|
resolution: 0.05
|
||||||
|
max_laser_range: 10.0
|
||||||
|
minimum_time_interval: 0.5
|
||||||
|
transform_timeout: 0.2
|
||||||
|
tf_buffer_duration: 30.0
|
||||||
|
stack_size_to_use: 40000000
|
||||||
|
|
||||||
|
# Thresholds
|
||||||
|
minimum_travel_distance: 0.3
|
||||||
|
minimum_travel_heading: 0.3
|
||||||
|
|
||||||
|
# Scan matching
|
||||||
|
use_scan_matching: true
|
||||||
|
use_scan_barycenter: true
|
||||||
|
scan_buffer_size: 10
|
||||||
|
scan_buffer_maximum_scan_distance: 10.0
|
||||||
|
link_match_minimum_response_fine: 0.1
|
||||||
|
link_scan_maximum_distance: 1.5
|
||||||
|
loop_search_maximum_distance: 3.0
|
||||||
|
do_loop_closing: true
|
||||||
|
loop_match_minimum_chain_size: 10
|
||||||
|
loop_match_maximum_variance_coarse: 3.0
|
||||||
|
loop_match_minimum_response_coarse: 0.35
|
||||||
|
loop_match_minimum_response_fine: 0.45
|
||||||
|
|
||||||
|
# Correlation Parameters
|
||||||
|
correlation_search_space_dimension: 0.5
|
||||||
|
correlation_search_space_resolution: 0.01
|
||||||
|
correlation_search_space_smear_deviation: 0.1
|
||||||
|
|
||||||
|
# Loop Closure Parameters
|
||||||
|
loop_search_space_dimension: 8.0
|
||||||
|
loop_search_space_resolution: 0.05
|
||||||
|
loop_search_space_smear_deviation: 0.03
|
||||||
|
|
||||||
|
# Update and publish
|
||||||
|
distance_variance_penalty: 0.5
|
||||||
|
angle_variance_penalty: 1.0
|
||||||
|
fine_search_angle_offset: 0.00349
|
||||||
|
coarse_search_angle_offset: 0.349
|
||||||
|
coarse_angle_resolution: 0.0349
|
||||||
|
minimum_angle_penalty: 0.9
|
||||||
|
minimum_distance_penalty: 0.5
|
||||||
|
use_response_expansion: true
|
||||||
|
|
||||||
|
# Map update
|
||||||
|
map_update_interval: 2.0
|
||||||
|
enable_interactive_mode: false
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
SLAM mapping launch: pointcloud_to_laserscan + slam_toolbox
|
||||||
|
Usage: ros2 launch go2_nav_bringup mapping.launch.py
|
||||||
|
"""
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg_dir = get_package_share_directory('go2_nav_bringup')
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
# Static TF: utlidar_lidar -> base_link (LiDAR mounted on Go2 head)
|
||||||
|
Node(
|
||||||
|
package='tf2_ros',
|
||||||
|
executable='static_transform_publisher',
|
||||||
|
name='lidar_to_base_link',
|
||||||
|
arguments=['0.28', '0', '0.08', '0', '0', '0', 'base_link', 'utlidar_lidar'],
|
||||||
|
),
|
||||||
|
|
||||||
|
# Odometry to TF bridge (odom -> base_link)
|
||||||
|
Node(
|
||||||
|
package='go2_nav_bringup',
|
||||||
|
executable='odom_to_tf.py',
|
||||||
|
name='odom_tf_bridge',
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# Point cloud to laser scan conversion
|
||||||
|
Node(
|
||||||
|
package='pointcloud_to_laserscan',
|
||||||
|
executable='pointcloud_to_laserscan_node',
|
||||||
|
name='pointcloud_to_laserscan',
|
||||||
|
parameters=[os.path.join(pkg_dir, 'config', 'pointcloud_to_laserscan.yaml')],
|
||||||
|
remappings=[
|
||||||
|
('cloud_in', '/utlidar/cloud'),
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# SLAM Toolbox async mapping
|
||||||
|
Node(
|
||||||
|
package='slam_toolbox',
|
||||||
|
executable='async_slam_toolbox_node',
|
||||||
|
name='slam_toolbox',
|
||||||
|
parameters=[
|
||||||
|
os.path.join(pkg_dir, 'config', 'slam_toolbox_params.yaml'),
|
||||||
|
{'use_sim_time': False},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# RViz
|
||||||
|
Node(
|
||||||
|
package='rviz2',
|
||||||
|
executable='rviz2',
|
||||||
|
name='rviz2',
|
||||||
|
arguments=['-d', os.path.join(pkg_dir, 'config', 'mapping.rviz')],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Nav2 navigation launch: uses saved map + Nav2 stack
|
||||||
|
Usage: ros2 launch go2_nav_bringup navigation.launch.py map:=$HOME/map.yaml
|
||||||
|
"""
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch.actions import DeclareLaunchArgument, IncludeLaunchDescription
|
||||||
|
from launch.launch_description_sources import PythonLaunchDescriptionSource
|
||||||
|
from launch.conditions import IfCondition
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg_dir = get_package_share_directory('go2_nav_bringup')
|
||||||
|
nav2_bringup_dir = get_package_share_directory('nav2_bringup')
|
||||||
|
|
||||||
|
map_file = LaunchConfiguration('map')
|
||||||
|
use_rviz = LaunchConfiguration('use_rviz')
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument('map', default_value=os.path.join(pkg_dir, 'maps', 'map.yaml'),
|
||||||
|
description='Full path to map yaml file'),
|
||||||
|
DeclareLaunchArgument('use_rviz', default_value='true'),
|
||||||
|
|
||||||
|
# Static TF: utlidar_lidar -> base_link (LiDAR mounted on Go2 head)
|
||||||
|
Node(
|
||||||
|
package='tf2_ros',
|
||||||
|
executable='static_transform_publisher',
|
||||||
|
name='lidar_to_base_link',
|
||||||
|
arguments=['0.28', '0', '0.08', '0', '0', '0', 'base_link', 'utlidar_lidar'],
|
||||||
|
),
|
||||||
|
|
||||||
|
# Point cloud to laser scan
|
||||||
|
Node(
|
||||||
|
package='pointcloud_to_laserscan',
|
||||||
|
executable='pointcloud_to_laserscan_node',
|
||||||
|
name='pointcloud_to_laserscan',
|
||||||
|
parameters=[os.path.join(pkg_dir, 'config', 'pointcloud_to_laserscan.yaml')],
|
||||||
|
remappings=[
|
||||||
|
('cloud_in', '/utlidar/cloud'),
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# Nav2 bringup
|
||||||
|
IncludeLaunchDescription(
|
||||||
|
PythonLaunchDescriptionSource(
|
||||||
|
os.path.join(nav2_bringup_dir, 'launch', 'navigation_launch.py')
|
||||||
|
),
|
||||||
|
launch_arguments={
|
||||||
|
'use_sim_time': 'false',
|
||||||
|
'params_file': os.path.join(pkg_dir, 'config', 'nav2_params.yaml'),
|
||||||
|
}.items(),
|
||||||
|
),
|
||||||
|
|
||||||
|
# Map server
|
||||||
|
Node(
|
||||||
|
package='nav2_map_server',
|
||||||
|
executable='map_server',
|
||||||
|
name='map_server',
|
||||||
|
parameters=[
|
||||||
|
{'yaml_filename': map_file, 'use_sim_time': False},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# Lifecycle manager for map server
|
||||||
|
Node(
|
||||||
|
package='nav2_lifecycle_manager',
|
||||||
|
executable='lifecycle_manager',
|
||||||
|
name='lifecycle_manager_map',
|
||||||
|
parameters=[
|
||||||
|
{'use_sim_time': False},
|
||||||
|
{'autostart': True},
|
||||||
|
{'node_names': ['map_server']},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# AMCL localization
|
||||||
|
Node(
|
||||||
|
package='nav2_amcl',
|
||||||
|
executable='amcl',
|
||||||
|
name='amcl',
|
||||||
|
parameters=[
|
||||||
|
os.path.join(pkg_dir, 'config', 'nav2_params.yaml'),
|
||||||
|
{'use_sim_time': False},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# RViz
|
||||||
|
Node(
|
||||||
|
package='rviz2',
|
||||||
|
executable='rviz2',
|
||||||
|
name='rviz2',
|
||||||
|
arguments=['-d', os.path.join(pkg_dir, 'config', 'navigation.rviz')],
|
||||||
|
condition=IfCondition(use_rviz),
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>go2_nav_bringup</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Go2 SLAM mapping and Nav2 navigation bringup</description>
|
||||||
|
<maintainer email="dev@example.com">GO2 Patrol Team</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
|
||||||
|
<exec_depend>nav2_bringup</exec_depend>
|
||||||
|
<exec_depend>nav2_common</exec_depend>
|
||||||
|
<exec_depend>navigation2</exec_depend>
|
||||||
|
<exec_depend>slam_toolbox</exec_depend>
|
||||||
|
<exec_depend>pointcloud_to_laserscan</exec_depend>
|
||||||
|
<exec_depend>rviz2</exec_depend>
|
||||||
|
<exec_depend>rclpy</exec_depend>
|
||||||
|
<exec_depend>sensor_msgs</exec_depend>
|
||||||
|
<exec_depend>geometry_msgs</exec_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,139 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Bridge Nav2 /cmd_vel to Unitree Go2 SportClient.Move()
|
||||||
|
|
||||||
|
Subscribes to /cmd_vel (Twist) and sends velocity commands to Go2.
|
||||||
|
Also publishes /odom from /sportmodestate for Nav2 odometry.
|
||||||
|
"""
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from geometry_msgs.msg import Twist, Vector3
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
import math
|
||||||
|
|
||||||
|
try:
|
||||||
|
from unitree_sdk2py.go2.sport.sport_client import SportClient
|
||||||
|
HAS_SDK = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SDK = False
|
||||||
|
print("[WARN] unitree_sdk2py not available, running in simulation mode")
|
||||||
|
|
||||||
|
|
||||||
|
class CmdVelBridge(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('cmd_vel_bridge')
|
||||||
|
|
||||||
|
self.declare_parameter('dog_ip', '192.168.123.161')
|
||||||
|
self.declare_parameter('publish_rate', 20.0)
|
||||||
|
self.declare_parameter('max_linear_x', 0.3)
|
||||||
|
self.declare_parameter('max_linear_y', 0.0)
|
||||||
|
self.declare_parameter('max_angular_z', 1.0)
|
||||||
|
|
||||||
|
self.sport_client = None
|
||||||
|
if HAS_SDK:
|
||||||
|
self.sport_client = SportClient()
|
||||||
|
self.sport_client.SetTimeout(10.0)
|
||||||
|
self.sport_client.Init()
|
||||||
|
self.get_logger().info('SportClient initialized')
|
||||||
|
|
||||||
|
self.cmd_vel_sub = self.create_subscription(
|
||||||
|
Twist, '/cmd_vel', self.cmd_vel_callback, 10)
|
||||||
|
|
||||||
|
self.odom_pub = self.create_publisher(Odometry, '/odom', 10)
|
||||||
|
|
||||||
|
# Subscribe to Go2 sport mode state for odometry
|
||||||
|
self.sport_state_sub = self.create_subscription(
|
||||||
|
JointState, '/sportmodestate', self.sport_state_callback, 10)
|
||||||
|
|
||||||
|
self.latest_twist = Twist()
|
||||||
|
self.cmd_timeout = 0.5 # seconds
|
||||||
|
self.last_cmd_time = self.get_clock().now()
|
||||||
|
|
||||||
|
# Odometry state
|
||||||
|
self.x = 0.0
|
||||||
|
self.y = 0.0
|
||||||
|
self.theta = 0.0
|
||||||
|
self.prev_time = self.get_clock().now()
|
||||||
|
|
||||||
|
# Watchdog timer
|
||||||
|
self.timer = self.create_timer(1.0 / self.get_parameter('publish_rate').value, self.timer_callback)
|
||||||
|
|
||||||
|
self.get_logger().info('cmd_vel_bridge started')
|
||||||
|
|
||||||
|
def cmd_vel_callback(self, msg: Twist):
|
||||||
|
self.latest_twist = msg
|
||||||
|
self.last_cmd_time = self.get_clock().now()
|
||||||
|
|
||||||
|
if self.sport_client:
|
||||||
|
vx = max(-self.get_parameter('max_linear_x').value,
|
||||||
|
min(self.get_parameter('max_linear_x').value, msg.linear.x))
|
||||||
|
vy = max(-self.get_parameter('max_linear_y').value,
|
||||||
|
min(self.get_parameter('max_linear_y').value, msg.linear.y))
|
||||||
|
vyaw = max(-self.get_parameter('max_angular_z').value,
|
||||||
|
min(self.get_parameter('max_angular_z').value, msg.angular.z))
|
||||||
|
try:
|
||||||
|
self.sport_client.Move(vx, vy, vyaw)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Move failed: {e}')
|
||||||
|
|
||||||
|
def sport_state_callback(self, msg: JointState):
|
||||||
|
"""Convert sportmodestate to odometry.
|
||||||
|
|
||||||
|
sportmodestate contains [vx, vy, vyaw, roll, pitch, yaw, body_height, ...]
|
||||||
|
in the velocity fields.
|
||||||
|
"""
|
||||||
|
# Build a simple odometry message
|
||||||
|
now = self.get_clock().now()
|
||||||
|
dt = (now - self.prev_time).nanoseconds / 1e9
|
||||||
|
self.prev_time = now
|
||||||
|
|
||||||
|
# sportmodestate position data is in joint positions
|
||||||
|
# The exact mapping depends on Unitree's convention
|
||||||
|
# For now, integrate velocity for dead-reckoning
|
||||||
|
if len(msg.velocity) >= 3:
|
||||||
|
vx = msg.velocity[0]
|
||||||
|
vy = msg.velocity[1]
|
||||||
|
vyaw = msg.velocity[2]
|
||||||
|
|
||||||
|
self.theta += vyaw * dt
|
||||||
|
self.x += (vx * math.cos(self.theta) - vy * math.sin(self.theta)) * dt
|
||||||
|
self.y += (vx * math.sin(self.theta) + vy * math.cos(self.theta)) * dt
|
||||||
|
|
||||||
|
odom = Odometry()
|
||||||
|
odom.header = Header()
|
||||||
|
odom.header.stamp = now.to_msg()
|
||||||
|
odom.header.frame_id = 'odom'
|
||||||
|
odom.child_frame_id = 'body'
|
||||||
|
odom.pose.pose.position.x = self.x
|
||||||
|
odom.pose.pose.position.y = self.y
|
||||||
|
odom.pose.pose.orientation.z = math.sin(self.theta / 2)
|
||||||
|
odom.pose.pose.orientation.w = math.cos(self.theta / 2)
|
||||||
|
if len(msg.velocity) >= 3:
|
||||||
|
odom.twist.twist.linear.x = msg.velocity[0]
|
||||||
|
odom.twist.twist.linear.y = msg.velocity[1]
|
||||||
|
odom.twist.twist.angular.z = msg.velocity[2]
|
||||||
|
|
||||||
|
self.odom_pub.publish(odom)
|
||||||
|
|
||||||
|
def timer_callback(self):
|
||||||
|
# Stop if no cmd_vel received for a while
|
||||||
|
elapsed = (self.get_clock().now() - self.last_cmd_time).nanoseconds / 1e9
|
||||||
|
if elapsed > self.cmd_timeout and self.sport_client:
|
||||||
|
try:
|
||||||
|
self.sport_client.Move(0.0, 0.0, 0.0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = CmdVelBridge()
|
||||||
|
rclpy.spin(node)
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Subscribe to /utlidar/robot_odom and broadcast odom → base_link TF.
|
||||||
|
Also republishes as /odom for Nav2 compatibility.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from geometry_msgs.msg import TransformStamped
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from tf2_ros import TransformBroadcaster
|
||||||
|
|
||||||
|
|
||||||
|
class OdomTFBridge(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('odom_tf_bridge')
|
||||||
|
self.tf_broadcaster = TransformBroadcaster(self)
|
||||||
|
self.odom_pub = self.create_publisher(Odometry, '/odom', 10)
|
||||||
|
|
||||||
|
self.sub = self.create_subscription(
|
||||||
|
Odometry, '/utlidar/robot_odom', self.odom_callback, 10)
|
||||||
|
self.get_logger().info('odom_tf_bridge started (odom -> base_link)')
|
||||||
|
|
||||||
|
def odom_callback(self, msg: Odometry):
|
||||||
|
t = TransformStamped()
|
||||||
|
t.header = msg.header
|
||||||
|
t.child_frame_id = msg.child_frame_id
|
||||||
|
t.transform.translation.x = msg.pose.pose.position.x
|
||||||
|
t.transform.translation.y = msg.pose.pose.position.y
|
||||||
|
t.transform.translation.z = msg.pose.pose.position.z
|
||||||
|
t.transform.rotation = msg.pose.pose.orientation
|
||||||
|
self.tf_broadcaster.sendTransform(t)
|
||||||
|
|
||||||
|
odom = Odometry()
|
||||||
|
odom.header = msg.header
|
||||||
|
odom.child_frame_id = msg.child_frame_id
|
||||||
|
odom.pose = msg.pose
|
||||||
|
odom.twist = msg.twist
|
||||||
|
self.odom_pub.publish(odom)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = OdomTFBridge()
|
||||||
|
rclpy.spin(node)
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
gait_controller:
|
||||||
|
ros__parameters:
|
||||||
|
dog_ip: "192.168.123.161"
|
||||||
|
publish_rate_hz: 1.0
|
||||||
|
default_speed_level: 3
|
||||||
|
default_gait: "TrotRun"
|
||||||
|
obstacle_avoidance: true
|
||||||
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
GO2 Gait Controller - ROS 2 Lifecycle Node
|
||||||
|
|
||||||
|
Bridges /go2/cmd_vel (Twist) to Unitree SportClient.Move(vx, vy, vyaw).
|
||||||
|
Publishes /go2/robot_state (DogHealth) at 1Hz.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.lifecycle import LifecycleNode, TransitionCallbackReturn
|
||||||
|
from geometry_msgs.msg import Twist
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
|
from go2_patrol_msgs.msg import DogHealth, GaitCommand
|
||||||
|
|
||||||
|
# Unitree SDK2 Python imports — available after installing unitree_sdk2py
|
||||||
|
try:
|
||||||
|
from unitree_sdk2py.go2.sport.sport_client import SportClient
|
||||||
|
from unitree_sdk2py.go2.sport.sport_client import ObstaclesAvoidClient
|
||||||
|
HAS_SDK = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SDK = False
|
||||||
|
|
||||||
|
|
||||||
|
class GaitController(LifecycleNode):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("gait_controller")
|
||||||
|
|
||||||
|
self.declare_parameter("dog_ip", "192.168.123.161")
|
||||||
|
self.declare_parameter("publish_rate_hz", 1.0)
|
||||||
|
self.declare_parameter("default_speed_level", 3)
|
||||||
|
self.declare_parameter("default_gait", "TrotRun")
|
||||||
|
self.declare_parameter("obstacle_avoidance", True)
|
||||||
|
|
||||||
|
self._sport_client = None
|
||||||
|
self._avoid_client = None
|
||||||
|
self._cmd_vel_sub = None
|
||||||
|
self._gait_cmd_sub = None
|
||||||
|
self._health_pub = None
|
||||||
|
self._health_timer = None
|
||||||
|
self._speed_level = self.get_parameter("default_speed_level").value
|
||||||
|
self._gait_mode = self.get_parameter("default_gait").value
|
||||||
|
self._obstacle_avoidance = self.get_parameter("obstacle_avoidance").value
|
||||||
|
|
||||||
|
def on_configure(self, state):
|
||||||
|
self.get_logger().info("Configuring gait controller...")
|
||||||
|
|
||||||
|
if HAS_SDK:
|
||||||
|
dog_ip = self.get_parameter("dog_ip").value
|
||||||
|
self._sport_client = SportClient()
|
||||||
|
self._sport_client.SetTimeout(10.0)
|
||||||
|
self._sport_client.Init()
|
||||||
|
self.get_logger().info(f"SportClient initialized (target: {dog_ip})")
|
||||||
|
|
||||||
|
self._avoid_client = ObstaclesAvoidClient()
|
||||||
|
self._avoid_client.SetTimeout(10.0)
|
||||||
|
self._avoid_client.Init()
|
||||||
|
else:
|
||||||
|
self.get_logger().warn("unitree_sdk2py not available — running in simulation mode")
|
||||||
|
|
||||||
|
self._health_pub = self.create_lifecycle_publisher(DogHealth, "/go2/robot_state", 10)
|
||||||
|
return TransitionCallbackReturn.SUCCESS
|
||||||
|
|
||||||
|
def on_activate(self, state):
|
||||||
|
self.get_logger().info("Activating gait controller...")
|
||||||
|
|
||||||
|
# Stand up the dog
|
||||||
|
if self._sport_client:
|
||||||
|
self._sport_client.StandUp()
|
||||||
|
|
||||||
|
# Enable obstacle avoidance
|
||||||
|
if self._avoid_client and self._obstacle_avoidance:
|
||||||
|
self._avoid_client.Switch(True)
|
||||||
|
|
||||||
|
# Subscribe to velocity commands
|
||||||
|
self._cmd_vel_sub = self.create_subscription(
|
||||||
|
Twist, "/go2/cmd_vel", self._cmd_vel_callback, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscribe to gait commands
|
||||||
|
self._gait_cmd_sub = self.create_subscription(
|
||||||
|
GaitCommand, "/go2/gait_command", self._gait_command_callback, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Health status publisher at 1Hz
|
||||||
|
rate = self.get_parameter("publish_rate_hz").value
|
||||||
|
self._health_timer = self.create_timer(1.0 / rate, self._publish_health)
|
||||||
|
|
||||||
|
return super().on_activate(state)
|
||||||
|
|
||||||
|
def on_deactivate(self, state):
|
||||||
|
self.get_logger().info("Deactivating gait controller...")
|
||||||
|
|
||||||
|
if self._cmd_vel_sub:
|
||||||
|
self.destroy_subscription(self._cmd_vel_sub)
|
||||||
|
if self._gait_cmd_sub:
|
||||||
|
self.destroy_subscription(self._gait_cmd_sub)
|
||||||
|
if self._health_timer:
|
||||||
|
self.destroy_timer(self._health_timer)
|
||||||
|
|
||||||
|
# Damp (sit down)
|
||||||
|
if self._sport_client:
|
||||||
|
self._sport_client.Damp()
|
||||||
|
|
||||||
|
return super().on_deactivate(state)
|
||||||
|
|
||||||
|
def on_cleanup(self, state):
|
||||||
|
self._sport_client = None
|
||||||
|
self._avoid_client = None
|
||||||
|
self._health_pub = None
|
||||||
|
return TransitionCallbackReturn.SUCCESS
|
||||||
|
|
||||||
|
def on_shutdown(self, state):
|
||||||
|
self.get_logger().info("Shutting down gait controller...")
|
||||||
|
if self._sport_client:
|
||||||
|
self._sport_client.Damp()
|
||||||
|
return TransitionCallbackReturn.SUCCESS
|
||||||
|
|
||||||
|
def _cmd_vel_callback(self, msg: Twist):
|
||||||
|
if not self._sport_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
vx = msg.linear.x
|
||||||
|
vy = msg.linear.y
|
||||||
|
vyaw = msg.angular.z
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._sport_client.Move(vx, vy, vyaw)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Move command failed: {e}")
|
||||||
|
|
||||||
|
def _gait_command_callback(self, msg: GaitCommand):
|
||||||
|
if not self._sport_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if msg.gait_mode:
|
||||||
|
self._gait_mode = msg.gait_mode
|
||||||
|
gait_map = {
|
||||||
|
"StaticWalk": self._sport_client.StaticWalk,
|
||||||
|
"TrotRun": self._sport_client.TrotRun,
|
||||||
|
"EconomicGait": self._sport_client.EconomicGait,
|
||||||
|
}
|
||||||
|
gait_fn = gait_map.get(msg.gait_mode)
|
||||||
|
if gait_fn:
|
||||||
|
gait_fn()
|
||||||
|
self.get_logger().info(f"Switched gait to {msg.gait_mode}")
|
||||||
|
|
||||||
|
if msg.speed_level > 0:
|
||||||
|
self._speed_level = msg.speed_level
|
||||||
|
self._sport_client.SpeedLevel(msg.speed_level)
|
||||||
|
self.get_logger().info(f"Set speed level to {msg.speed_level}")
|
||||||
|
|
||||||
|
if self._avoid_client:
|
||||||
|
self._obstacle_avoidance = msg.obstacle_avoidance
|
||||||
|
self._avoid_client.Switch(msg.obstacle_avoidance)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Gait command failed: {e}")
|
||||||
|
|
||||||
|
def _publish_health(self):
|
||||||
|
if not self._health_pub or not self._health_pub.is_activated:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = DogHealth()
|
||||||
|
msg.header = Header()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = "base_link"
|
||||||
|
msg.dog_id = "dog-00000000-0000-0000-0000-000000000001" # TODO: from config
|
||||||
|
msg.gait_mode = self._gait_mode
|
||||||
|
msg.speed_level = self._speed_level
|
||||||
|
msg.obstacle_avoidance_enabled = self._obstacle_avoidance
|
||||||
|
msg.emergency_stop = False
|
||||||
|
|
||||||
|
# TODO: read actual battery/IMU from /lowstate topic
|
||||||
|
msg.battery_percent = 100.0
|
||||||
|
msg.battery_voltage = 28.5
|
||||||
|
msg.cpu_temp_c = 45.0
|
||||||
|
msg.gpu_temp_c = 42.0
|
||||||
|
|
||||||
|
self._health_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = GaitController()
|
||||||
|
rclpy.spin(node)
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import LifecycleNode
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
LifecycleNode(
|
||||||
|
package="go2_locomotion",
|
||||||
|
executable="gait_controller",
|
||||||
|
name="gait_controller",
|
||||||
|
output="screen",
|
||||||
|
parameters=[{
|
||||||
|
"dog_ip": "192.168.123.161",
|
||||||
|
"publish_rate_hz": 1.0,
|
||||||
|
"default_speed_level": 3,
|
||||||
|
"default_gait": "TrotRun",
|
||||||
|
"obstacle_avoidance": True,
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>go2_locomotion</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>GO2 gait controller - bridges cmd_vel to Unitree SportClient</description>
|
||||||
|
<maintainer email="dev@example.com">GO2 Patrol Team</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>go2_patrol_msgs</depend>
|
||||||
|
<depend>unitree_go</depend>
|
||||||
|
<depend>unitree_api</depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
package_name = "go2_locomotion"
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(),
|
||||||
|
data_files=[
|
||||||
|
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||||
|
("share/" + package_name, ["package.xml"]),
|
||||||
|
(os.path.join("share", package_name, "launch"), glob("launch/*.py")),
|
||||||
|
(os.path.join("share", package_name, "config"), glob("config/*.yaml")),
|
||||||
|
],
|
||||||
|
install_requires=["setuptools"],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer="GO2 Patrol Team",
|
||||||
|
maintainer_email="dev@example.com",
|
||||||
|
description="GO2 gait controller - bridges cmd_vel to Unitree SportClient",
|
||||||
|
license="MIT",
|
||||||
|
tests_require=["pytest"],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"gait_controller = go2_gait_controller.gait_controller:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(go2_patrol_msgs)
|
||||||
|
|
||||||
|
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||||
|
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(std_msgs REQUIRED)
|
||||||
|
find_package(geometry_msgs REQUIRED)
|
||||||
|
find_package(sensor_msgs REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"msg/Detection2D.msg"
|
||||||
|
"msg/DetectionArray.msg"
|
||||||
|
"msg/PatrolTaskStatus.msg"
|
||||||
|
"msg/TelemetryReport.msg"
|
||||||
|
"msg/Waypoint.msg"
|
||||||
|
"msg/DogHealth.msg"
|
||||||
|
"msg/GaitCommand.msg"
|
||||||
|
"action/FollowPatrolWaypoints.action"
|
||||||
|
"srv/GaitControl.srv"
|
||||||
|
DEPENDENCIES std_msgs geometry_msgs sensor_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
# Goal: navigate through a sequence of waypoints
|
||||||
|
go2_patrol_msgs/Waypoint[] waypoints
|
||||||
|
float32 position_tolerance_m
|
||||||
|
float32 heading_tolerance_deg
|
||||||
|
---
|
||||||
|
# Result: final status after all waypoints processed
|
||||||
|
bool success
|
||||||
|
string message
|
||||||
|
int32 waypoints_completed
|
||||||
|
float32 total_distance_m
|
||||||
|
float32 elapsed_time_sec
|
||||||
|
---
|
||||||
|
# Feedback: current progress
|
||||||
|
int32 current_waypoint_index
|
||||||
|
float32 distance_to_current_m
|
||||||
|
float32 progress
|
||||||
|
string current_status
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
# Single object detection result
|
||||||
|
int32 class_id
|
||||||
|
string class_label
|
||||||
|
float32 confidence
|
||||||
|
float32 x_center
|
||||||
|
float32 y_center
|
||||||
|
float32 width
|
||||||
|
float32 height
|
||||||
|
int32 track_id
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
# Array of detections from a single frame
|
||||||
|
std_msgs/Header header
|
||||||
|
float32 inference_ms
|
||||||
|
Detection2D[] detections
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
# Dog hardware health status
|
||||||
|
std_msgs/Header header
|
||||||
|
string dog_id
|
||||||
|
float32 battery_percent
|
||||||
|
float32 battery_voltage
|
||||||
|
string gait_mode # StaticWalk, TrotRun, EconomicGait
|
||||||
|
int32 speed_level # 1-5
|
||||||
|
float32 cpu_temp_c
|
||||||
|
float32 gpu_temp_c
|
||||||
|
bool obstacle_avoidance_enabled
|
||||||
|
bool emergency_stop
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
# Gait control command
|
||||||
|
string gait_mode # StaticWalk, TrotRun, EconomicGait
|
||||||
|
int32 speed_level # 1-5
|
||||||
|
bool obstacle_avoidance
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
# Patrol task execution status
|
||||||
|
string task_id
|
||||||
|
string status # DRAFT, SAVED, DEPLOYED, EXECUTING, COMPLETED, CANCELLED, ABORTED
|
||||||
|
int32 current_waypoint_index
|
||||||
|
int32 total_waypoints
|
||||||
|
float32 progress # 0.0 - 1.0
|
||||||
|
string current_action # PASS, SCAN, HOVER, OBSERVE
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
# Telemetry data reported to C4ISR
|
||||||
|
std_msgs/Header header
|
||||||
|
string dog_id
|
||||||
|
float64 latitude
|
||||||
|
float64 longitude
|
||||||
|
float64 altitude
|
||||||
|
float32 heading_deg
|
||||||
|
float32 speed_mps
|
||||||
|
float32 battery_percent
|
||||||
|
float32 cpu_temp_c
|
||||||
|
float32 gpu_temp_c
|
||||||
|
float32 cpu_usage_pct
|
||||||
|
float32 gpu_usage_pct
|
||||||
|
float32 memory_usage_pct
|
||||||
|
string task_id
|
||||||
|
string nav_status # IDLE, NAVIGATING, RECOVERING
|
||||||
|
float32 signal_strength_dbm
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue