You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tkinter as tk
from PIL import Image, ImageTk
global Viewcanvas # 定义画布
global Root # 主窗口
global AllModelObj #网络对象
'''
【编程16.5】编制程序依据AllModelObj和AllModelConn数据结构产生如图16.2的输出界面。
【目的及编程说明】读者通过编程16.5可理解卷积神经网络模型构建的输出界面。数据结构及初始值参见【编程16.1】。
'''
# 定义图元对象类
class ModelObj:
def __init__(self, ObjID, ObjType, ObjLable, ProcFunc, SetParaFunc, ParaString, ObjX, ObjY):
self.ObjID = ObjID # 图元号
self.ObjType = ObjType # 图元类别
self.ObjLable = ObjLable # 对象标签
self.ProcFunc = ProcFunc # 基本操作函数
self.SetParaFunc = SetParaFunc # 参数设置函数
self.ParaString = ParaString # 参数字符串
self.ObjX = ObjX # 对象位置x坐标
self.ObjY = ObjY # 对象位置y坐标
def output(self): # 输出方法
# 创建一个空列表
result = []
# 将对象的属性添加到列表中
result.append(self.ObjID)
result.append(self.ObjType)
result.append(self.ObjLable)
result.append(self.ProcFunc)
result.append(self.SetParaFunc)
result.append(self.ParaString)
result.append(self.ObjX)
result.append(self.ObjY)
# 返回列表
return result
# 定义网络连接对象类
class ModelConn:
def __init__(self, ConnObjID, ConnType, NobjS, NobjE):
self.ConnObjID = ConnObjID # 连接线编号
self.ConnType = ConnType # 连接线类别
self.NobjS = NobjS # 源图元对象
self.NobjE = NobjE # 目标图元对象
def __repr__(self):
return f"{self.ConnObjID}, {self.ConnType}, {self.NobjS}, {self.NobjE}"
def create_instance():
# 创建图元对象实例
DataSet = ModelObj("DataSet", 1, "数据集", "LoadData", "SetDataPara", ".", 120, 330).output()
Conv = ModelObj("Conv", 2, "卷积", "ConvProc", "SetConvPara", ".", 250, 330).output()
Pool = ModelObj("Pool", 3, "最大池化", "MaxPoolProc", "SetPollPara", ".", 380, 330).output()
FullConn = ModelObj("FullConn", 4, "全连接", "FullConnProc", "SetFullConnPara", ".", 510, 330).output()
Nonline = ModelObj("Nonline", 5, "非线性函数", "NonlinearProc", "SetNonLPara", ".", 640, 330).output()
Classifier = ModelObj("Classifier", 6, "分类", "ClassifierProc", "SetClassifyPara", ".", 780, 330).output()
Error = ModelObj("Error", 7, "误差计算", "ErrorProc", "SetErrorPara", ".", 710, 124).output()
AjConv = ModelObj("AjConv", 8, "卷积调整", "AjConvProc", "SetAjConvPara", ".", 250, 70).output()
AjFullconn = ModelObj("AjFullconn", 9, "全连接调整", "AjFullconnProc", "SetAjFCPara", ".", 510, 120).output()
listinstance = [DataSet,Conv,Pool,FullConn,Nonline,Classifier,Error,AjConv,AjFullconn]
return listinstance # 还回图元对象实例列表
# if __name__ == '__main__':
# listinstance = create_instance()
# for instance in listinstance:
# print(instance)
def connect_class(listinstance):
# 创建连接对象实例
Line1 = ModelConn(1, 1, listinstance[0], listinstance[1])
Line2 = ModelConn(2, 1, listinstance[1], listinstance[2])
Line3 = ModelConn(3, 1, listinstance[2], listinstance[3])
Line4 = ModelConn(4, 1, listinstance[3], listinstance[4])
Line5 = ModelConn(5, 1, listinstance[4], listinstance[5])
Line6 = ModelConn(6, 1, listinstance[5], listinstance[6])
Line7 = ModelConn(7, 2, listinstance[6], listinstance[8])
Line8 = ModelConn(8, 2, listinstance[6], listinstance[7])
Line9 = ModelConn(9, 2, listinstance[8], listinstance[3])
Line10 = ModelConn(10, 2, listinstance[7], listinstance[1])
listclass = [Line1,Line2, Line3, Line4, Line5, Line6, Line7, Line8, Line9, Line10]
return listclass # # 还回连接对象实例
# if __name__ == '__main__':
# listinstance = create_instance()
# listclass = connect_class(listinstance)
# for iclass in listclass:
# print(iclass)
def element(path):
imgs = Image.open(path) # 加载图元对应的图片文件
imgs = imgs.resize((60, 50)) # 使用resize方法调整图片
imgs = ImageTk.PhotoImage(imgs) # 把Image对象转换成PhotoImage对象
Root.img = imgs # 保存图片的引用,防止被垃圾回收
return imgs
# if __name__ == '__main__':
# Root = tk.Tk() # 创建一个主窗口
# img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
# "img/classifier.png", "img/error.png", "img/adjust.png"] # 图元路径
# list_image = [] # 定义一个列表存储PhotoImage对象
# for path in img_path:
# list_image.append(element(path))
# for image in list_image:
# print(image) # 打印结果
def window():
global Root
global Viewcanvas
Root = tk.Tk() # 创建一个主窗口
# 设置窗口的大小为1200*750
window_width = 900 # 窗口的宽度
window_height = 550 # 窗口的高度
Root.title("神经网络可视化")
Root.geometry("900x550") # 设置窗口的大小和位置
# 创建一个画布,用于绘制矩形框,设置画布的大小和背景色
Viewcanvas = tk.Canvas(Root, width=window_width, height=window_height, bg="white")
# 将画布添加到主窗口中
Viewcanvas.pack()
# 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度
Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2)
# if __name__ == '__main__':
# window()
# path = "img/data.png" # 图元路径
# image =element(path)
# print(image) # 打印结果
# Root.mainloop()
def connecting_lines(obj_x, obj_y, obj_x1, obj_x2, obj_x3, obj_y1, obj_y2, obj_y3, image, text, smooth, width):
Viewcanvas.create_image(obj_x, obj_y, image=image) # 创建图元对象
Viewcanvas.create_text(obj_x1, obj_y1, text=text, font=("黑体", 14)) # 创建图元对象的标签
Viewcanvas.create_line(obj_x2, obj_y2, obj_x3, obj_y3, arrow=tk.LAST, # 创建数据线箭头
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth, width=width)
def connectings_lines(obj_x, obj_y, obj_x1, obj_x2, obj_x3,obj_x4, obj_y1, obj_y2, obj_y3, obj_y4, image, text, smooth, width):
# 创建图元对象
Viewcanvas.create_image(obj_x, obj_y, image=image)
# 创建图元对象的标签
Viewcanvas.create_text(obj_x1, obj_y1, text=text, font=("黑体", 14))
# 创建数据线箭头
Viewcanvas.create_line(obj_x2, obj_y2, obj_x3, obj_y3, obj_x4, obj_y4, arrow=tk.LAST,
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth, width=width)
# if __name__ == '__main__':
# window()
# listinstance = create_instance()
# # 创建网络对象总表和网络连接对象总表
# AllModelObj = [listinstance[0], listinstance[1], listinstance[2], listinstance[3], listinstance[4], listinstance[5],
# listinstance[6], listinstance[7], listinstance[8]]
# img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
# "img/classifier.png", "img/error.png", "img/adjust.png"]
# list_image = []
# for path in img_path:
# list_image.append(element(path))
# obj_x = AllModelObj[0][6] # 根据对象的id计算x坐标
# obj_y = AllModelObj[0][7] # 根据对象的id计算y坐标
# obj_x2 = AllModelObj[5][6] # 根据对象的id计算x坐标
# obj_y2 = AllModelObj[5][7] # 根据对象的id计算y坐标
# connecting_lines(obj_x, obj_y, 0, 32, 100, 50, 0, 0, list_image[0], " 加载" + "\n" + "数据集", True, 3)
# connectings_lines(obj_x2, obj_y2, 0, 0, 0, -50, 50, -30, -120, -180, list_image[5], "类别", False, 3)
# Root.mainloop()
def switch(obj_type, obj_x, obj_y,listimage):
if obj_type == 1: # 加载数据集
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+32, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[0], " 加载" + "\n" + "数据集", True, 3)
elif obj_type == 2: # 卷积
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[1], "卷积", True, 3)
elif obj_type == 3: # 池化
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[2], "池化", True, 3)
elif obj_type == 4: # 全连接
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+100, obj_y+50, obj_y+0, obj_y+0, listimage[3], "全连接" + "\n" + " 函数", True, 3)
elif obj_type == 5: # 非线性
connecting_lines(obj_x, obj_y, obj_x+0, obj_x+30, obj_x+110, obj_y+50, obj_y+0, obj_y+0, listimage[4], "非线性" + "\n" + " 函数", True, 3)
elif obj_type == 6: # 分类
connectings_lines(obj_x, obj_y, obj_x+0, obj_x+0, obj_x+0, obj_x-50, obj_y+50, obj_y-30, obj_y-120, obj_y-180, listimage[5], "类别", False, 3)
elif obj_type == 7: # 误差计算
connectings_lines(obj_x, obj_y, obj_x+0, obj_x-20, obj_x-50, obj_x-420, obj_y-40,obj_y -20,obj_y -60,obj_y -60, listimage[6], "误差", False, 2)
connecting_lines(obj_x, obj_y, obj_x+0, obj_x-40, obj_x-170, obj_y-40, obj_y+0, obj_y+0, listimage[6], "误差", False, 2)
elif obj_type == 8: # 调整
connecting_lines(obj_x, obj_y, obj_x-80, obj_x+0, obj_x+0, obj_y+0, obj_y+30, obj_y+235, listimage[7], "调整1", False, 2)
elif obj_type == 9: # 调整
connecting_lines(obj_x, obj_y, obj_x-80, obj_x+0, obj_x+0, obj_y+0,obj_y+ 30,obj_y+ 183, listimage[7], "调整2", False, 2)
def creating_elements(AllModelObj,listimage):
# 遍历AllModelObj列表在窗口左侧创建图元菜单
for obj in AllModelObj:
# 获取图元对象的类型、标签等信息
obj_type = obj[1]
# 并且要根据需求调整每个对象的位置
obj_x = obj[6] # 根据对象的id计算x坐标
obj_y = obj[7] # 根据对象的id计算y坐标
# 根据对象的类型,绘制相应的图形
switch(obj_type, obj_x, obj_y,listimage)
def main():
global AllModelObj
window()
listinstance = create_instance()
listclass = connect_class(listinstance)
# 创建网络对象总表和网络连接对象总表
AllModelObj = [listinstance[0],listinstance[1], listinstance[2], listinstance[3], listinstance[4], listinstance[5],
listinstance[6], listinstance[7], listinstance[8]]
AllModelConn = [listclass[0], listclass[1], listclass[2], listclass[3], listclass[4], listclass[5], listclass[6],
listclass[7], listclass[8], listclass[9]]
img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
"img/classifier.png", "img/error.png", "img/adjust.png"]
list_image = []
for path in img_path:
list_image.append(element(path))
creating_elements(AllModelObj, list_image)
print(1)
Root.mainloop()
if __name__ == '__main__':
main()