You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

310 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from models import (AggregatedStatus, SafetyCommand, SafetyLevel, DegradationMode,
DecisionStrategy, ModeTransition)
from typing import Optional, List, Dict
from config_manager import ConfigManager
from collections import Counter
import uuid
import time
import logging
logger = logging.getLogger("Decider")
class ScoreBasedStrategy:
"""基于评分的决策策略。
为每个检测模块的异常级别分配分数,累计总分与阈值比较,
同时考虑滞后窗口平滑。这是默认策略。
"""
def __init__(self, config: ConfigManager):
guardian_cfg = config.get("guardian") or {}
self.error_threshold = guardian_cfg.get("error_threshold", 25)
self.fatal_threshold = guardian_cfg.get("fatal_threshold", 50)
self.hysteresis = guardian_cfg.get("hysteresis_window", 3)
self.level_scores = {
SafetyLevel.WARN: 5,
SafetyLevel.ERROR: 15,
SafetyLevel.FATAL: 30,
}
self._score_history = []
def decide(self, aggregated: List[AggregatedStatus]) -> SafetyCommand:
total_score = 0
reasons = []
for status in aggregated:
if status.overall_level == SafetyLevel.OK:
continue
score = self.level_scores.get(status.overall_level, 0)
total_score += score
reasons.append(f"{status.module}:{status.overall_level.value}(+{score})")
self._score_history.append(total_score)
if len(self._score_history) > self.hysteresis:
self._score_history.pop(0)
smoothed_score = sum(self._score_history) / len(self._score_history)
if smoothed_score >= self.fatal_threshold or any(
s.overall_level == SafetyLevel.FATAL for s in aggregated
):
target_mode = DegradationMode.L3_EMERGENCY_STOP
level = SafetyLevel.FATAL
elif smoothed_score >= self.error_threshold or any(
s.overall_level == SafetyLevel.ERROR for s in aggregated
):
target_mode = DegradationMode.L2_SOFT_STOP
level = SafetyLevel.ERROR
elif total_score > 0:
target_mode = DegradationMode.L1_LIMITED
level = SafetyLevel.WARN
else:
target_mode = DegradationMode.L0_NORMAL
level = SafetyLevel.OK
cmd = SafetyCommand(
command_id=str(uuid.uuid4())[:8],
level=level,
target_mode=target_mode,
reasons=reasons if reasons else ["All systems nominal"],
strategy=DecisionStrategy.SCORE_BASED,
)
return cmd
class VotingStrategy:
"""基于投票的决策策略。
所有检测模块投票决定最终安全模式,
多数异常模块决定模式升级。
"""
def __init__(self, config: ConfigManager):
guardian_cfg = config.get("guardian") or {}
self.error_vote_ratio = guardian_cfg.get("error_vote_ratio", 0.3)
self.fatal_vote_ratio = guardian_cfg.get("fatal_vote_ratio", 0.5)
self.fatal_veto = guardian_cfg.get("fatal_veto_enabled", True)
def decide(self, aggregated: List[AggregatedStatus]) -> SafetyCommand:
if not aggregated:
return SafetyCommand(
command_id=str(uuid.uuid4())[:8],
level=SafetyLevel.OK,
target_mode=DegradationMode.L0_NORMAL,
reasons=["No checks active"],
strategy=DecisionStrategy.VOTING,
)
total = len(aggregated)
vote_counts = Counter(s.overall_level for s in aggregated)
reasons = [f"{s.module}:{s.overall_level.value}" for s in aggregated
if s.overall_level != SafetyLevel.OK]
# FATAL 一票否决
if self.fatal_veto and vote_counts.get(SafetyLevel.FATAL, 0) > 0:
return SafetyCommand(
command_id=str(uuid.uuid4())[:8],
level=SafetyLevel.FATAL,
target_mode=DegradationMode.L3_EMERGENCY_STOP,
reasons=reasons,
strategy=DecisionStrategy.VOTING,
)
error_and_fatal = vote_counts.get(SafetyLevel.ERROR, 0) + vote_counts.get(SafetyLevel.FATAL, 0)
error_ratio = error_and_fatal / total if total > 0 else 0
if error_ratio >= self.fatal_vote_ratio:
target_mode = DegradationMode.L3_EMERGENCY_STOP
level = SafetyLevel.FATAL
elif error_ratio >= self.error_vote_ratio:
target_mode = DegradationMode.L2_SOFT_STOP
level = SafetyLevel.ERROR
elif vote_counts.get(SafetyLevel.WARN, 0) > 0:
target_mode = DegradationMode.L1_LIMITED
level = SafetyLevel.WARN
else:
target_mode = DegradationMode.L0_NORMAL
level = SafetyLevel.OK
return SafetyCommand(
command_id=str(uuid.uuid4())[:8],
level=level,
target_mode=target_mode,
reasons=reasons if reasons else ["All systems nominal"],
strategy=DecisionStrategy.VOTING,
)
class PriorityBasedStrategy:
"""基于优先级的决策策略。
不同模块有不同的优先级,高优先级模块的异常
会在决策中占更大权重。
"""
# 模块优先级(数值越大优先级越高)
MODULE_PRIORITY = {
"control": 10,
"collision": 9,
"localization": 8,
"communication": 7,
"vehicle": 6,
"resource": 5,
"system": 4,
"unknown": 1,
}
# 优先级加权分数
PRIORITY_WEIGHT_SCORES = {
SafetyLevel.WARN: 2,
SafetyLevel.ERROR: 8,
SafetyLevel.FATAL: 25,
}
def __init__(self, config: ConfigManager):
guardian_cfg = config.get("guardian") or {}
self.error_threshold = guardian_cfg.get("priority_error_threshold", 20)
self.fatal_threshold = guardian_cfg.get("priority_fatal_threshold", 40)
# 允许自定义优先级
custom_priority = guardian_cfg.get("module_priority", {})
self.MODULE_PRIORITY.update(custom_priority)
def decide(self, aggregated: List[AggregatedStatus]) -> SafetyCommand:
total_weighted_score = 0
reasons = []
max_priority = 0
max_priority_level = SafetyLevel.OK
for status in aggregated:
if status.overall_level == SafetyLevel.OK:
continue
priority = self.MODULE_PRIORITY.get(status.module, 1)
base_score = self.PRIORITY_WEIGHT_SCORES.get(status.overall_level, 0)
weighted_score = base_score * priority
total_weighted_score += weighted_score
reasons.append(
f"{status.module}:{status.overall_level.value}"
f"(pri={priority},score={weighted_score})"
)
if priority > max_priority:
max_priority = priority
max_priority_level = status.overall_level
# FATAL 一票否决
if any(s.overall_level == SafetyLevel.FATAL for s in aggregated):
target_mode = DegradationMode.L3_EMERGENCY_STOP
level = SafetyLevel.FATAL
elif total_weighted_score >= self.fatal_threshold:
target_mode = DegradationMode.L3_EMERGENCY_STOP
level = SafetyLevel.FATAL
elif total_weighted_score >= self.error_threshold:
target_mode = DegradationMode.L2_SOFT_STOP
level = SafetyLevel.ERROR
elif total_weighted_score > 0:
target_mode = DegradationMode.L1_LIMITED
level = SafetyLevel.WARN
else:
target_mode = DegradationMode.L0_NORMAL
level = SafetyLevel.OK
return SafetyCommand(
command_id=str(uuid.uuid4())[:8],
level=level,
target_mode=target_mode,
reasons=reasons if reasons else ["All systems nominal"],
strategy=DecisionStrategy.PRIORITY_BASED,
)
class SafetyDecider:
"""安全决策器。
根据所有检测项的结果,使用选定的决策策略
生成刹停策略(缓刹/急刹)。
"""
def __init__(self, config_manager: ConfigManager):
self.config = config_manager
guardian_cfg = config_manager.get("guardian") or {}
# 选择决策策略
strategy_str = guardian_cfg.get("decision_strategy", "SCORE_BASED").upper()
try:
self._strategy_type = DecisionStrategy[strategy_str]
except KeyError:
self._strategy_type = DecisionStrategy.SCORE_BASED
# 创建策略实例
self._strategy = self._create_strategy(config_manager)
# 命令去重
self._last_command: Optional[SafetyCommand] = None
# 模式转换记录
self._mode_transitions: List[ModeTransition] = []
self._max_transitions = 100
def _create_strategy(self, config_manager: ConfigManager):
"""根据配置创建决策策略。"""
if self._strategy_type == DecisionStrategy.VOTING:
logger.info("Decision strategy: VOTING")
return VotingStrategy(config_manager)
elif self._strategy_type == DecisionStrategy.PRIORITY_BASED:
logger.info("Decision strategy: PRIORITY_BASED")
return PriorityBasedStrategy(config_manager)
else:
logger.info("Decision strategy: SCORE_BASED")
return ScoreBasedStrategy(config_manager)
def decide(self, aggregated: List[AggregatedStatus]) -> SafetyCommand:
"""根据聚合结果生成安全指令。"""
cmd = self._strategy.decide(aggregated)
# 命令去重相同目标模式不重复发送FATAL 除外)
if (self._last_command and
self._last_command.target_mode == cmd.target_mode and
cmd.level != SafetyLevel.FATAL):
logger.debug(f"Command unchanged, suppressed: {cmd.to_log()}")
return self._last_command
# 记录模式转换
if self._last_command and self._last_command.target_mode != cmd.target_mode:
transition = ModeTransition(
from_mode=self._last_command.target_mode,
to_mode=cmd.target_mode,
trigger_level=cmd.level,
trigger_source=", ".join(cmd.reasons[:3]),
command_id=cmd.command_id,
)
self._mode_transitions.append(transition)
if len(self._mode_transitions) > self._max_transitions:
self._mode_transitions.pop(0)
if transition.is_escalation():
logger.warning(f"Mode escalation: {transition.from_mode.name} -> "
f"{transition.to_mode.name} (trigger: {cmd.level.value})")
else:
logger.info(f"Mode de-escalation: {transition.from_mode.name} -> "
f"{transition.to_mode.name}")
self._last_command = cmd
logger.info(f"Decided: {cmd.to_log()}")
return cmd
def get_strategy_name(self) -> str:
"""获取当前策略名称。"""
return self._strategy_type.name
def get_mode_transitions(self, count: int = 10) -> List[ModeTransition]:
"""获取最近的模式转换记录。"""
return self._mode_transitions[-count:]
def get_last_command(self) -> Optional[SafetyCommand]:
"""获取最近的安全指令。"""
return self._last_command