You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.2 KiB
39 lines
1.2 KiB
import numpy as np
|
|
import math
|
|
import copy
|
|
import random
|
|
|
|
class MCTSNode:
|
|
def __init__(self, parent, prior_p):
|
|
self.parent = parent
|
|
self.children = dict() # action -> MCTSNode
|
|
self.N = 0 # 访问次数
|
|
self.W = 0.0 # 总价值
|
|
self.Q = 0.0 # 平均价值
|
|
self.P = prior_p # 先验概率
|
|
|
|
def expand(self, action_priors):
|
|
for action, prob in action_priors:
|
|
if action not in self.children:
|
|
self.children[action] = MCTSNode(self, prob)
|
|
|
|
def select(self, c_puct):
|
|
return max(self.children.items(),
|
|
key=lambda act_node: act_node[1].Q + c_puct * act_node[1].P * math.sqrt(self.N) / (1 + act_node[1].N))
|
|
|
|
def update(self, leaf_value):
|
|
self.N += 1
|
|
self.W += leaf_value
|
|
self.Q = self.W / self.N
|
|
|
|
def update_recursive(self, leaf_value):
|
|
if self.parent:
|
|
self.parent.update_recursive(-leaf_value)
|
|
self.update(leaf_value)
|
|
|
|
def is_leaf(self):
|
|
return self.children == {}
|
|
|
|
def is_root(self):
|
|
return self.parent is None
|