You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
795 B
26 lines
795 B
import config
|
|
from data_utils import load_and_preprocess_data
|
|
from graph_builder import build_graphs_from_dataframe
|
|
from model import ADSB_GAT
|
|
from trainer import train_model
|
|
|
|
def main():
|
|
# 1. 加载和预处理数据
|
|
df, feature_cols, label_col = load_and_preprocess_data()
|
|
|
|
# 2. 构建图
|
|
train_graphs, test_graphs = build_graphs_from_dataframe(df, feature_cols, label_col)
|
|
|
|
if not train_graphs:
|
|
print("Error: No graphs were built. Check your data and time window settings.")
|
|
return
|
|
|
|
# 3. 初始化模型
|
|
model = ADSB_GAT(in_channels=len(feature_cols))
|
|
print(f"Model initialized with {len(feature_cols)} input features.")
|
|
|
|
# 4. 开始训练
|
|
train_model(model, train_graphs, test_graphs)
|
|
|
|
if __name__ == "__main__":
|
|
main() |