You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
4.4 KiB
144 lines
4.4 KiB
from policynet import PolicyValueNet
|
|
from net import SimpleNet
|
|
import numpy as np
|
|
import time
|
|
import threading
|
|
import concurrent.futures
|
|
from collections import deque
|
|
import random
|
|
import pickle
|
|
import os
|
|
from shutil import copy
|
|
from mcts import MCTS
|
|
from board import GomokuBoard
|
|
import torch
|
|
config = {
|
|
# GomokuBoard
|
|
"n": 15,
|
|
"n_in_row": 5,
|
|
"moveCacheNum": 10,
|
|
# NeuralNet
|
|
"model_path": "./Models/Libtorch/",
|
|
"use_gpu": True,
|
|
# MCTS
|
|
"thread_num": 64,
|
|
"mcts_branch_num": 1000,
|
|
"c_puct": 5,
|
|
"c_virtual_loss": 3,
|
|
# train
|
|
"train_model_path":"./Models/Pytorch/",
|
|
"samples_path": "./Data/checkpoint.example",
|
|
"train_iters": 1000,
|
|
"batch_size": 512,
|
|
"epochs": 15,
|
|
"explore_threshold": 18,
|
|
"dirichlet_alpha": 0.06,
|
|
"dirichlet_theta": 0.25,
|
|
"lr": 0.0005,
|
|
"temp": 1,
|
|
"self_play_threadNum":10,
|
|
"max_buffer_len": 80000, #训练数据的缓冲区支持数据条数
|
|
"comparison_freq": 25,
|
|
"comparison_times": 10,
|
|
"update_threshold":0.55,
|
|
"explore_prob":0.8,
|
|
"trainSamples_save_freq":10,
|
|
"Yixin_freq":5
|
|
|
|
}
|
|
|
|
class TrainPipeline:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
def self_play(self, policy_net):
|
|
board = GomokuBoard(self.config["n"], self.config["n_in_row"], self.config["moveCacheNum"])
|
|
Mcts = MCTS(policy_net, c_puct=self.config["c_puct"], n_playout=self.config["mcts_branch_num"])
|
|
|
|
train_data = {"board_data": [], "policy_probs": [], "value": []}
|
|
temp = self.config["temp"]
|
|
move_count = 0
|
|
start = time.time()
|
|
|
|
while board.get_board_state() == 0:
|
|
board_data = board.get_board_data()
|
|
acts, probs = Mcts.get_move_probs(board, temp=temp)
|
|
|
|
|
|
# 数据增强
|
|
equi_boards, equi_probs = self.getEquiDataSet(board_data, probs)
|
|
train_data["board_data"] += equi_boards
|
|
train_data["policy_probs"] += equi_probs
|
|
|
|
# 添加 Dirichlet 噪声用于探索
|
|
legal_moves = board.get_available_move()
|
|
legal_indices = [i for i, x in enumerate(legal_moves) if x == 1]
|
|
noise = self.config["dirichlet_theta"] * np.random.dirichlet(
|
|
self.config["dirichlet_alpha"] * np.ones(len(legal_indices))
|
|
)
|
|
for i, idx in enumerate(legal_indices):
|
|
probs[idx] = (1 - self.config["dirichlet_theta"]) * probs[idx] + noise[i]
|
|
probs /= np.sum(probs)
|
|
|
|
action = np.random.choice(len(probs), p=probs)
|
|
board.execute_move(action)
|
|
Mcts.update_with_move(action)
|
|
move_count += 1
|
|
|
|
if move_count >= self.config["explore_threshold"]:
|
|
temp = 1e-3
|
|
|
|
# 填充 value
|
|
end = time.time()
|
|
result = board.getBoardState()
|
|
if result == board.BLACK_WIN:
|
|
outcome = 1
|
|
elif result == board.WHITE_WIN:
|
|
outcome = -1
|
|
else:
|
|
outcome = 0
|
|
|
|
for i in range(move_count):
|
|
v = outcome if i % 2 == 0 else -outcome
|
|
train_data["value"] += 8 * [[v]] # 每个样本含8个增强样本
|
|
|
|
return list(zip(train_data["board_data"], train_data["policy_probs"], train_data["value"])), move_count, end - start
|
|
|
|
|
|
def collect_selfplay_data(net, num_games):
|
|
dataset = []
|
|
game_count = 0
|
|
total_time = 0.0
|
|
total_moves = 0 # 初始化
|
|
|
|
for i in range(num_games):
|
|
trainer = TrainPipeline(config)
|
|
data, move_count, duration = trainer.self_play(net)
|
|
dataset.extend(data)
|
|
game_count += 1
|
|
total_moves += move_count
|
|
total_time += duration
|
|
print(f"Game {i+1}: {move_count} moves, time {duration:.2f}s")
|
|
print(f"Total moves: {total_moves}, average time per game: {total_time / game_count:.2f}s")
|
|
|
|
print(f"\nGenerated {len(dataset)} training samples from {game_count} games.")
|
|
return dataset
|
|
|
|
def save_dataset(dataset, filename):
|
|
with open(filename, 'wb') as f:
|
|
pickle.dump(dataset, f)
|
|
print(f"Dataset saved to {filename}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
device = torch.device("cuda" if config["use_gpu"] and torch.cuda.is_available() else "cpu")
|
|
net = SimpleNet().to(device)
|
|
dataset = collect_selfplay_data(net, num_games=10)
|
|
save_dataset(dataset, 'selfplay_dataset.pkl')
|
|
# ...existing code...
|