import numpy as np import math import copy import random class MCTSNode: def __init__(self, parent, prior_p): self.parent = parent self.children = dict() # action -> MCTSNode self.N = 0 # 访问次数 self.W = 0.0 # 总价值 self.Q = 0.0 # 平均价值 self.P = prior_p # 先验概率 def expand(self, action_priors): for action, prob in action_priors: if action not in self.children: self.children[action] = MCTSNode(self, prob) def select(self, c_puct): return max(self.children.items(), key=lambda act_node: act_node[1].Q + c_puct * act_node[1].P * math.sqrt(self.N) / (1 + act_node[1].N)) def update(self, leaf_value): self.N += 1 self.W += leaf_value self.Q = self.W / self.N def update_recursive(self, leaf_value): if self.parent: self.parent.update_recursive(-leaf_value) self.update(leaf_value) def is_leaf(self): return self.children == {} def is_root(self): return self.parent is None