You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.3 KiB
39 lines
1.3 KiB
import paddle.fluid as fluid
|
|
from 口罩检测.ConvPool import ConvPool
|
|
class VGGNet(fluid.dygraph.Layer):
|
|
def __init__(self):
|
|
super(VGGNet,self).__init__()
|
|
#通道数,卷积核个数,卷积核大小,池化核大小,池化步长,连续卷积个数
|
|
self.convpool01 = ConvPool(3,64,3,2,2,2,act='relu')
|
|
|
|
self.convpool02 = ConvPool(64,128,3,2,2,2,act='relu')
|
|
|
|
self.convpool03 = ConvPool(128,256,3,2,2,3,act='relu')
|
|
|
|
self.convpool04 = ConvPool(256,512,3,2,2,3,act='relu')
|
|
|
|
self.convpool05 = ConvPool(512,512,3,2,2,3,act='relu')
|
|
|
|
self.pool5_shape = 512*7*7
|
|
|
|
self.fc01 = fluid.dygraph.Linear(self.pool5_shape,4096,act='relu')
|
|
self.fc02 =fluid.dygraph.Linear(4096,4096,act='relu')
|
|
self.fc03= fluid.dygraph.Linear(4096,2,act='softmax')
|
|
|
|
def forward(self, input,label=None):
|
|
x= self.convpool01(input)
|
|
x= self.convpool02(x)
|
|
x = self.convpool03(x)
|
|
x= self.convpool04(x)
|
|
x = self.convpool05(x)
|
|
|
|
result =fluid.layers.reshape(x,shape=[-1,512*7*7])
|
|
out = self.fc01(result)
|
|
out = self.fc02(out)
|
|
out = self.fc03(out)
|
|
|
|
if label is not None:
|
|
acc = fluid.layers.accuracy(input =out,label=label)
|
|
return out,acc
|
|
else:
|
|
return out |