You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
2.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import pandas as pd
import numpy as np
from torch_geometric.data import Data
import config
import random
def build_graphs_from_dataframe(df, feature_cols, label_col):
print("Building graphs... FINAL PAPER-LEVEL VERSION")
# 1. 时间窗口改小一点,保证图的数量
df["time_window"] = (df["time"] // config.TIME_WINDOW).astype(int)
window_groups = df.groupby("time_window")
pyg_graph_list = []
for window_id, window_df in window_groups:
n = len(window_df)
# 【修复1】不限制上限只限制下限
if n < 3:
continue
x = torch.tensor(window_df[feature_cols].values, dtype=torch.float)
y = torch.tensor(window_df[label_col].values, dtype=torch.long)
# 【修复2】快速构图每个节点连前后5个
edge_index = []
for i in range(n):
for j in range(i+1, min(i+6, n)):
edge_index.append([i, j])
edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() if edge_index else torch.empty((2, 0), dtype=torch.long)
# 【修复3】掩码先全设为True后面全局划分
train_mask = torch.ones(n, dtype=torch.bool)
val_mask = torch.zeros(n, dtype=torch.bool)
test_mask = torch.zeros(n, dtype=torch.bool)
graph = Data(x=x, edge_index=edge_index, y=y,
train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
pyg_graph_list.append(graph)
# 最多15000张图足够了
if len(pyg_graph_list) >= 15000:
break
# 【修复4】全局彻底打乱保证训练集和测试集都有正常和攻击
print(f"Shuffling {len(pyg_graph_list)} graphs...")
random.shuffle(pyg_graph_list)
random.shuffle(pyg_graph_list)
split = int(len(pyg_graph_list) * 0.7)
train_graphs = pyg_graph_list[:split]
test_graphs = pyg_graph_list[split:]
# 给测试集打开 test_mask
for g in test_graphs:
g.train_mask[:] = False
g.test_mask[:] = True
print(f"Done! Train: {len(train_graphs)}, Test: {len(test_graphs)}")
return train_graphs, test_graphs