ADD file via upload

main
hnu202410040106 12 months ago
parent d8542794d9
commit 6c2ee7be3b

@ -0,0 +1,65 @@
from node import MCTSNode
from board import GomokuBoard
import numpy as np
import copy
import torch
class MCTS:
def __init__(self, policy_value_fn, c_puct=1.0, n_playout=100):
self.root = MCTSNode(None, 1.0)
self.policy_value_fn = policy_value_fn
self.c_puct = c_puct
self.n_playout = n_playout
def playout(self, state):
node = self.root
board = copy.deepcopy(state) # 这里改成用传入的状态深拷贝
# 选择阶段
while not node.is_leaf():
action, node = node.select(self.c_puct)
board.execute_move(action)
# 评估阶段
board_state = board.get_board_state()
if board_state == 1: # 黑赢
value = 1.0
elif board_state == -1: # 白赢
value = -1.0
elif board_state == 2: # 平局
value = 0.0
else:
input_tensor = board.get_board_data() # numpy array
input_tensor = torch.from_numpy(input_tensor).float() # 转成float tensor
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
input_tensor = input_tensor # 加 batch 维度 (1, C, H, W)
act_probs, value = self.policy_value_fn(input_tensor)
legal_moves = board.get_available_move()
legal_indices = [i for i, flag in enumerate(legal_moves) if flag == 1]
probs = [(i, act_probs[0][i]) for i in legal_indices]
node.expand(probs)
# 回传阶段
node.update_recursive(-value)
def get_move_probs(self, state, temp=1e-3):
for _ in range(self.n_playout):
self.playout(state)
act_visits = [(act, node.N) for act, node in self.root.children.items()]
acts, visits = zip(*act_visits)
act_probs = np.power(visits, 1.0 / temp)
act_probs /= np.sum(act_probs)
return acts, act_probs
def update_with_move(self, last_move):
if last_move in self.root.children:
self.root = self.root.children[last_move]
self.root.parent = None
else:
self.root = MCTSNode(None, 1.0)
Loading…
Cancel
Save