forked from puhvqweop/MachineLearning
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
5.6 KiB
105 lines
5.6 KiB
import tkinter as tk
|
|
from PIL import Image, ImageTk
|
|
class Networking:
|
|
def __init__(self):
|
|
self.Root = tk.Tk() # 创建一个主窗口
|
|
self.window_width = 900 # 窗口的宽度
|
|
self.window_height = 550 # 窗口的高度
|
|
self.list_image = [] # 图元列表
|
|
self.Item_Record = [[], []] # 记录图元坐标与图元号
|
|
def window(self):
|
|
self.Root.title("神经网络可视化")
|
|
self.Root.geometry("900x550") # 设置窗口的大小和位置
|
|
# 创建一个画布,用于绘制矩形框,设置画布的大小和背景色
|
|
self.Viewcanvas = tk.Canvas(self.Root, width=self.window_width, height=self.window_height, bg="white")
|
|
# 将画布添加到主窗口中
|
|
self.Viewcanvas.pack()
|
|
# 绘制矩形框,使用不同的颜色和线宽,指定矩形框的左上角和右下角坐标,填充色,边框色和边框宽度
|
|
self.Viewcanvas.create_rectangle(5, 5, 895, 545, fill=None, outline="lightblue", width=2)
|
|
def connecting_lines(self, obj):
|
|
obj_x = obj.ObjX # 根据对象的id计算x坐标
|
|
obj_y = obj.ObjY # 根据对象的id计算y坐标
|
|
text = obj.ObjLable
|
|
if 'Error' in obj.ObjID:
|
|
x, y = 0, -50
|
|
elif 'Aj' in obj.ObjID:
|
|
x, y = -80, 0
|
|
else:
|
|
x, y = 0, 50
|
|
self.Viewcanvas.create_image(obj_x, obj_y, image=self.list_image[obj.ObjType - 1]) # 创建图元对象
|
|
self.Viewcanvas.create_text(obj_x + x, obj_y + y, text=text, font=("黑体", 14)) # 创建图元对象的标签
|
|
|
|
def conn_lines(self, conn):
|
|
starting = self.Item_Record[1].index(conn[2])
|
|
ending = self.Item_Record[1].index(conn[3])
|
|
smooth = [False, True]
|
|
width = [2, 4]
|
|
start, end = self.Item_Record[0][starting], self.Item_Record[0][ending]
|
|
index = 1 if conn[1] == 1 else 0
|
|
if start[0] == end[0]:
|
|
self.Viewcanvas.create_line(start[0], start[1] + 30, end[0], end[1] - 30, arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
elif start[1] == end[1]:
|
|
self.Viewcanvas.create_line(start[0] + 30, start[1], end[0] - 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
else:
|
|
if abs(start[0] - end[0]) > abs(start[1] - end[1]):
|
|
# 创建数据线箭头
|
|
self.Viewcanvas.create_line(start[0] - 15, start[1], int((start[0] + end[0]) / 2), end[1], end[0] + 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
else:
|
|
# 创建数据线箭头
|
|
self.Viewcanvas.create_line(start[0], start[1] - 20, start[0], end[1], end[0] + 30, end[1], arrow=tk.LAST, arrowshape=(16, 20, 4), fill='lightblue', smooth=smooth[index], width=width[index])
|
|
def element(self, path):
|
|
img = Image.open(path) # 加载图元对应的图片文件
|
|
img = img.resize((60, 50)) # 使用resize方法调整图片
|
|
img = ImageTk.PhotoImage(img) # 把Image对象转换成PhotoImage对象
|
|
self.Root.img = img # 保存图片的引用,防止被垃圾回收
|
|
return img
|
|
def read_element(self):
|
|
img_path = ["img/data.png", "img/conv.png", "img/pool.png", "img/full_connect.png", "img/nonlinear.png", "img/classifier.png", "img/error.png", "img/adjust.png", "img/adjust.png"]
|
|
for path in img_path:
|
|
self.list_image.append(self.element(path))
|
|
def visual_output(self, AllModelObj, AllModelConn):
|
|
# 遍历 AllModelObj 列表,在窗口创建图元
|
|
for obj in AllModelObj:
|
|
# 记录图元坐标
|
|
self.Item_Record[0].append((obj.ObjX, obj.ObjY))
|
|
# 记录图元号
|
|
self.Item_Record[1].append(obj.ObjID)
|
|
# 根据图元对象信息在画布上画图元
|
|
self.connecting_lines(obj)
|
|
# 遍历 AllModelConn 列表,在窗口连线图元
|
|
for conn in AllModelConn:
|
|
# 根据连接对象信息在画布上连接图元
|
|
self.conn_lines(conn)
|
|
if __name__ == '__main__':
|
|
AllModelObj = [
|
|
['DataSet1', 1, '数据集1', 'LoadData',
|
|
'SetDataPara', [], 120, 330],
|
|
['Conv1', 2, '卷积1', 'ConvProc',
|
|
'SetConvPara', [], 250, 330],
|
|
['Pool1', 3, '最大池化1', 'MaxPoolProc',
|
|
'SetPollPara', [], 380, 330],
|
|
['FullConn1', 4, '全连接1', 'FullConnProc',
|
|
'SetFullConnPara', [], 510, 330],
|
|
['Nonline1', 5, '非线性函数1', 'NonlinearProc',
|
|
'SetNonLPara', [], 640, 330],
|
|
['Classifier1', 6, '分类1', 'ClassifierProc',
|
|
'SetClassifyPara', [], 780, 330],
|
|
['Error1', 7, '误差计算1', 'ErrorProc',
|
|
'SetErrorPara', [], 710, 124],
|
|
['AjConv1', 8, '卷积调整1', 'AjConvProc',
|
|
'SetAjConvPara', [], 250, 70],
|
|
['AjFullconn1', 9, '全连接调整1', 'AjFullconnProc',
|
|
'SetAjFCPara', [], 510, 120]]
|
|
AllModelConn = [
|
|
[1, 1, 'DataSet1', 'Conv1'], [2, 1, 'Conv1', 'Pool1'],
|
|
[3, 1, 'Pool1', 'FullConn1'], [4, 1, 'FullConn1', 'Nonline1'],
|
|
[5, 1, 'Nonline1', 'Classifier1'], [6, 1, 'Classifier1', 'Error1'],
|
|
[7, 2, 'Error1', 'AjFullconn1'], [8, 2, 'Error1', 'AjConv1'],
|
|
[9, 2, 'AjFullconn1', 'FullConn1'], [10, 2, 'AjConv1', 'Conv1']]
|
|
Net = Networking() # 创建 Networking 实例
|
|
Net.window() # 构造窗口
|
|
Net.read_element() # 读取图元
|
|
# 在窗口中可视化输出图元和连接
|
|
Net.visual_output(AllModelObj, AllModelConn)
|
|
Net.Root.mainloop() # 启动主事件循环
|