You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

55 lines
2.4 KiB

from mindspore import load_checkpoint, load_param_into_net
from mindspore import nn, context
def load_parameters(file_name):
param_dict = load_checkpoint(file_name)
param_dict_new = {}
# print(param_dict)
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith("layers."):
param_dict_new['l'+key[7:]] = values
else:
param_dict_new[key] = values
return param_dict_new
class Vgg19(nn.Cell):
def __init__(self):
super().__init__()
self.l0 = nn.Conv2d(3, 64, kernel_size=3, weight_init='ones')
self.l2 = nn.Conv2d(64, 64, kernel_size=3, weight_init='ones')
self.l5 = nn.Conv2d(64, 128, kernel_size=3, weight_init='ones')
self.l7 = nn.Conv2d(128, 128, kernel_size=3, weight_init='ones')
self.l10 = nn.Conv2d(128, 256, kernel_size=3, weight_init='ones')
self.l12 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
self.l14 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
self.l16 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
self.l19 = nn.Conv2d(256, 512, kernel_size=3, weight_init='ones')
self.l21 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l23 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l25 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l28 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l30 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l32 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.l34 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
self.relu = nn.ReLU()
self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
layer_1 = self.relu(self.l0(x)) # 3-64
layer_2 = self.relu(self.l5(self.mp(self.relu(self.l2(layer_1)))))
layer_3 = self.relu(self.l10(self.relu(self.l7(self.mp(layer_2)))))
layer_4 = self.relu(self.l19(self.mp(self.relu(self.l16(self.relu(self.l14(self.relu(self.l12(layer_3)))))))))
layer_4_2 = self.relu(self.l21(layer_4))
layer_5 = self.relu(self.l28(self.mp(self.relu(self.l25(self.relu(self.l23(layer_4_2)))))))
return [layer_1, layer_2, layer_3, layer_4, layer_4_2, layer_5]