You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.5 KiB
43 lines
1.5 KiB
6 months ago
|
import paddle.fluid as fluid
|
||
|
class ConvPool(fluid.dygraph.Layer):
|
||
|
def __init__(self,
|
||
|
num_channels,
|
||
|
num_filters,
|
||
|
filter_size,
|
||
|
pool_size,
|
||
|
pool_stride,
|
||
|
groups,
|
||
|
pool_padding=0,
|
||
|
pool_type='max',
|
||
|
conv_stride=1,
|
||
|
conv_padding=1,
|
||
|
act=None):
|
||
|
super(ConvPool,self).__init__()
|
||
|
self._conv2d_list=[]
|
||
|
for i in range (groups):
|
||
|
conv2d = self.add_sublayer('bb_%d'%i,
|
||
|
fluid.dygraph.Conv2D(
|
||
|
num_channels=num_channels,
|
||
|
num_filters=num_filters,
|
||
|
filter_size=filter_size,
|
||
|
stride=conv_stride,
|
||
|
padding=conv_padding,
|
||
|
act=act
|
||
|
))
|
||
|
|
||
|
num_channels = num_filters
|
||
|
self._conv2d_list.append(conv2d)
|
||
|
|
||
|
self._pool2d=fluid.dygraph.Pool2D(
|
||
|
pool_size=pool_size,
|
||
|
pool_type=pool_type,
|
||
|
pool_stride=pool_stride,
|
||
|
pool_padding=pool_padding
|
||
|
)
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
x= inputs
|
||
|
for conv in self._conv2d_list:
|
||
|
x= conv(x)
|
||
|
x = self._pool2d(x)
|
||
|
return x
|