You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.6 KiB
45 lines
1.6 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class SimpleNet(nn.Module):
|
|
def __init__(self, board_size=15, input_channels=13, action_size=225):
|
|
super(SimpleNet, self).__init__()
|
|
self.board_size = board_size
|
|
self.input_channels = input_channels
|
|
self.action_size = action_size
|
|
|
|
# 输入卷积层
|
|
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
|
|
|
# 策略头
|
|
self.policy_conv = nn.Conv2d(128, 4, kernel_size=1)
|
|
self.policy_fc = nn.Linear(4 * board_size * board_size, action_size)
|
|
|
|
# 价值头
|
|
self.value_conv = nn.Conv2d(128, 2, kernel_size=1)
|
|
self.value_fc1 = nn.Linear(2 * board_size * board_size, 64)
|
|
self.value_fc2 = nn.Linear(64, 1)
|
|
|
|
def forward(self, x):
|
|
# x shape: (batch, channels, board_size, board_size)
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
|
|
# 策略头
|
|
p = F.relu(self.policy_conv(x))
|
|
p = p.view(-1, 4 * self.board_size * self.board_size)
|
|
p = self.policy_fc(p)
|
|
policy = F.softmax(p, dim=1)
|
|
|
|
# 价值头
|
|
v = F.relu(self.value_conv(x))
|
|
v = v.view(-1, 2 * self.board_size * self.board_size)
|
|
v = F.relu(self.value_fc1(v))
|
|
value = torch.tanh(self.value_fc2(v))
|
|
|
|
return policy, value.squeeze(-1) # value形状 (batch,)
|