diff --git a/mcts.py b/mcts.py new file mode 100644 index 0000000..946a42e --- /dev/null +++ b/mcts.py @@ -0,0 +1,65 @@ +from node import MCTSNode +from board import GomokuBoard +import numpy as np +import copy +import torch + +class MCTS: + def __init__(self, policy_value_fn, c_puct=1.0, n_playout=100): + self.root = MCTSNode(None, 1.0) + self.policy_value_fn = policy_value_fn + self.c_puct = c_puct + self.n_playout = n_playout + + def playout(self, state): + node = self.root + board = copy.deepcopy(state) # 这里改成用传入的状态深拷贝 + + # 选择阶段 + while not node.is_leaf(): + action, node = node.select(self.c_puct) + board.execute_move(action) + + # 评估阶段 + board_state = board.get_board_state() + if board_state == 1: # 黑赢 + value = 1.0 + elif board_state == -1: # 白赢 + value = -1.0 + elif board_state == 2: # 平局 + value = 0.0 + else: + + input_tensor = board.get_board_data() # numpy array + input_tensor = torch.from_numpy(input_tensor).float() # 转成float tensor + if torch.cuda.is_available(): + input_tensor = input_tensor.cuda() + input_tensor = input_tensor # 加 batch 维度 (1, C, H, W) + + act_probs, value = self.policy_value_fn(input_tensor) + + legal_moves = board.get_available_move() + legal_indices = [i for i, flag in enumerate(legal_moves) if flag == 1] + probs = [(i, act_probs[0][i]) for i in legal_indices] + node.expand(probs) + + # 回传阶段 + node.update_recursive(-value) + + + def get_move_probs(self, state, temp=1e-3): + for _ in range(self.n_playout): + self.playout(state) + + act_visits = [(act, node.N) for act, node in self.root.children.items()] + acts, visits = zip(*act_visits) + act_probs = np.power(visits, 1.0 / temp) + act_probs /= np.sum(act_probs) + return acts, act_probs + + def update_with_move(self, last_move): + if last_move in self.root.children: + self.root = self.root.children[last_move] + self.root.parent = None + else: + self.root = MCTSNode(None, 1.0)