You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
2.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
class ModelObj: # 网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
class Pool_Class(ModelObj): # 池化对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
self.MaxPoolProc = self.pool_proc # 基本操作函数
self.SetPollPara = self.setpool_para # 参数设置函数
def pool_proc(self, image, PoolPara):
pool_mode = PoolPara["pool_mode"]
pool_size = PoolPara["pool_size"]
stride = PoolPara["stride"]
c, h, w = image.shape # 获取输入特征图的高度和宽度
out_h = int((h - pool_size) / stride) + 1 # 计算输出特征图的高度
out_w = int((w - pool_size) / stride) + 1 # 计算输出特征图的宽度
out = np.zeros((c, out_h, out_w)) # 初始化输出特征图为全零数组
for k in range(c): # 对于输出的每一个位置上计算:
for i in range(out_h):
for j in range(out_w):
window = image[k, i * stride:i * stride + pool_size,
j * stride:j * stride + pool_size]
if pool_mode == "max": # 最大池化
out[k][i][j] = np.max(window)
elif pool_mode == "avg": # 平均池化
out[k][i][j] = np.mean(window)
elif pool_mode == "min": # 最小池化
out[k][i][j] = np.min(window)
else: # 无效的池化类型
raise ValueError("Invalid pooling mode")
return out # 返回特征图。
def setpool_para(self): # 定义设置池化参数的函数
pool_mode = input("请输入池化模式max/avg/min: ") # 用户输入池化模式
pool_size = int(input("请输入池化大小: ")) # 用户输入池化大小
stride = int(input("请输入步长: ")) # 用户输入步长
PoolPara = {"pool_mode": pool_mode, "pool_size": pool_size,
"stride": stride} # 返回PoolPara参数这里用字典来存储
return PoolPara # 返回PoolPara参数
if __name__ == '__main__':
Pool = Pool_Class("Pool1", 3, "池化1", [], 300, 400)
PoolPara = Pool.SetPollPara()
print(PoolPara)