import tkinter as tk from PIL import Image, ImageTk from X1 import * def create_instance(): global AllModelObj global DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn DataSet = Data_Class("DataSet1", 1, "数据集1", ".", 120, 330) Conv = Conv_Class("Conv1", 2, "卷积1", ".", 250, 330) Pool = Pool_Class("Pool1", 3, "最大池化1", ".", 380, 330) FullConn = FullConn_Class("FullConn1", 4, "全连接1", ".", 510, 330) Nonline = Nonline_Class("Nonline1", 5, "非线性函数1", ".", 640, 330) Classifier = Classifier_Class("Classifier1", 6, "分类1", ".", 780, 330) Error = Error_Class("Error1", 7, "误差计算1", ".", 710, 124) AjConv = AjConv_Class("AjConv1", 8, "卷积调整1", ".", 250, 70) AjFullconn = AjFullconn_Class("AjFullconn1", 9, "全连接调整1", ".", 510, 120) AllModelObj = [DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn] def connect_class(): global AllModelConn # 创建连接对象实例 Line1 = ModelConn(1, 1, DataSet.ObjID, Conv.ObjID).output() Line2 = ModelConn(2, 1, Conv.ObjID, Pool.ObjID).output() Line3 = ModelConn(3, 1, Pool.ObjID, FullConn.ObjID).output() Line4 = ModelConn(4, 1, FullConn.ObjID, Nonline.ObjID).output() Line5 = ModelConn(5, 1, Nonline.ObjID, Classifier.ObjID).output() Line6 = ModelConn(6, 1, Classifier.ObjID, Error.ObjID).output() Line7 = ModelConn(7, 2, Error.ObjID, AjFullconn.ObjID).output() Line8 = ModelConn(8, 2, Error.ObjID, AjConv.ObjID).output() Line9 = ModelConn(9, 2, AjFullconn.ObjID, FullConn.ObjID).output() Line10 = ModelConn(10, 2, AjConv.ObjID, Conv.ObjID).output() # 网络连接对象总表 AllModelConn = [Line1, Line2, Line3, Line4, Line5, Line6, Line7, Line8, Line9, Line10] class Networking: def __init__(self): self.Root = tk.Tk() # 创建一个主窗口 # 设置窗口的大小为1200*750 self.window_width = 900 # 窗口的宽度 self.window_height = 550 # 窗口的高度 self.list_image = [] self.Item_Record = [[], []] # 记录图元坐标与图元号 def window(self): self.Root.title("神经网络可视化") self.Root.geometry("900x550") # 设置窗口的大小和位置 # 创建一个画布,用于绘制矩形框,设置画布的大小和背景色 self.Viewcanvas = tk.Canvas(self.Root, width=self.window_width, height=self.window_height, bg="white") # 将画布添加到主窗口中 self.Viewcanvas.pack() # 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度 self.Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2) def connecting_lines(self, obj): obj_x = obj.ObjX # 根据对象的id计算x坐标 obj_y = obj.ObjY # 根据对象的id计算y坐标 text = obj.ObjLable if 'Error' in obj.ObjID: x, y = 0, -50 elif 'Aj' in obj.ObjID: x, y = -80, 0 else: x, y = 0, 50 self.Viewcanvas.create_image(obj_x, obj_y, image=self.list_image[obj.ObjType - 1]) # 创建图元对象 self.Viewcanvas.create_text(obj_x + x, obj_y + y, text=text, font=("黑体", 14)) # 创建图元对象的标签 def conn_lines(self, conn): starting = self.Item_Record[1].index(conn[2]) ending = self.Item_Record[1].index(conn[3]) smooth = [False, True] width = [2, 4] start, end = self.Item_Record[0][starting], self.Item_Record[0][ending] index = 1 if conn[1] == 1 else 0 if start[0] == end[0]: self.Viewcanvas.create_line(start[0], start[1] + 30, end[0], end[1] - 30, arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index]) elif start[1] == end[1]: self.Viewcanvas.create_line(start[0] + 30, start[1], end[0] - 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index]) else: if abs(start[0] - end[0]) > abs(start[1] - end[1]): # 创建数据线箭头 self.Viewcanvas.create_line(start[0] - 15, start[1], int((start[0] + end[0]) / 2), end[1], end[0] + 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index]) else: # 创建数据线箭头 self.Viewcanvas.create_line(start[0], start[1] - 20, start[0], end[1], end[0] + 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index]) def creating_elements(self): text_record = [(0, -50), (0, 50), (-80, 0)] # 遍历AllModelObj列表,在窗口左侧创建图元菜单 for obj in self.AllModelObj: # 并且要根据需求调整每个对象的位置 obj_x = obj.ObjX # 根据对象的id计算x坐标 obj_y = obj.ObjY # 根据对象的id计算y坐标 self.Item_Record.append((obj_x, obj_y)) self.connecting_lines(obj,) def element(self, path): img = Image.open(path) # 加载图元对应的图片文件 img = img.resize((60, 50)) # 使用resize方法调整图片 img = ImageTk.PhotoImage(img) # 把Image对象转换成PhotoImage对象 self.Root.img = img # 保存图片的引用,防止被垃圾回收 return img def read_element(self): img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png", "img/classifier.png", "img/error.png", "img/adjust.png", "img/adjust.png"] for path in img_path: self.list_image.append(self.element(path)) def visual_output(self, AllModelObj, AllModelConn): for obj in AllModelObj: # 遍历AllModelObj列表,在窗口创建图元 self.Item_Record[0].append((obj.ObjX, obj.ObjY)) # 记录图元坐标 self.Item_Record[1].append(obj.ObjID) # 记录图元号 self.connecting_lines(obj) # 根据图元对象信息在画布上画图元 for conn in AllModelConn: # 遍历AllModelConn列表,在窗口连线图元 self.conn_lines(conn) if __name__ == '__main__': create_instance() connect_class() Net = Networking() Net.window() Net.read_element() Net.visual_output(AllModelObj, AllModelConn) Net.Root.mainloop()