parent
6c2ee7be3b
commit
6759e1c19a
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
|
||||
class MCTSNode:
|
||||
def __init__(self, parent, prior_p):
|
||||
self.parent = parent
|
||||
self.children = dict() # action -> MCTSNode
|
||||
self.N = 0 # 访问次数
|
||||
self.W = 0.0 # 总价值
|
||||
self.Q = 0.0 # 平均价值
|
||||
self.P = prior_p # 先验概率
|
||||
|
||||
def expand(self, action_priors):
|
||||
for action, prob in action_priors:
|
||||
if action not in self.children:
|
||||
self.children[action] = MCTSNode(self, prob)
|
||||
|
||||
def select(self, c_puct):
|
||||
return max(self.children.items(),
|
||||
key=lambda act_node: act_node[1].Q + c_puct * act_node[1].P * math.sqrt(self.N) / (1 + act_node[1].N))
|
||||
|
||||
def update(self, leaf_value):
|
||||
self.N += 1
|
||||
self.W += leaf_value
|
||||
self.Q = self.W / self.N
|
||||
|
||||
def update_recursive(self, leaf_value):
|
||||
if self.parent:
|
||||
self.parent.update_recursive(-leaf_value)
|
||||
self.update(leaf_value)
|
||||
|
||||
def is_leaf(self):
|
||||
return self.children == {}
|
||||
|
||||
def is_root(self):
|
||||
return self.parent is None
|
||||
Loading…
Reference in new issue