You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

39 lines
1.2 KiB

import numpy as np
import math
import copy
import random
class MCTSNode:
def __init__(self, parent, prior_p):
self.parent = parent
self.children = dict() # action -> MCTSNode
self.N = 0 # 访问次数
self.W = 0.0 # 总价值
self.Q = 0.0 # 平均价值
self.P = prior_p # 先验概率
def expand(self, action_priors):
for action, prob in action_priors:
if action not in self.children:
self.children[action] = MCTSNode(self, prob)
def select(self, c_puct):
return max(self.children.items(),
key=lambda act_node: act_node[1].Q + c_puct * act_node[1].P * math.sqrt(self.N) / (1 + act_node[1].N))
def update(self, leaf_value):
self.N += 1
self.W += leaf_value
self.Q = self.W / self.N
def update_recursive(self, leaf_value):
if self.parent:
self.parent.update_recursive(-leaf_value)
self.update(leaf_value)
def is_leaf(self):
return self.children == {}
def is_root(self):
return self.parent is None