You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
4.4 KiB

from policynet import PolicyValueNet
from net import SimpleNet
import numpy as np
import time
import threading
import concurrent.futures
from collections import deque
import random
import pickle
import os
from shutil import copy
from mcts import MCTS
from board import GomokuBoard
import torch
config = {
# GomokuBoard
"n": 15,
"n_in_row": 5,
"moveCacheNum": 10,
# NeuralNet
"model_path": "./Models/Libtorch/",
"use_gpu": True,
# MCTS
"thread_num": 64,
"mcts_branch_num": 1000,
"c_puct": 5,
"c_virtual_loss": 3,
# train
"train_model_path":"./Models/Pytorch/",
"samples_path": "./Data/checkpoint.example",
"train_iters": 1000,
"batch_size": 512,
"epochs": 15,
"explore_threshold": 18,
"dirichlet_alpha": 0.06,
"dirichlet_theta": 0.25,
"lr": 0.0005,
"temp": 1,
"self_play_threadNum":10,
"max_buffer_len": 80000, #训练数据的缓冲区支持数据条数
"comparison_freq": 25,
"comparison_times": 10,
"update_threshold":0.55,
"explore_prob":0.8,
"trainSamples_save_freq":10,
"Yixin_freq":5
}
class TrainPipeline:
def __init__(self, config):
self.config = config
def self_play(self, policy_net):
board = GomokuBoard(self.config["n"], self.config["n_in_row"], self.config["moveCacheNum"])
Mcts = MCTS(policy_net, c_puct=self.config["c_puct"], n_playout=self.config["mcts_branch_num"])
train_data = {"board_data": [], "policy_probs": [], "value": []}
temp = self.config["temp"]
move_count = 0
start = time.time()
while board.get_board_state() == 0:
board_data = board.get_board_data()
acts, probs = Mcts.get_move_probs(board, temp=temp)
# 数据增强
equi_boards, equi_probs = self.getEquiDataSet(board_data, probs)
train_data["board_data"] += equi_boards
train_data["policy_probs"] += equi_probs
# 添加 Dirichlet 噪声用于探索
legal_moves = board.get_available_move()
legal_indices = [i for i, x in enumerate(legal_moves) if x == 1]
noise = self.config["dirichlet_theta"] * np.random.dirichlet(
self.config["dirichlet_alpha"] * np.ones(len(legal_indices))
)
for i, idx in enumerate(legal_indices):
probs[idx] = (1 - self.config["dirichlet_theta"]) * probs[idx] + noise[i]
probs /= np.sum(probs)
action = np.random.choice(len(probs), p=probs)
board.execute_move(action)
Mcts.update_with_move(action)
move_count += 1
if move_count >= self.config["explore_threshold"]:
temp = 1e-3
# 填充 value
end = time.time()
result = board.getBoardState()
if result == board.BLACK_WIN:
outcome = 1
elif result == board.WHITE_WIN:
outcome = -1
else:
outcome = 0
for i in range(move_count):
v = outcome if i % 2 == 0 else -outcome
train_data["value"] += 8 * [[v]] # 每个样本含8个增强样本
return list(zip(train_data["board_data"], train_data["policy_probs"], train_data["value"])), move_count, end - start
def collect_selfplay_data(net, num_games):
dataset = []
game_count = 0
total_time = 0.0
total_moves = 0 # 初始化
for i in range(num_games):
trainer = TrainPipeline(config)
data, move_count, duration = trainer.self_play(net)
dataset.extend(data)
game_count += 1
total_moves += move_count
total_time += duration
print(f"Game {i+1}: {move_count} moves, time {duration:.2f}s")
print(f"Total moves: {total_moves}, average time per game: {total_time / game_count:.2f}s")
print(f"\nGenerated {len(dataset)} training samples from {game_count} games.")
return dataset
def save_dataset(dataset, filename):
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
print(f"Dataset saved to {filename}")
if __name__ == "__main__":
device = torch.device("cuda" if config["use_gpu"] and torch.cuda.is_available() else "cpu")
net = SimpleNet().to(device)
dataset = collect_selfplay_data(net, num_games=10)
save_dataset(dataset, 'selfplay_dataset.pkl')
# ...existing code...