|
|
from flask import Flask, render_template, request, jsonify
|
|
|
import torch
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import pickle
|
|
|
from torch_geometric.data import Data
|
|
|
import config
|
|
|
from model import ADSB_GAT
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
# 全局变量
|
|
|
model = None
|
|
|
scaler = None
|
|
|
le_icao = None
|
|
|
FEATURE_COLS = None
|
|
|
|
|
|
def load_model_once():
|
|
|
global model, scaler, le_icao, FEATURE_COLS
|
|
|
print("正在加载模型和预处理参数...")
|
|
|
|
|
|
# 1. 加载训练好的GAT模型
|
|
|
model = ADSB_GAT(in_channels=8)
|
|
|
model.load_state_dict(torch.load(config.MODEL_SAVE_PATH, map_location=config.DEVICE))
|
|
|
model.to(config.DEVICE)
|
|
|
model.eval()
|
|
|
|
|
|
# 2. 加载预处理参数(和训练时完全一致)
|
|
|
with open("scaler.pkl", "rb") as f:
|
|
|
scaler = pickle.load(f)
|
|
|
with open("le_icao.pkl", "rb") as f:
|
|
|
le_icao = pickle.load(f)
|
|
|
with open("feature_cols.pkl", "rb") as f:
|
|
|
FEATURE_COLS = pickle.load(f)
|
|
|
|
|
|
print("✅ 加载完成!服务已就绪")
|
|
|
print(f" 模型输入特征:{FEATURE_COLS}")
|
|
|
|
|
|
@app.route('/')
|
|
|
def index():
|
|
|
return render_template('index.html')
|
|
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
|
def predict():
|
|
|
try:
|
|
|
# 1. 接收上传的文件
|
|
|
file = request.files['file']
|
|
|
if not file:
|
|
|
return jsonify({"error": "请先上传CSV文件"}), 400
|
|
|
|
|
|
# 2. 读取全量数据
|
|
|
df = pd.read_csv(file)
|
|
|
total_original = len(df)
|
|
|
|
|
|
# ====================== 【新增】只取前1万条数据 ======================
|
|
|
MAX_LINES = 10000
|
|
|
if total_original > MAX_LINES:
|
|
|
df = df.iloc[:MAX_LINES] # 只保留前10000行
|
|
|
# =======================================================================
|
|
|
|
|
|
df = df.dropna(axis=0)
|
|
|
total_valid = len(df)
|
|
|
|
|
|
if total_valid == 0:
|
|
|
return jsonify({"error": "文件中没有有效数据"}), 400
|
|
|
|
|
|
# 3. 预处理(和训练时完全一致的步骤)
|
|
|
# 3.1 编码ICAO24
|
|
|
if "icao24" in df.columns:
|
|
|
df["icao24"] = df["icao24"].astype(str)
|
|
|
known_icao = set(le_icao.classes_)
|
|
|
df["icao24_enc"] = df["icao24"].apply(lambda x: le_icao.transform([x])[0] if x in known_icao else 0)
|
|
|
else:
|
|
|
df["icao24_enc"] = 0
|
|
|
|
|
|
# 3.2 计算加速度衍生特征
|
|
|
df = df.sort_values(["time"]).reset_index(drop=True)
|
|
|
df['acceleration'] = df['velocity'].diff().fillna(0)
|
|
|
|
|
|
# 3.3 【核心修复】用完整8个特征做归一化,和训练时完全匹配
|
|
|
df[FEATURE_COLS] = scaler.transform(df[FEATURE_COLS])
|
|
|
|
|
|
# 4. 按时间窗口切片构图
|
|
|
df["time_window"] = (df["time"] // config.TIME_WINDOW).astype(int)
|
|
|
window_groups = df.groupby("time_window")
|
|
|
|
|
|
all_preds = []
|
|
|
all_labels = []
|
|
|
|
|
|
for window_id, window_df in window_groups:
|
|
|
n = len(window_df)
|
|
|
# 节点太少直接判定为正常
|
|
|
if n < 3:
|
|
|
all_preds.extend([0]*n)
|
|
|
if "label" in window_df.columns:
|
|
|
all_labels.extend(window_df["label"].tolist())
|
|
|
continue
|
|
|
|
|
|
# 构建节点特征
|
|
|
x = torch.tensor(window_df[FEATURE_COLS].values, dtype=torch.float).to(config.DEVICE)
|
|
|
|
|
|
# 构建边(和训练时一致的规则)
|
|
|
edge_index = []
|
|
|
for i in range(n):
|
|
|
for j in range(i+1, min(i+6, n)):
|
|
|
edge_index.append([i, j])
|
|
|
edge_index.append([j, i])
|
|
|
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous().to(config.DEVICE)
|
|
|
|
|
|
# 模型推理
|
|
|
with torch.no_grad():
|
|
|
out, _ = model(x, edge_index)
|
|
|
preds = out.argmax(dim=1).cpu().numpy().tolist()
|
|
|
|
|
|
all_preds.extend(preds)
|
|
|
# 如果有真实标签,同步保存
|
|
|
if "label" in window_df.columns:
|
|
|
all_labels.extend(window_df["label"].tolist())
|
|
|
|
|
|
# 5. 统计结果
|
|
|
total = len(all_preds)
|
|
|
attack_count = int(np.sum(np.array(all_preds) == 1))
|
|
|
normal_count = total - attack_count
|
|
|
attack_rate = round(attack_count / total * 100, 2)
|
|
|
|
|
|
# 有真实标签的话,计算准确率
|
|
|
accuracy = None
|
|
|
if all_labels and len(all_labels) == total:
|
|
|
correct = np.sum(np.array(all_preds) == np.array(all_labels))
|
|
|
accuracy = round(correct / total * 100, 2)
|
|
|
|
|
|
return jsonify({
|
|
|
"status": "success",
|
|
|
"total_original": total_original,
|
|
|
"total_valid": total,
|
|
|
"normal": normal_count,
|
|
|
"attack": attack_count,
|
|
|
"attack_rate": attack_rate,
|
|
|
"accuracy": accuracy
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
return jsonify({"error": f"处理失败:{str(e)}"}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
load_model_once()
|
|
|
print("服务启动成功!访问 http://localhost:5000")
|
|
|
app.run(debug=True, host='0.0.0.0', port=5000) |