You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.6 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self, board_size=15, input_channels=13, action_size=225):
super(SimpleNet, self).__init__()
self.board_size = board_size
self.input_channels = input_channels
self.action_size = action_size
# 输入卷积层
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
# 策略头
self.policy_conv = nn.Conv2d(128, 4, kernel_size=1)
self.policy_fc = nn.Linear(4 * board_size * board_size, action_size)
# 价值头
self.value_conv = nn.Conv2d(128, 2, kernel_size=1)
self.value_fc1 = nn.Linear(2 * board_size * board_size, 64)
self.value_fc2 = nn.Linear(64, 1)
def forward(self, x):
# x shape: (batch, channels, board_size, board_size)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# 策略头
p = F.relu(self.policy_conv(x))
p = p.view(-1, 4 * self.board_size * self.board_size)
p = self.policy_fc(p)
policy = F.softmax(p, dim=1)
# 价值头
v = F.relu(self.value_conv(x))
v = v.view(-1, 2 * self.board_size * self.board_size)
v = F.relu(self.value_fc1(v))
value = torch.tanh(self.value_fc2(v))
return policy, value.squeeze(-1) # value形状 (batch,)