|
|
import tkinter as tk
|
|
|
from PIL import Image, ImageTk
|
|
|
|
|
|
global Viewcanvas # 定义画布
|
|
|
global Root # 主窗口
|
|
|
global AllModelObj #网络对象
|
|
|
|
|
|
'''
|
|
|
【编程16.5】编制程序:依据AllModelObj和AllModelConn数据结构产生如图16.2的输出界面。
|
|
|
【目的及编程说明】读者通过编程16.5可理解卷积神经网络模型构建的输出界面。数据结构及初始值参见【编程16.1】。
|
|
|
'''
|
|
|
# 定义图元对象类
|
|
|
class ModelObj:
|
|
|
def __init__(self, ObjID, ObjType, ObjLable, ProcFunc, SetParaFunc, ParaString, ObjX, ObjY):
|
|
|
self.ObjID = ObjID # 图元号
|
|
|
self.ObjType = ObjType # 图元类别
|
|
|
self.ObjLable = ObjLable # 对象标签
|
|
|
self.ProcFunc = ProcFunc # 基本操作函数
|
|
|
self.SetParaFunc = SetParaFunc # 参数设置函数
|
|
|
self.ParaString = ParaString # 参数字符串
|
|
|
self.ObjX = ObjX # 对象位置x坐标
|
|
|
self.ObjY = ObjY # 对象位置y坐标
|
|
|
|
|
|
def output(self): # 输出方法
|
|
|
# 创建一个空列表
|
|
|
result = []
|
|
|
# 将对象的属性添加到列表中
|
|
|
result.append(self.ObjID)
|
|
|
result.append(self.ObjType)
|
|
|
result.append(self.ObjLable)
|
|
|
result.append(self.ProcFunc)
|
|
|
result.append(self.SetParaFunc)
|
|
|
result.append(self.ParaString)
|
|
|
result.append(self.ObjX)
|
|
|
result.append(self.ObjY)
|
|
|
# 返回列表
|
|
|
return result
|
|
|
|
|
|
# 定义网络连接对象类
|
|
|
class ModelConn:
|
|
|
def __init__(self, ConnObjID, ConnType, NobjS, NobjE):
|
|
|
self.ConnObjID = ConnObjID # 连接线编号
|
|
|
self.ConnType = ConnType # 连接线类别
|
|
|
self.NobjS = NobjS # 源图元对象
|
|
|
self.NobjE = NobjE # 目标图元对象
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f"{self.ConnObjID}, {self.ConnType}, {self.NobjS}, {self.NobjE}"
|
|
|
|
|
|
def create_instance():
|
|
|
# 创建图元对象实例
|
|
|
DataSet = ModelObj("DataSet", 1, "数据集", "LoadData", "SetDataPara", ".", 120, 330).output()
|
|
|
Conv = ModelObj("Conv", 2, "卷积", "ConvProc", "SetConvPara", ".", 250, 330).output()
|
|
|
Pool = ModelObj("Pool", 3, "最大池化", "MaxPoolProc", "SetPollPara", ".", 380, 330).output()
|
|
|
FullConn = ModelObj("FullConn", 4, "全连接", "FullConnProc", "SetFullConnPara", ".", 510, 330).output()
|
|
|
Nonline = ModelObj("Nonline", 5, "非线性函数", "NonlinearProc", "SetNonLPara", ".", 640, 330).output()
|
|
|
Classifier = ModelObj("Classifier", 6, "分类", "ClassifierProc", "SetClassifyPara", ".", 780, 330).output()
|
|
|
Error = ModelObj("Error", 7, "误差计算", "ErrorProc", "SetErrorPara", ".", 710, 124).output()
|
|
|
AjConv = ModelObj("AjConv", 8, "卷积调整", "AjConvProc", "SetAjConvPara", ".", 250, 70).output()
|
|
|
AjFullconn = ModelObj("AjFullconn", 9, "全连接调整", "AjFullconnProc", "SetAjFCPara", ".", 510, 120).output()
|
|
|
listinstance = [DataSet,Conv,Pool,FullConn,Nonline,Classifier,Error,AjConv,AjFullconn]
|
|
|
return listinstance # 还回图元对象实例列表
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
# listinstance = create_instance()
|
|
|
# for instance in listinstance:
|
|
|
# print(instance)
|
|
|
|
|
|
def connect_class(listinstance):
|
|
|
# 创建连接对象实例
|
|
|
Line1 = ModelConn(1, 1, listinstance[0], listinstance[1])
|
|
|
Line2 = ModelConn(2, 1, listinstance[1], listinstance[2])
|
|
|
Line3 = ModelConn(3, 1, listinstance[2], listinstance[3])
|
|
|
Line4 = ModelConn(4, 1, listinstance[3], listinstance[4])
|
|
|
Line5 = ModelConn(5, 1, listinstance[4], listinstance[5])
|
|
|
Line6 = ModelConn(6, 1, listinstance[5], listinstance[6])
|
|
|
Line7 = ModelConn(7, 2, listinstance[6], listinstance[8])
|
|
|
Line8 = ModelConn(8, 2, listinstance[6], listinstance[7])
|
|
|
Line9 = ModelConn(9, 2, listinstance[8], listinstance[3])
|
|
|
Line10 = ModelConn(10, 2, listinstance[7], listinstance[1])
|
|
|
listclass = [Line1,Line2, Line3, Line4, Line5, Line6, Line7, Line8, Line9, Line10]
|
|
|
return listclass # # 还回连接对象实例
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
# listinstance = create_instance()
|
|
|
# listclass = connect_class(listinstance)
|
|
|
# for iclass in listclass:
|
|
|
# print(iclass)
|
|
|
|
|
|
def element(path):
|
|
|
imgs = Image.open(path) # 加载图元对应的图片文件
|
|
|
imgs = imgs.resize((60, 50)) # 使用resize方法调整图片
|
|
|
imgs = ImageTk.PhotoImage(imgs) # 把Image对象转换成PhotoImage对象
|
|
|
Root.img = imgs # 保存图片的引用,防止被垃圾回收
|
|
|
return imgs
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
# Root = tk.Tk() # 创建一个主窗口
|
|
|
# img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
|
|
|
# "img/classifier.png", "img/error.png", "img/adjust.png"] # 图元路径
|
|
|
# list_image = [] # 定义一个列表,存储PhotoImage对象
|
|
|
# for path in img_path:
|
|
|
# list_image.append(element(path))
|
|
|
# for image in list_image:
|
|
|
# print(image) # 打印结果
|
|
|
|
|
|
def window():
|
|
|
global Root
|
|
|
global Viewcanvas
|
|
|
Root = tk.Tk() # 创建一个主窗口
|
|
|
# 设置窗口的大小为1200*750
|
|
|
window_width = 900 # 窗口的宽度
|
|
|
window_height = 550 # 窗口的高度
|
|
|
Root.title("神经网络可视化")
|
|
|
Root.geometry("900x550") # 设置窗口的大小和位置
|
|
|
# 创建一个画布,用于绘制矩形框,设置画布的大小和背景色
|
|
|
Viewcanvas = tk.Canvas(Root, width=window_width, height=window_height, bg="white")
|
|
|
# 将画布添加到主窗口中
|
|
|
Viewcanvas.pack()
|
|
|
# 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度
|
|
|
Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2)
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
# window()
|
|
|
# path = "img/data.png" # 图元路径
|
|
|
# image =element(path)
|
|
|
# print(image) # 打印结果
|
|
|
# Root.mainloop()
|
|
|
|
|
|
def connecting_lines(obj_x, obj_y, obj_x1, obj_x2, obj_x3, obj_y1, obj_y2, obj_y3, image, text, smooth, width):
|
|
|
Viewcanvas.create_image(obj_x, obj_y, image=image) # 创建图元对象
|
|
|
Viewcanvas.create_text(obj_x1, obj_y1, text=text, font=("黑体", 14)) # 创建图元对象的标签
|
|
|
Viewcanvas.create_line(obj_x2, obj_y2, obj_x3, obj_y3, arrow=tk.LAST, # 创建数据线箭头
|
|
|
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth, width=width)
|
|
|
|
|
|
def connectings_lines(obj_x, obj_y, obj_x1, obj_x2, obj_x3,obj_x4, obj_y1, obj_y2, obj_y3, obj_y4, image, text, smooth, width):
|
|
|
# 创建图元对象
|
|
|
Viewcanvas.create_image(obj_x, obj_y, image=image)
|
|
|
# 创建图元对象的标签
|
|
|
Viewcanvas.create_text(obj_x1, obj_y1, text=text, font=("黑体", 14))
|
|
|
# 创建数据线箭头
|
|
|
Viewcanvas.create_line(obj_x2, obj_y2, obj_x3, obj_y3, obj_x4, obj_y4, arrow=tk.LAST,
|
|
|
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth, width=width)
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
# window()
|
|
|
# listinstance = create_instance()
|
|
|
# # 创建网络对象总表和网络连接对象总表
|
|
|
# AllModelObj = [listinstance[0], listinstance[1], listinstance[2], listinstance[3], listinstance[4], listinstance[5],
|
|
|
# listinstance[6], listinstance[7], listinstance[8]]
|
|
|
# img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
|
|
|
# "img/classifier.png", "img/error.png", "img/adjust.png"]
|
|
|
# list_image = []
|
|
|
# for path in img_path:
|
|
|
# list_image.append(element(path))
|
|
|
# obj_x = AllModelObj[0][6] # 根据对象的id计算x坐标
|
|
|
# obj_y = AllModelObj[0][7] # 根据对象的id计算y坐标
|
|
|
# obj_x2 = AllModelObj[5][6] # 根据对象的id计算x坐标
|
|
|
# obj_y2 = AllModelObj[5][7] # 根据对象的id计算y坐标
|
|
|
# connecting_lines(obj_x, obj_y, 0, 32, 100, 50, 0, 0, list_image[0], " 加载" + "\n" + "数据集", True, 3)
|
|
|
# connectings_lines(obj_x2, obj_y2, 0, 0, 0, -50, 50, -30, -120, -180, list_image[5], "类别", False, 3)
|
|
|
# Root.mainloop()
|
|
|
def switch(obj_type, obj_x, obj_y,listimage):
|
|
|
if obj_type == 1: # 加载数据集
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+32, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[0], " 加载" + "\n" + "数据集", True, 3)
|
|
|
elif obj_type == 2: # 卷积
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[1], "卷积", True, 3)
|
|
|
elif obj_type == 3: # 池化
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[2], "池化", True, 3)
|
|
|
elif obj_type == 4: # 全连接
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[3], "全连接" + "\n" + " 函数", True, 3)
|
|
|
elif obj_type == 5: # 非线性
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+110, obj_y+50, obj_y+0, obj_y+0, listimage[4], "非线性" + "\n" + " 函数", True, 3)
|
|
|
elif obj_type == 6: # 分类
|
|
|
connectings_lines(obj_x, obj_y, obj_x+0, obj_x+0, obj_x+0, obj_x-50, obj_y+50, obj_y-30, obj_y-120, obj_y-180, listimage[5], "类别", False, 3)
|
|
|
elif obj_type == 7: # 误差计算
|
|
|
connectings_lines(obj_x, obj_y, obj_x+0, obj_x-20, obj_x-50, obj_x-420, obj_y-40,obj_y -20,obj_y -60,obj_y -60, listimage[6], "误差", False, 2)
|
|
|
connecting_lines(obj_x, obj_y, obj_x+0, obj_x-40, obj_x-170, obj_y-40, obj_y+0, obj_y+0, listimage[6], "误差", False, 2)
|
|
|
elif obj_type == 8: # 调整
|
|
|
connecting_lines(obj_x, obj_y, obj_x-80, obj_x+0, obj_x+0, obj_y+0, obj_y+30, obj_y+235, listimage[7], "调整1", False, 2)
|
|
|
elif obj_type == 9: # 调整
|
|
|
connecting_lines(obj_x, obj_y, obj_x-80, obj_x+0, obj_x+0, obj_y+0,obj_y+ 30,obj_y+ 183, listimage[7], "调整2", False, 2)
|
|
|
|
|
|
def creating_elements(AllModelObj,listimage):
|
|
|
# 遍历AllModelObj列表,在窗口左侧创建图元菜单
|
|
|
for obj in AllModelObj:
|
|
|
# 获取图元对象的类型、标签等信息
|
|
|
obj_type = obj[1]
|
|
|
# 并且要根据需求调整每个对象的位置
|
|
|
obj_x = obj[6] # 根据对象的id计算x坐标
|
|
|
obj_y = obj[7] # 根据对象的id计算y坐标
|
|
|
# 根据对象的类型,绘制相应的图形
|
|
|
switch(obj_type, obj_x, obj_y,listimage)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
global AllModelObj
|
|
|
window()
|
|
|
listinstance = create_instance()
|
|
|
listclass = connect_class(listinstance)
|
|
|
# 创建网络对象总表和网络连接对象总表
|
|
|
AllModelObj = [listinstance[0],listinstance[1], listinstance[2], listinstance[3], listinstance[4], listinstance[5],
|
|
|
listinstance[6], listinstance[7], listinstance[8]]
|
|
|
AllModelConn = [listclass[0], listclass[1], listclass[2], listclass[3], listclass[4], listclass[5], listclass[6],
|
|
|
listclass[7], listclass[8], listclass[9]]
|
|
|
|
|
|
img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
|
|
|
"img/classifier.png", "img/error.png", "img/adjust.png"]
|
|
|
list_image = []
|
|
|
for path in img_path:
|
|
|
list_image.append(element(path))
|
|
|
creating_elements(AllModelObj, list_image)
|
|
|
print(1)
|
|
|
Root.mainloop()
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|