from flask import Flask, render_template, request, jsonify import torch import pandas as pd import numpy as np import pickle from torch_geometric.data import Data import config from model import ADSB_GAT app = Flask(__name__) # 全局变量 model = None scaler = None le_icao = None FEATURE_COLS = None def load_model_once(): global model, scaler, le_icao, FEATURE_COLS print("正在加载模型和预处理参数...") # 1. 加载训练好的GAT模型 model = ADSB_GAT(in_channels=8) model.load_state_dict(torch.load(config.MODEL_SAVE_PATH, map_location=config.DEVICE)) model.to(config.DEVICE) model.eval() # 2. 加载预处理参数(和训练时完全一致) with open("scaler.pkl", "rb") as f: scaler = pickle.load(f) with open("le_icao.pkl", "rb") as f: le_icao = pickle.load(f) with open("feature_cols.pkl", "rb") as f: FEATURE_COLS = pickle.load(f) print("✅ 加载完成!服务已就绪") print(f" 模型输入特征:{FEATURE_COLS}") @app.route('/') def index(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): try: # 1. 接收上传的文件 file = request.files['file'] if not file: return jsonify({"error": "请先上传CSV文件"}), 400 # 2. 读取全量数据 df = pd.read_csv(file) total_original = len(df) # ====================== 【新增】只取前1万条数据 ====================== MAX_LINES = 10000 if total_original > MAX_LINES: df = df.iloc[:MAX_LINES] # 只保留前10000行 # ======================================================================= df = df.dropna(axis=0) total_valid = len(df) if total_valid == 0: return jsonify({"error": "文件中没有有效数据"}), 400 # 3. 预处理(和训练时完全一致的步骤) # 3.1 编码ICAO24 if "icao24" in df.columns: df["icao24"] = df["icao24"].astype(str) known_icao = set(le_icao.classes_) df["icao24_enc"] = df["icao24"].apply(lambda x: le_icao.transform([x])[0] if x in known_icao else 0) else: df["icao24_enc"] = 0 # 3.2 计算加速度衍生特征 df = df.sort_values(["time"]).reset_index(drop=True) df['acceleration'] = df['velocity'].diff().fillna(0) # 3.3 【核心修复】用完整8个特征做归一化,和训练时完全匹配 df[FEATURE_COLS] = scaler.transform(df[FEATURE_COLS]) # 4. 按时间窗口切片构图 df["time_window"] = (df["time"] // config.TIME_WINDOW).astype(int) window_groups = df.groupby("time_window") all_preds = [] all_labels = [] for window_id, window_df in window_groups: n = len(window_df) # 节点太少直接判定为正常 if n < 3: all_preds.extend([0]*n) if "label" in window_df.columns: all_labels.extend(window_df["label"].tolist()) continue # 构建节点特征 x = torch.tensor(window_df[FEATURE_COLS].values, dtype=torch.float).to(config.DEVICE) # 构建边(和训练时一致的规则) edge_index = [] for i in range(n): for j in range(i+1, min(i+6, n)): edge_index.append([i, j]) edge_index.append([j, i]) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous().to(config.DEVICE) # 模型推理 with torch.no_grad(): out, _ = model(x, edge_index) preds = out.argmax(dim=1).cpu().numpy().tolist() all_preds.extend(preds) # 如果有真实标签,同步保存 if "label" in window_df.columns: all_labels.extend(window_df["label"].tolist()) # 5. 统计结果 total = len(all_preds) attack_count = int(np.sum(np.array(all_preds) == 1)) normal_count = total - attack_count attack_rate = round(attack_count / total * 100, 2) # 有真实标签的话,计算准确率 accuracy = None if all_labels and len(all_labels) == total: correct = np.sum(np.array(all_preds) == np.array(all_labels)) accuracy = round(correct / total * 100, 2) return jsonify({ "status": "success", "total_original": total_original, "total_valid": total, "normal": normal_count, "attack": attack_count, "attack_rate": attack_rate, "accuracy": accuracy }) except Exception as e: return jsonify({"error": f"处理失败:{str(e)}"}), 500 if __name__ == '__main__': load_model_once() print("服务启动成功!访问 http://localhost:5000") app.run(debug=True, host='0.0.0.0', port=5000)