You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
5.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from flask import Flask, render_template, request, jsonify
import torch
import pandas as pd
import numpy as np
import pickle
from torch_geometric.data import Data
import config
from model import ADSB_GAT
app = Flask(__name__)
# 全局变量
model = None
scaler = None
le_icao = None
FEATURE_COLS = None
def load_model_once():
global model, scaler, le_icao, FEATURE_COLS
print("正在加载模型和预处理参数...")
# 1. 加载训练好的GAT模型
model = ADSB_GAT(in_channels=8)
model.load_state_dict(torch.load(config.MODEL_SAVE_PATH, map_location=config.DEVICE))
model.to(config.DEVICE)
model.eval()
# 2. 加载预处理参数(和训练时完全一致)
with open("scaler.pkl", "rb") as f:
scaler = pickle.load(f)
with open("le_icao.pkl", "rb") as f:
le_icao = pickle.load(f)
with open("feature_cols.pkl", "rb") as f:
FEATURE_COLS = pickle.load(f)
print("✅ 加载完成!服务已就绪")
print(f" 模型输入特征:{FEATURE_COLS}")
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
# 1. 接收上传的文件
file = request.files['file']
if not file:
return jsonify({"error": "请先上传CSV文件"}), 400
# 2. 读取全量数据
df = pd.read_csv(file)
total_original = len(df)
# ====================== 【新增】只取前1万条数据 ======================
MAX_LINES = 10000
if total_original > MAX_LINES:
df = df.iloc[:MAX_LINES] # 只保留前10000行
# =======================================================================
df = df.dropna(axis=0)
total_valid = len(df)
if total_valid == 0:
return jsonify({"error": "文件中没有有效数据"}), 400
# 3. 预处理(和训练时完全一致的步骤)
# 3.1 编码ICAO24
if "icao24" in df.columns:
df["icao24"] = df["icao24"].astype(str)
known_icao = set(le_icao.classes_)
df["icao24_enc"] = df["icao24"].apply(lambda x: le_icao.transform([x])[0] if x in known_icao else 0)
else:
df["icao24_enc"] = 0
# 3.2 计算加速度衍生特征
df = df.sort_values(["time"]).reset_index(drop=True)
df['acceleration'] = df['velocity'].diff().fillna(0)
# 3.3 【核心修复】用完整8个特征做归一化和训练时完全匹配
df[FEATURE_COLS] = scaler.transform(df[FEATURE_COLS])
# 4. 按时间窗口切片构图
df["time_window"] = (df["time"] // config.TIME_WINDOW).astype(int)
window_groups = df.groupby("time_window")
all_preds = []
all_labels = []
for window_id, window_df in window_groups:
n = len(window_df)
# 节点太少直接判定为正常
if n < 3:
all_preds.extend([0]*n)
if "label" in window_df.columns:
all_labels.extend(window_df["label"].tolist())
continue
# 构建节点特征
x = torch.tensor(window_df[FEATURE_COLS].values, dtype=torch.float).to(config.DEVICE)
# 构建边(和训练时一致的规则)
edge_index = []
for i in range(n):
for j in range(i+1, min(i+6, n)):
edge_index.append([i, j])
edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous().to(config.DEVICE)
# 模型推理
with torch.no_grad():
out, _ = model(x, edge_index)
preds = out.argmax(dim=1).cpu().numpy().tolist()
all_preds.extend(preds)
# 如果有真实标签,同步保存
if "label" in window_df.columns:
all_labels.extend(window_df["label"].tolist())
# 5. 统计结果
total = len(all_preds)
attack_count = int(np.sum(np.array(all_preds) == 1))
normal_count = total - attack_count
attack_rate = round(attack_count / total * 100, 2)
# 有真实标签的话,计算准确率
accuracy = None
if all_labels and len(all_labels) == total:
correct = np.sum(np.array(all_preds) == np.array(all_labels))
accuracy = round(correct / total * 100, 2)
return jsonify({
"status": "success",
"total_original": total_original,
"total_valid": total,
"normal": normal_count,
"attack": attack_count,
"attack_rate": attack_rate,
"accuracy": accuracy
})
except Exception as e:
return jsonify({"error": f"处理失败:{str(e)}"}), 500
if __name__ == '__main__':
load_model_once()
print("服务启动成功!访问 http://localhost:5000")
app.run(debug=True, host='0.0.0.0', port=5000)