import numpy as np import matplotlib.pyplot as plt import heapq import random import time # 计算曼哈顿距离 def heuristic(a, b): return abs(a[0] - b[0]) + abs(a[1] - b[1]) # A* 算法 def a_star(start, goal, grid): open_list = [] closed_list = set() heapq.heappush(open_list, (0 + heuristic(start, goal), 0, start, None)) came_from = {} g_score = {start: 0} while open_list: _, current_g, current, parent = heapq.heappop(open_list) if current == goal: path = [] while current: path.append(current) current = came_from.get(current) return path[::-1] closed_list.add(current) for neighbor in [(0, 1), (1, 0), (0, -1), (-1, 0)]: neighbor_pos = (current[0] + neighbor[0], current[1] + neighbor[1]) if (0 <= neighbor_pos[0] < len(grid)) and (0 <= neighbor_pos[1] < len(grid[0])) and grid[neighbor_pos[0]][neighbor_pos[1]] == 0: if neighbor_pos in closed_list: continue tentative_g_score = current_g + 1 if neighbor_pos not in g_score or tentative_g_score < g_score[neighbor_pos]: g_score[neighbor_pos] = tentative_g_score f_score = tentative_g_score + heuristic(neighbor_pos, goal) heapq.heappush(open_list, (f_score, tentative_g_score, neighbor_pos, current)) came_from[neighbor_pos] = current return None # 绘制网格 def draw_grid(ax,grid, path=None, open_list=None, closed_list=None, start=None, goal=None): grid_size = len(grid) # fig, ax = plt.subplots(figsize=(grid_size, grid_size)) ax.clear() ax.set_xticks(np.arange(grid_size + 2) - 0.5, minor=True) ax.set_yticks(np.arange(grid_size + 2) - 0.5, minor=True) ax.grid(which='minor', color='gray', linestyle='--', linewidth=1) # 绘制每个单元格 for i in range(grid_size): for j in range(grid_size): if grid[i][j] == 1: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='black')) # 障碍物 elif (i, j) == start: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='red')) # 起点 elif (i, j) == goal: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='blue')) # 终点 elif open_list and (i, j) in open_list: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='purple')) # 开放列表 elif path and (i, j) in path: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='green')) # 路径 elif closed_list and (i, j) in closed_list: ax.add_patch(plt.Rectangle((j, i), 1, 1, color='lightblue')) # 闭合列表 # ax.add_patch(plt.plot(path(1),path(2), label='Sine Wave', color='blue', linestyle='-', linewidth=1)) # ax.set_xticks([]) # ax.set_yticks([]) # plt.show() plt.draw() # 更新图形 plt.pause(0.1) # 暂停0.1秒,以便视觉上可以观察到变化 def generate_random_grid(size, obstacle_probability=0.45): """ 生成一个随机的网格,指定障碍物的概率。 size: 网格的大小,大小为 size x size obstacle_probability: 每个单元格成为障碍物的概率 """ grid = np.random.choice([0, 1], size=(size, size), p=[1 - obstacle_probability, obstacle_probability]) return grid def save_grid_to_txt(grid, filename="grid_data.txt"): """将二维网格保存到txt文件""" with open(filename, "w", encoding="utf-8") as f: for row in grid: # 将每行元素转为字符串,用空格分隔,末尾加换行 row_str = " ".join(map(str, row)) + "\n" f.write(row_str) print(f"网格数据已保存到 {filename}") def read_grid_from_txt(filename="grid_data.txt"): """从txt文件读取网格数据""" with open(filename, "r", encoding="utf-8") as f: grid = [] for line in f: # 去除换行符,按空格分割为整数列表 row = list(map(int, line.strip().split())) grid.append(row) return np.array(grid) # 验证读取 # saved_grid = read_grid_from_txt("correlated_01_grid.txt") # print(saved_grid) # 输出与保存的网格一致 def main(): # 网格大小 grid_size = 10 # 你可以修改为任意大小 while True: #grid = generate_random_grid(grid_size, obstacle_probability=0.45) grid = read_grid_from_txt(filename="./correlated_01_grid.txt") # 随机选择起点和终点 start = (0, 0) goal = (grid_size - 1, grid_size - 1) # 确保起点和终点不是障碍物 grid[start] = 0 grid[goal] = 0 # 执行A*算法 path = a_star(start, goal, grid) if path: break else: print('try again') save_grid_to_txt(grid, filename="./correlated_01_grid.txt") # # 可视化A*算法的搜索过程 # if path: # for step in path: # draw_grid(grid, path=path, open_list=[step], closed_list=[], start=start, goal=goal) # time.sleep(0.1) # 暂停一段时间以便观察过程 # print("路径找到:", path) # else: # print("没有找到路径") if path: print(path) else: return 1 # 拆分x和y坐标 x_vals, y_vals = zip(*path) x_vals = [x + 0.5 for x in x_vals] y_vals = [y + 0.5 for y in y_vals] # 初始化绘图 fig, ax = plt.subplots(figsize=(grid_size+1, grid_size+1)) # 循环更新路径 for i in range(len(path)): # plt.clf() # 清除图形内容 draw_grid(ax,grid, path=path[:i+1], open_list=[path[i]], closed_list=[], start=start, goal=goal) # time.sleep(0.1) # 暂停0.5秒以便观察过程 # 绘制路径 plt.plot(y_vals,x_vals, marker='o', color='b', linestyle='-', markersize=5, label='Path') plt.draw() # 更新图形 plt.show() # 绘制路径 if __name__ == "__main__": main()