You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
2.4 KiB
55 lines
2.4 KiB
|
|
from mindspore import load_checkpoint, load_param_into_net
|
|
from mindspore import nn, context
|
|
|
|
|
|
def load_parameters(file_name):
|
|
param_dict = load_checkpoint(file_name)
|
|
param_dict_new = {}
|
|
# print(param_dict)
|
|
for key, values in param_dict.items():
|
|
if key.startswith('moments.'):
|
|
continue
|
|
elif key.startswith("layers."):
|
|
param_dict_new['l'+key[7:]] = values
|
|
else:
|
|
param_dict_new[key] = values
|
|
return param_dict_new
|
|
|
|
|
|
class Vgg19(nn.Cell):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l0 = nn.Conv2d(3, 64, kernel_size=3, weight_init='ones')
|
|
self.l2 = nn.Conv2d(64, 64, kernel_size=3, weight_init='ones')
|
|
self.l5 = nn.Conv2d(64, 128, kernel_size=3, weight_init='ones')
|
|
self.l7 = nn.Conv2d(128, 128, kernel_size=3, weight_init='ones')
|
|
self.l10 = nn.Conv2d(128, 256, kernel_size=3, weight_init='ones')
|
|
self.l12 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
|
self.l14 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
|
self.l16 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
|
self.l19 = nn.Conv2d(256, 512, kernel_size=3, weight_init='ones')
|
|
self.l21 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l23 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l25 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l28 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l30 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l32 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.l34 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
|
self.relu = nn.ReLU()
|
|
self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.flatten = nn.Flatten()
|
|
|
|
def construct(self, x):
|
|
layer_1 = self.relu(self.l0(x)) # 3-64
|
|
layer_2 = self.relu(self.l5(self.mp(self.relu(self.l2(layer_1)))))
|
|
layer_3 = self.relu(self.l10(self.relu(self.l7(self.mp(layer_2)))))
|
|
layer_4 = self.relu(self.l19(self.mp(self.relu(self.l16(self.relu(self.l14(self.relu(self.l12(layer_3)))))))))
|
|
layer_4_2 = self.relu(self.l21(layer_4))
|
|
layer_5 = self.relu(self.l28(self.mp(self.relu(self.l25(self.relu(self.l23(layer_4_2)))))))
|
|
return [layer_1, layer_2, layer_3, layer_4, layer_4_2, layer_5]
|
|
|
|
|
|
|
|
|