You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
6.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tkinter as tk
from PIL import Image, ImageTk
from X1 import *
def create_instance():
global AllModelObj
global DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn
DataSet = Data_Class("DataSet1", 1, "数据集1", ".", 120, 330)
Conv = Conv_Class("Conv1", 2, "卷积1", ".", 250, 330)
Pool = Pool_Class("Pool1", 3, "最大池化1", ".", 380, 330)
FullConn = FullConn_Class("FullConn1", 4, "全连接1", ".", 510, 330)
Nonline = Nonline_Class("Nonline1", 5, "非线性函数1", ".", 640, 330)
Classifier = Classifier_Class("Classifier1", 6, "分类1", ".", 780, 330)
Error = Error_Class("Error1", 7, "误差计算1", ".", 710, 124)
AjConv = AjConv_Class("AjConv1", 8, "卷积调整1", ".", 250, 70)
AjFullconn = AjFullconn_Class("AjFullconn1", 9, "全连接调整1", ".", 510, 120)
AllModelObj = [DataSet, Conv, Pool, FullConn, Nonline, Classifier, Error, AjConv, AjFullconn]
def connect_class():
global AllModelConn
# 创建连接对象实例
Line1 = ModelConn(1, 1, DataSet.ObjID, Conv.ObjID).output()
Line2 = ModelConn(2, 1, Conv.ObjID, Pool.ObjID).output()
Line3 = ModelConn(3, 1, Pool.ObjID, FullConn.ObjID).output()
Line4 = ModelConn(4, 1, FullConn.ObjID, Nonline.ObjID).output()
Line5 = ModelConn(5, 1, Nonline.ObjID, Classifier.ObjID).output()
Line6 = ModelConn(6, 1, Classifier.ObjID, Error.ObjID).output()
Line7 = ModelConn(7, 2, Error.ObjID, AjFullconn.ObjID).output()
Line8 = ModelConn(8, 2, Error.ObjID, AjConv.ObjID).output()
Line9 = ModelConn(9, 2, AjFullconn.ObjID, FullConn.ObjID).output()
Line10 = ModelConn(10, 2, AjConv.ObjID, Conv.ObjID).output()
# 网络连接对象总表
AllModelConn = [Line1, Line2, Line3, Line4,
Line5, Line6, Line7, Line8,
Line9, Line10]
class Networking:
def __init__(self):
self.Root = tk.Tk() # 创建一个主窗口
# 设置窗口的大小为1200*750
self.window_width = 900 # 窗口的宽度
self.window_height = 550 # 窗口的高度
self.list_image = []
self.Item_Record = [[], []] # 记录图元坐标与图元号
def window(self):
self.Root.title("神经网络可视化")
self.Root.geometry("900x550") # 设置窗口的大小和位置
# 创建一个画布,用于绘制矩形框,设置画布的大小和背景色
self.Viewcanvas = tk.Canvas(self.Root, width=self.window_width, height=self.window_height, bg="white")
# 将画布添加到主窗口中
self.Viewcanvas.pack()
# 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度
self.Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2)
def connecting_lines(self, obj):
obj_x = obj.ObjX # 根据对象的id计算x坐标
obj_y = obj.ObjY # 根据对象的id计算y坐标
text = obj.ObjLable
if 'Error' in obj.ObjID:
x, y = 0, -50
elif 'Aj' in obj.ObjID:
x, y = -80, 0
else:
x, y = 0, 50
self.Viewcanvas.create_image(obj_x, obj_y, image=self.list_image[obj.ObjType - 1]) # 创建图元对象
self.Viewcanvas.create_text(obj_x + x, obj_y + y, text=text, font=("黑体", 14)) # 创建图元对象的标签
def conn_lines(self, conn):
starting = self.Item_Record[1].index(conn[2])
ending = self.Item_Record[1].index(conn[3])
smooth = [False, True]
width = [2, 4]
start, end = self.Item_Record[0][starting], self.Item_Record[0][ending]
index = 1 if conn[1] == 1 else 0
if start[0] == end[0]:
self.Viewcanvas.create_line(start[0], start[1] + 30, end[0], end[1] - 30, arrow=tk.LAST,
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
elif start[1] == end[1]:
self.Viewcanvas.create_line(start[0] + 30, start[1], end[0] - 30, end[1], arrow=tk.LAST,
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
else:
if abs(start[0] - end[0]) > abs(start[1] - end[1]):
# 创建数据线箭头
self.Viewcanvas.create_line(start[0] - 15, start[1], int((start[0] + end[0]) / 2), end[1], end[0] + 30,
end[1], arrow=tk.LAST,
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index],
width=width[index])
else:
# 创建数据线箭头
self.Viewcanvas.create_line(start[0], start[1] - 20, start[0], end[1], end[0] + 30, end[1], arrow=tk.LAST,
arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index],
width=width[index])
def creating_elements(self):
text_record = [(0, -50), (0, 50), (-80, 0)]
# 遍历AllModelObj列表在窗口左侧创建图元菜单
for obj in self.AllModelObj:
# 并且要根据需求调整每个对象的位置
obj_x = obj.ObjX # 根据对象的id计算x坐标
obj_y = obj.ObjY # 根据对象的id计算y坐标
self.Item_Record.append((obj_x, obj_y))
self.connecting_lines(obj,)
def element(self, path):
img = Image.open(path) # 加载图元对应的图片文件
img = img.resize((60, 50)) # 使用resize方法调整图片
img = ImageTk.PhotoImage(img) # 把Image对象转换成PhotoImage对象
self.Root.img = img # 保存图片的引用,防止被垃圾回收
return img
def read_element(self):
img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png",
"img/classifier.png", "img/error.png", "img/adjust.png", "img/adjust.png"]
for path in img_path:
self.list_image.append(self.element(path))
def visual_output(self, AllModelObj, AllModelConn):
for obj in AllModelObj: # 遍历AllModelObj列表在窗口创建图元
self.Item_Record[0].append((obj.ObjX, obj.ObjY)) # 记录图元坐标
self.Item_Record[1].append(obj.ObjID) # 记录图元号
self.connecting_lines(obj) # 根据图元对象信息在画布上画图元
for conn in AllModelConn: # 遍历AllModelConn列表在窗口连线图元
self.conn_lines(conn)
if __name__ == '__main__':
create_instance()
connect_class()
Net = Networking()
Net.window()
Net.read_element()
Net.visual_output(AllModelObj, AllModelConn)
Net.Root.mainloop()