You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.7 KiB
36 lines
1.7 KiB
class AjFullconn_Class(ModelObj): # 全连接调整对象
|
|
def __init__(self, ObjID,ObjType,ObjLable,ParaString,ObjX,ObjY):
|
|
super().__init__(ObjID,ObjType,ObjLable,ParaString,ObjX,ObjY)
|
|
self.AjFullconnProc = self.ajfullconn_proc # 基本操作函数
|
|
self.SetAjFCPara = self.setajfc_para # 参数设置函数
|
|
def ajfullconn_proc(self, AjFCPara):
|
|
# 根据激活函数的参数选择相应的函数和导数
|
|
# 计算权重矩阵和偏置向量的梯度,使用链式法则
|
|
gradient_weights = np.outer(AjFCPara['loss'],
|
|
AjFCPara['learning_rate'])
|
|
# 更新权重矩阵和偏置向量
|
|
weight_matrix = AjFCPara['weights'] - gradient_weights
|
|
bias_vector = AjFCPara['bias'] - AjFCPara[
|
|
'learning_rate'] * AjFCPara['bias']
|
|
# 返回更新后的权重矩阵和偏置向量
|
|
return weight_matrix, bias_vector
|
|
|
|
def setajfc_para(self, loss, FullConnPara):
|
|
weights = FullConnPara["weights"]
|
|
bias = FullConnPara["bias"]
|
|
loss = np.array([loss])
|
|
AjFCPara = {
|
|
'weights': weights, # 全连接权重
|
|
'bias': bias, # 全连接偏置
|
|
'learning_rate': 0.01, # 学习率
|
|
'loss': loss # 误差值
|
|
}
|
|
return AjFCPara
|
|
if __name__ == '__main__':
|
|
AjFullconn = AjFullconn_Class("AjFullconn1", 9,
|
|
"全连接调整1", [], 510, 120)
|
|
···
|
|
AjFCPara = AjFullconn.SetAjFCPara(loss, FullConnPara)
|
|
weight, bias = AjFullconn.AjFullconnProc(AjFCPara)
|
|
FullConnPara['weights'] = weight
|
|
FullConnPara['bias'] = bias |