parent
d8542794d9
commit
6c2ee7be3b
@ -0,0 +1,65 @@
|
||||
from node import MCTSNode
|
||||
from board import GomokuBoard
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
|
||||
class MCTS:
|
||||
def __init__(self, policy_value_fn, c_puct=1.0, n_playout=100):
|
||||
self.root = MCTSNode(None, 1.0)
|
||||
self.policy_value_fn = policy_value_fn
|
||||
self.c_puct = c_puct
|
||||
self.n_playout = n_playout
|
||||
|
||||
def playout(self, state):
|
||||
node = self.root
|
||||
board = copy.deepcopy(state) # 这里改成用传入的状态深拷贝
|
||||
|
||||
# 选择阶段
|
||||
while not node.is_leaf():
|
||||
action, node = node.select(self.c_puct)
|
||||
board.execute_move(action)
|
||||
|
||||
# 评估阶段
|
||||
board_state = board.get_board_state()
|
||||
if board_state == 1: # 黑赢
|
||||
value = 1.0
|
||||
elif board_state == -1: # 白赢
|
||||
value = -1.0
|
||||
elif board_state == 2: # 平局
|
||||
value = 0.0
|
||||
else:
|
||||
|
||||
input_tensor = board.get_board_data() # numpy array
|
||||
input_tensor = torch.from_numpy(input_tensor).float() # 转成float tensor
|
||||
if torch.cuda.is_available():
|
||||
input_tensor = input_tensor.cuda()
|
||||
input_tensor = input_tensor # 加 batch 维度 (1, C, H, W)
|
||||
|
||||
act_probs, value = self.policy_value_fn(input_tensor)
|
||||
|
||||
legal_moves = board.get_available_move()
|
||||
legal_indices = [i for i, flag in enumerate(legal_moves) if flag == 1]
|
||||
probs = [(i, act_probs[0][i]) for i in legal_indices]
|
||||
node.expand(probs)
|
||||
|
||||
# 回传阶段
|
||||
node.update_recursive(-value)
|
||||
|
||||
|
||||
def get_move_probs(self, state, temp=1e-3):
|
||||
for _ in range(self.n_playout):
|
||||
self.playout(state)
|
||||
|
||||
act_visits = [(act, node.N) for act, node in self.root.children.items()]
|
||||
acts, visits = zip(*act_visits)
|
||||
act_probs = np.power(visits, 1.0 / temp)
|
||||
act_probs /= np.sum(act_probs)
|
||||
return acts, act_probs
|
||||
|
||||
def update_with_move(self, last_move):
|
||||
if last_move in self.root.children:
|
||||
self.root = self.root.children[last_move]
|
||||
self.root.parent = None
|
||||
else:
|
||||
self.root = MCTSNode(None, 1.0)
|
||||
Loading…
Reference in new issue