You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
2.6 KiB

8 months ago
import numpy as np
class ModelObj: # 网络对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
class Pool_Class(ModelObj): # 池化对象
def __init__(self, ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY):
super().__init__(ObjID, ObjType, ObjLable, ParaString, ObjX, ObjY)
self.MaxPoolProc = self.pool_proc # 基本操作函数
self.SetPollPara = self.setpool_para # 参数设置函数
def pool_proc(self, image, PoolPara):
pool_mode = PoolPara["pool_mode"]
pool_size = PoolPara["pool_size"]
stride = PoolPara["stride"]
c, h, w = image.shape # 获取输入特征图的高度和宽度
out_h = int((h - pool_size) / stride) + 1 # 计算输出特征图的高度
out_w = int((w - pool_size) / stride) + 1 # 计算输出特征图的宽度
out = np.zeros((c, out_h, out_w)) # 初始化输出特征图为全零数组
for k in range(c): # 对于输出的每一个位置上计算:
for i in range(out_h):
for j in range(out_w):
window = image[k, i * stride:i * stride + pool_size,
j * stride:j * stride + pool_size]
if pool_mode == "max": # 最大池化
out[k][i][j] = np.max(window)
elif pool_mode == "avg": # 平均池化
out[k][i][j] = np.mean(window)
elif pool_mode == "min": # 最小池化
out[k][i][j] = np.min(window)
else: # 无效的池化类型
raise ValueError("Invalid pooling mode")
return out # 返回特征图。
def setpool_para(self): # 定义设置池化参数的函数
pool_mode = input("请输入池化模式max/avg/min: ") # 用户输入池化模式
pool_size = int(input("请输入池化大小: ")) # 用户输入池化大小
stride = int(input("请输入步长: ")) # 用户输入步长
PoolPara = {"pool_mode": pool_mode, "pool_size": pool_size,
"stride": stride} # 返回PoolPara参数这里用字典来存储
return PoolPara # 返回PoolPara参数
if __name__ == '__main__':
Pool = Pool_Class("Pool1", 3, "池化1", [], 300, 400)
PoolPara = Pool.SetPollPara()
print(PoolPara)