|
|
import tkinter as tk
|
|
|
from PIL import Image, ImageTk
|
|
|
from X1 import *
|
|
|
global Viewcanvas # 定义画布
|
|
|
global Root # 主窗口
|
|
|
global AllModelObj #网络对象
|
|
|
|
|
|
'''
|
|
|
【编程16.5】编制程序:依据AllModelObj和AllModelConn数据结构产生如图16.2的输出界面。
|
|
|
【目的及编程说明】读者通过编程16.5可理解卷积神经网络模型构建的输出界面。数据结构及初始值参见【编程16.1】。
|
|
|
'''
|
|
|
|
|
|
|
|
|
def create_instance():
|
|
|
global AllModelObj
|
|
|
global DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn
|
|
|
DataSet = Data_Class("DataSet1", 1, "数据集1", ".", 120, 330)
|
|
|
Conv = Conv_Class("Conv1", 2, "卷积1", ".", 250, 330)
|
|
|
Pool = Pool_Class("Pool1", 3, "最大池化1", ".", 380, 330)
|
|
|
FullConn = FullConn_Class("FullConn1", 4, "全连接1", ".", 510, 330)
|
|
|
Nonline = Nonline_Class("Nonline1", 5, "非线性函数1", ".", 640, 330)
|
|
|
Classifier = Classifier_Class("Classifier1", 6, "分类1", ".", 780, 330)
|
|
|
Error = Error_Class("Error1", 7, "误差计算1", ".", 710, 124)
|
|
|
AjConv = AjConv_Class("AjConv1", 8, "卷积调整1", ".", 250, 70)
|
|
|
AjFullconn = AjFullconn_Class("AjFullconn1", 9, "全连接调整1", ".", 510, 120)
|
|
|
AllModelObj = [DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn]
|
|
|
|
|
|
def connect_class():
|
|
|
global AllModelConn
|
|
|
# 创建连接对象实例
|
|
|
Line1 = ModelConn(1, 1, DataSet.ObjID, Conv.ObjID).output()
|
|
|
Line2 = ModelConn(2, 1, Conv.ObjID, Pool.ObjID).output()
|
|
|
Line3 = ModelConn(3, 1, Pool.ObjID, FullConn.ObjID).output()
|
|
|
Line4 = ModelConn(4, 1, FullConn.ObjID, Nonline.ObjID).output()
|
|
|
Line5 = ModelConn(5, 1, Nonline.ObjID, Classifier.ObjID).output()
|
|
|
Line6 = ModelConn(6, 1, Classifier.ObjID, Error.ObjID).output()
|
|
|
Line7 = ModelConn(7, 2, Error.ObjID, AjFullconn.ObjID).output()
|
|
|
Line8 = ModelConn(8, 2, Error.ObjID, AjConv.ObjID).output()
|
|
|
Line9 = ModelConn(9, 2, AjFullconn.ObjID, FullConn.ObjID).output()
|
|
|
Line10 = ModelConn(10, 2, AjConv.ObjID, Conv.ObjID).output()
|
|
|
# 网络连接对象总表
|
|
|
AllModelConn = [Line1, Line2, Line3, Line4,
|
|
|
Line5, Line6, Line7, Line8,
|
|
|
Line9, Line10]
|
|
|
|
|
|
def element(path):
|
|
|
img = Image.open(path) # 加载图元对应的图片文件
|
|
|
img = img.resize((60, 50)) # 使用resize方法调整图片
|
|
|
img = ImageTk.PhotoImage(img) # 把Image对象转换成PhotoImage对象
|
|
|
Root.img = img # 保存图片的引用,防止被垃圾回收
|
|
|
return img
|
|
|
|
|
|
def window():
|
|
|
global Root
|
|
|
global Viewcanvas
|
|
|
Root = tk.Tk() # 创建一个主窗口
|
|
|
# 设置窗口的大小为1200*750
|
|
|
window_width = 900 # 窗口的宽度
|
|
|
window_height = 550 # 窗口的高度
|
|
|
Root.title("神经网络可视化")
|
|
|
Root.geometry("900x550") # 设置窗口的大小和位置
|
|
|
# 创建一个画布,用于绘制矩形框,设置画布的大小和背景色
|
|
|
Viewcanvas = tk.Canvas(Root, width=window_width, height=window_height, bg="white")
|
|
|
# 将画布添加到主窗口中
|
|
|
Viewcanvas.pack()
|
|
|
# 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度
|
|
|
Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2)
|
|
|
|
|
|
def connecting_lines(obj_x, obj_y, text, text_record,image):
|
|
|
Viewcanvas.create_image(obj_x, obj_y, image=image) # 创建图元对象
|
|
|
Viewcanvas.create_text(obj_x + text_record[0], obj_y + text_record[1], text=text, font=("黑体", 14)) # 创建图元对象的标签
|
|
|
|
|
|
def conn_lines(start, end, index):
|
|
|
smooth = [False, True]
|
|
|
width = [2, 4]
|
|
|
if start[0] == end[0]:
|
|
|
Viewcanvas.create_line(start[0], start[1] + 30, end[0] , end[1] - 30, arrow=tk.LAST,
|
|
|
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
|
elif start[1] == end[1]:
|
|
|
Viewcanvas.create_line(start[0] + 30, start[1], end[0] - 30, end[1], arrow=tk.LAST,
|
|
|
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
|
else:
|
|
|
if abs(start[0]-end[0]) > abs(start[1]-end[1]):
|
|
|
# 创建数据线箭头
|
|
|
Viewcanvas.create_line(start[0]-15, start[1], int((start[0] + end[0])/2), end[1], end[0] + 30, end[1], arrow=tk.LAST,
|
|
|
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
|
else:
|
|
|
# 创建数据线箭头
|
|
|
Viewcanvas.create_line(start[0], start[1] - 20, start[0], end[1], end[0] + 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
|
|
|
|
def creating_elements():
|
|
|
text_record = [(0, -50), (0, 50), (-80, 0)]
|
|
|
# 遍历AllModelObj列表,在窗口左侧创建图元菜单
|
|
|
for obj in AllModelObj:
|
|
|
# 并且要根据需求调整每个对象的位置
|
|
|
obj_x = obj.ObjX # 根据对象的id计算x坐标
|
|
|
obj_y = obj.ObjY # 根据对象的id计算y坐标
|
|
|
Item_Record.append((obj_x, obj_y))
|
|
|
Item_Name.append(obj.ObjID)
|
|
|
# 根据对象的类型,绘制相应的图形
|
|
|
if 'Error' in obj.ObjID:
|
|
|
connecting_lines(obj_x, obj_y, obj.ObjLable, text_record[0], list_image[obj.ObjType - 1])
|
|
|
elif 'Aj' in obj.ObjID:
|
|
|
connecting_lines(obj_x, obj_y, obj.ObjLable, text_record[2], list_image[-1])
|
|
|
else:
|
|
|
connecting_lines(obj_x, obj_y, obj.ObjLable, text_record[1], list_image[obj.ObjType - 1])
|
|
|
|
|
|
def ligature(): # 连接线
|
|
|
# print(Item_Record)
|
|
|
for conn in AllModelConn:
|
|
|
starting = Item_Name.index(conn[2])
|
|
|
# print(starting)
|
|
|
ending = Item_Name.index(conn[3])
|
|
|
if conn[1] == 1:
|
|
|
# print(Item_Record[starting])
|
|
|
conn_lines(Item_Record[starting], Item_Record[ending], 1)
|
|
|
else:
|
|
|
conn_lines(Item_Record[starting], Item_Record[ending], 0)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
global AllModelObj
|
|
|
Item_Record = []
|
|
|
Item_Name = []
|
|
|
window()
|
|
|
create_instance()
|
|
|
connect_class()
|
|
|
img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
|
|
|
"img/classifier.png", "img/error.png", "img/adjust.png"]
|
|
|
list_image = []
|
|
|
for path in img_path:
|
|
|
list_image.append(element(path))
|
|
|
creating_elements()
|
|
|
ligature()
|
|
|
Root.mainloop()
|
|
|
# print(Item_Record) |