You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
project/Src/command_center/rrt_algorithm.py

316 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
# File: rrt_algorithm.py
# Purpose: 实现RRT(快速随机树)路径规划算法,支持避开危险区域
import numpy as np
import math
import random
from typing import List, Tuple, Dict, Optional
class Node:
"""RRT树节点"""
def __init__(self, x: float, y: float):
self.x = x
self.y = y
self.parent = None
self.cost = 0.0 # 从起点到当前节点的代价
class RRTAlgorithm:
"""RRT路径规划算法"""
def __init__(self,
grid_resolution: float = 5.0,
max_iterations: int = 2000,
step_size: float = 20.0,
goal_sample_rate: float = 0.1,
search_radius: float = 50.0):
"""
初始化RRT算法
Args:
grid_resolution: 网格分辨率
max_iterations: 最大迭代次数
step_size: 步长
goal_sample_rate: 采样目标点的概率
search_radius: 搜索半径
"""
self.grid_resolution = grid_resolution
self.max_iterations = max_iterations
self.step_size = step_size
self.goal_sample_rate = goal_sample_rate
self.search_radius = search_radius
def plan(self, start: Tuple[float, float],
goal: Tuple[float, float],
map_width: int, map_height: int,
threat_areas: List[Dict],
obstacles: List[Tuple[float, float]] = None) -> List[Tuple[float, float]]:
"""
使用RRT算法规划路径
Args:
start: 起点坐标 (x, y)
goal: 终点坐标 (x, y)
map_width: 地图宽度
map_height: 地图高度
threat_areas: 危险区域列表
obstacles: 障碍物列表
Returns:
路径点列表,如果找不到路径返回空列表
"""
# 创建起点和终点节点
start_node = Node(start[0], start[1])
goal_node = Node(goal[0], goal[1])
# 初始化树
tree = [start_node]
# 迭代构建树
for i in range(self.max_iterations):
# 以一定概率直接取目标点,提高搜索效率
if random.random() < self.goal_sample_rate:
random_point = (goal_node.x, goal_node.y)
else:
# 随机采样点
random_point = (
random.uniform(0, map_width),
random.uniform(0, map_height)
)
# 找到树中离随机点最近的节点
nearest_node = self._find_nearest_node(tree, random_point)
# 朝随机点方向扩展固定步长
new_node = self._steer(nearest_node, random_point, self.step_size)
# 检查路径是否有效
if new_node and self._is_path_valid(nearest_node, new_node, threat_areas, obstacles):
# 将新节点添加到树中
new_node.parent = nearest_node
new_node.cost = nearest_node.cost + self._distance((nearest_node.x, nearest_node.y),
(new_node.x, new_node.y))
tree.append(new_node)
# 检查是否可以连接到目标点
dist_to_goal = self._distance((new_node.x, new_node.y), (goal_node.x, goal_node.y))
if dist_to_goal < self.step_size:
# 尝试直接连接到目标点
if self._is_path_valid(new_node, goal_node, threat_areas, obstacles):
goal_node.parent = new_node
goal_node.cost = new_node.cost + dist_to_goal
# 找到路径,提前结束
return self._extract_path(goal_node)
# RRT* 优化 (可选,提高路径质量)
# 尝试重新连接附近节点以获得更优路径
self._rewire(tree, new_node, threat_areas, obstacles)
# 如果达到最大迭代次数但仍未找到到达目标的路径
# 尝试找到离目标最近的节点,并返回到该节点的路径
closest_node = self._find_nearest_node(tree, (goal_node.x, goal_node.y))
path = self._extract_path(closest_node)
# 如果最近节点离目标足够近,则认为找到了路径
if self._distance((closest_node.x, closest_node.y), (goal_node.x, goal_node.y)) < self.step_size * 2:
return path
# 否则找不到有效路径,返回空列表
return []
def _find_nearest_node(self, tree: List[Node], point: Tuple[float, float]) -> Node:
"""找到树中离指定点最近的节点"""
min_dist = float('inf')
nearest_node = None
for node in tree:
dist = self._distance((node.x, node.y), point)
if dist < min_dist:
min_dist = dist
nearest_node = node
return nearest_node
def _steer(self, from_node: Node, to_point: Tuple[float, float], step_size: float) -> Optional[Node]:
"""从起始节点向目标点方向移动固定步长"""
dx = to_point[0] - from_node.x
dy = to_point[1] - from_node.y
dist = math.sqrt(dx*dx + dy*dy)
if dist < 0.0001: # 距离近乎为0
return None
# 按照指定的步长移动
ratio = min(step_size / dist, 1.0)
new_x = from_node.x + dx * ratio
new_y = from_node.y + dy * ratio
new_node = Node(new_x, new_y)
return new_node
def _is_path_valid(self, from_node: Node, to_node: Node,
threat_areas: List[Dict], obstacles: List[Tuple[float, float]] = None) -> bool:
"""检查两个节点之间的路径是否有效(不经过障碍物和危险区域)"""
# 采样点检测
dist = self._distance((from_node.x, from_node.y), (to_node.x, to_node.y))
num_samples = max(int(dist / self.grid_resolution), 5)
for i in range(num_samples + 1):
t = i / num_samples
x = from_node.x + t * (to_node.x - from_node.x)
y = from_node.y + t * (to_node.y - from_node.y)
# 检查是否在危险区域内
if self._is_in_threat_areas(x, y, threat_areas):
return False
# 检查是否与障碍物碰撞
if obstacles and self._is_in_obstacles(x, y, obstacles):
return False
return True
def _rewire(self, tree: List[Node], new_node: Node,
threat_areas: List[Dict], obstacles: List[Tuple[float, float]] = None) -> None:
"""重新连接树中节点以优化路径代价 (RRT* 算法的核心步骤)"""
# 找到搜索半径内的所有节点
near_nodes = []
for node in tree:
if node is not new_node: # 排除新节点本身
dist = self._distance((node.x, node.y), (new_node.x, new_node.y))
if dist < self.search_radius:
near_nodes.append(node)
# 对于搜索半径内的每个节点,检查是否通过新节点创建的路径更好
for node in near_nodes:
# 计算通过新节点到达当前节点的潜在代价
potential_cost = new_node.cost + self._distance((new_node.x, new_node.y), (node.x, node.y))
# 如果代价更低且路径有效,则重新连接
if potential_cost < node.cost and self._is_path_valid(new_node, node, threat_areas, obstacles):
node.parent = new_node
node.cost = potential_cost
def _extract_path(self, goal_node: Node) -> List[Tuple[float, float]]:
"""从目标节点回溯生成路径"""
path = []
current = goal_node
# 回溯到起点
while current:
path.append((current.x, current.y))
current = current.parent
# 反转路径顺序(从起点到终点)
return list(reversed(path))
def _distance(self, p1: Tuple[float, float], p2: Tuple[float, float]) -> float:
"""计算两点间欧几里得距离"""
return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)
def _is_in_threat_areas(self, x: float, y: float, threat_areas: List[Dict]) -> bool:
"""检查点是否在危险区域中"""
for area in threat_areas:
area_type = area.get('type', '')
if area_type == 'circle':
center = area.get('center', (0, 0))
radius = area.get('radius', 0)
distance = math.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
if distance <= radius:
return True
elif area_type == 'rectangle':
rect = area.get('rect', (0, 0, 0, 0))
x1, y1, width, height = rect
if x1 <= x <= x1 + width and y1 <= y <= y1 + height:
return True
elif area_type == 'polygon':
points = area.get('points', [])
if self._point_in_polygon(x, y, points):
return True
return False
def _point_in_polygon(self, x: float, y: float, polygon: List[Tuple[float, float]]) -> bool:
"""检查点是否在多边形内(射线法)"""
if len(polygon) < 3:
return False
inside = False
j = len(polygon) - 1
for i in range(len(polygon)):
xi, yi = polygon[i]
xj, yj = polygon[j]
# 检查点是否在多边形边上
if (yi == y and xi == x) or (yj == y and xj == x):
return True
# 射线法判断点是否在多边形内部
intersect = ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi)
if intersect:
inside = not inside
j = i
return inside
def _is_in_obstacles(self, x: float, y: float, obstacles: List[Tuple[float, float]]) -> bool:
"""检查点是否在障碍物中"""
for obstacle_x, obstacle_y in obstacles:
distance = math.sqrt((x - obstacle_x) ** 2 + (y - obstacle_y) ** 2)
if distance < self.grid_resolution:
return True
return False
def smooth_path(self, path: List[Tuple[float, float]],
threat_areas: List[Dict],
obstacles: List[Tuple[float, float]] = None,
weight_data: float = 0.5,
weight_smooth: float = 0.3,
tolerance: float = 0.000001) -> List[Tuple[float, float]]:
"""
使用平滑算法优化路径 (与A*接口兼容)
Args:
path: 原始路径
threat_areas: 危险区域
obstacles: 障碍物
weight_data: 数据权重
weight_smooth: 平滑权重
tolerance: 收敛阈值
Returns:
平滑后的路径
"""
# 如果路径点太少,不需要平滑
if len(path) <= 2:
return path
# 将路径转换为numpy数组方便操作
path_array = np.array(path)
smooth_path = path_array.copy()
# 迭代优化
change = tolerance
while change >= tolerance:
change = 0.0
# 对每个点进行优化(除了起点和终点)
for i in range(1, len(path) - 1):
for j in range(2): # x, y
aux = smooth_path[i][j]
smooth_path[i][j] += weight_data * (path_array[i][j] - smooth_path[i][j])
smooth_path[i][j] += weight_smooth * (smooth_path[i-1][j] + smooth_path[i+1][j] - 2.0 * smooth_path[i][j])
change += abs(aux - smooth_path[i][j])
# 检查平滑后的路径是否经过危险区域或障碍物
for i in range(len(smooth_path)):
x, y = smooth_path[i]
if self._is_in_threat_areas(x, y, threat_areas) or (obstacles and self._is_in_obstacles(x, y, obstacles)):
return path # 如果平滑后路径穿过危险区域,返回原路径
return [(x, y) for x, y in smooth_path]