|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
Created on Wed May 14 08:06:07 2025
|
|
|
|
|
|
@author: 缄默
|
|
|
"""
|
|
|
|
|
|
import ollama
|
|
|
import concurrent.futures
|
|
|
import json
|
|
|
from typing import Dict, Any, Tuple
|
|
|
import csv
|
|
|
import re
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
# ==================== 数据库 ====================
|
|
|
# 数据库配置(需根据实际环境修改)
|
|
|
input_table_name = "only_text"
|
|
|
DB_CONFIG = {
|
|
|
'host': 'localhost',
|
|
|
'user': 'root',
|
|
|
'password': '111111',
|
|
|
'database': 'atc',
|
|
|
'charset': 'utf8mb4'
|
|
|
}
|
|
|
|
|
|
# 读取函数实现
|
|
|
def read_from_table(table_name, key_name):
|
|
|
"""
|
|
|
读取指定表的指定列数据
|
|
|
:param table_name: 目标表名
|
|
|
:param key_name: 目标字段名
|
|
|
:return: 包含所有字段值的列表
|
|
|
"""
|
|
|
try:
|
|
|
conn = pymysql.connect(**DB_CONFIG)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
# 使用反引号包裹防止SQL注入
|
|
|
query = f"SELECT `{key_name}` FROM `{table_name}`"
|
|
|
cursor.execute(query)
|
|
|
|
|
|
# 提取结果并转换为列表
|
|
|
return [row[0] for row in cursor.fetchall()]
|
|
|
except Exception as e:
|
|
|
print(f"Database error: {str(e)}")
|
|
|
return []
|
|
|
finally:
|
|
|
if 'conn' in locals() and conn.open:
|
|
|
cursor.close()
|
|
|
conn.close()
|
|
|
|
|
|
# 写入函数实现
|
|
|
def write_to_table(data, headers, table_name):
|
|
|
"""
|
|
|
写入数据到指定表(自动建表)
|
|
|
:param data: 二维数据列表
|
|
|
:param headers: 表头列表
|
|
|
:param table_name: 目标表名
|
|
|
"""
|
|
|
try:
|
|
|
conn = pymysql.connect(**DB_CONFIG)
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
# 自动建表逻辑
|
|
|
columns = [f"`{col}` TEXT" for col in headers]
|
|
|
create_sql = f"CREATE TABLE IF NOT EXISTS `{table_name}` ({','.join(columns)})"
|
|
|
cursor.execute(create_sql)
|
|
|
|
|
|
# 数据预处理(填充None)
|
|
|
processed_data = []
|
|
|
for row in data:
|
|
|
extended_row = list(row) + [None] * (len(headers) - len(row))
|
|
|
processed_data.append(extended_row)
|
|
|
|
|
|
# 构造插入语句
|
|
|
fields = ', '.join([f'`{h}`' for h in headers])
|
|
|
placeholders = ', '.join(['%s'] * len(headers))
|
|
|
insert_sql = f"INSERT INTO `{table_name}` ({fields}) VALUES ({placeholders})"
|
|
|
|
|
|
# 批量插入数据
|
|
|
cursor.executemany(insert_sql, processed_data)
|
|
|
conn.commit()
|
|
|
except Exception as e:
|
|
|
print(f"Database error: {str(e)}")
|
|
|
conn.rollback()
|
|
|
finally:
|
|
|
if 'conn' in locals() and conn.open:
|
|
|
cursor.close()
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
HANDER = ["ID","CallSignal","Behavior","FlightLevel","Location","Time"]
|
|
|
# ==================== 读入文件/处理 ====================
|
|
|
def read_csv_to_array(file_path):
|
|
|
with open(file_path, 'r', newline='', encoding='utf-8') as file:
|
|
|
csv_reader = csv.reader(file)
|
|
|
data = list(csv_reader)
|
|
|
return data
|
|
|
def extract_result(text):
|
|
|
# 匹配 "result:{...}" 格式的部分
|
|
|
match = re.search(r'result:(\s*\{.*?\})', text, re.DOTALL)
|
|
|
if not match:
|
|
|
return ""
|
|
|
|
|
|
json_str = match.group(1).strip()
|
|
|
|
|
|
return json_str
|
|
|
def parse_input(s):
|
|
|
# 去掉最外层的大括号
|
|
|
s = s.replace('\n','')
|
|
|
#s = s.replace(' ','')
|
|
|
s = s.strip('{}')
|
|
|
# 去掉首尾的方括号
|
|
|
if len(s) >= 2 and s[0] == '[' and s[-1] == ']':
|
|
|
s = s[1:-1]
|
|
|
# 分割成各个行
|
|
|
rows = s.split('],[')
|
|
|
# 分割每行的字段并生成二维列表
|
|
|
return [row.split(',') for row in rows]
|
|
|
# ==================== 核心处理模块 ====================
|
|
|
class CallSignExtractor:
|
|
|
def __init__(self, model1="deepseek-r1:8b", model2="qwen2"):
|
|
|
self.model1 = model1
|
|
|
self.model2 = model2
|
|
|
self.prompt_template = """As an aviation communications analyst, extract all flight communications data from the following ATC dialogue following these strict rules:
|
|
|
Data Extraction Requirements:
|
|
|
Identify all aircraft call signs (e.g. “Lufthansa Seven Thirty Nine”) using standard aviation patterns.
|
|
|
Extraction is performed for each occurrence of the callsign:
|
|
|
Behavior: action phrase following the callsign (climb/descend/left turn/pointing/hold, etc.)
|
|
|
Flight level: number extracted from the phrase “FL310”/“level three zero zero zero”,only output number,such as "three zero zero zero" (NULL if not present)
|
|
|
Position: NAV waypoint (3-5 letter code) or airport name (NULL if none)
|
|
|
Time: UTC timestamp (if explicitly mentioned) (NULL if not present)
|
|
|
|
|
|
Format Rules:
|
|
|
Multiple entries for same callsign must be listed separately
|
|
|
Output must be ordered: [id,callsign,behavior,flight_level,location,time]
|
|
|
Use NULL for missing information
|
|
|
Strict CSV format inside result brackets with NO additional text
|
|
|
|
|
|
Processing Logic:
|
|
|
Analyze context before/after callsign mentions
|
|
|
Preserve original callsign wording exactly
|
|
|
Treat multi-line interactions as single context when connected
|
|
|
|
|
|
Example Input:
|
|
|
"APP-pZiJZT", "AT: lufthansa two juliett whiskey climb FL250\nPI: turn left direct kolad lufthansa two juliett whiskey"
|
|
|
|
|
|
Example Output:
|
|
|
result:{{[APP-pZiJZT,lufthansa two juliett whiskey,climbing,two five zero,NULL,NULL],[APP-pZiJZT,lufthansa two juliett whiskey,turn left,NULL,kolad,NULL]}}/result
|
|
|
|
|
|
Process this transmission:
|
|
|
ID:{cid}
|
|
|
Text:{ctext}"""
|
|
|
|
|
|
def _call_model(self, model_name: str, prompt: str) -> str:
|
|
|
"""调用Ollama模型并返回响应"""
|
|
|
try:
|
|
|
response = ollama.generate(model=model_name, prompt=prompt)
|
|
|
rs = extract_result(response["response"])
|
|
|
#print(rs,type(rs))
|
|
|
return rs
|
|
|
except Exception as e:
|
|
|
print(f"模型调用错误 ({model_name}): {e}")
|
|
|
return ""
|
|
|
|
|
|
def _dual_model_inference(self, prompt: str) -> Tuple[str, str]:
|
|
|
"""并行执行双模型推理"""
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
future1 = executor.submit(self._call_model, self.model1, prompt)
|
|
|
future2 = executor.submit(self._call_model, self.model2, prompt)
|
|
|
return future1.result(), future2.result()
|
|
|
|
|
|
def _model_inference(self, prompt: str) -> Tuple[str, str]:
|
|
|
"""并行执行双模型推理"""
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
future = executor.submit(self._call_model, self.model2, prompt)
|
|
|
return future.result()
|
|
|
|
|
|
def _merge_results(self, result1: str, result2: str) -> Dict[str, str]:
|
|
|
"""智能融合双模型结果"""
|
|
|
try:
|
|
|
json1 = json.loads(result1) if result1.strip() else {}
|
|
|
except json.JSONDecodeError:
|
|
|
json1 = {}
|
|
|
|
|
|
try:
|
|
|
json2 = json.loads(result2) if result2.strip() else {}
|
|
|
except json.JSONDecodeError:
|
|
|
json2 = {}
|
|
|
|
|
|
# 合并策略:保留所有识别结果,冲突时优先模型1
|
|
|
merged = {**json1, **json2} # 后者覆盖前者
|
|
|
return merged
|
|
|
|
|
|
def extract_call_signs_DoubleModel(self, cid: str,ctext: str) -> Dict[str, str]:
|
|
|
"""执行完整提取流程"""
|
|
|
prompt = self.prompt_template.format(cid=cid,ctext=ctext)
|
|
|
result1, result2 = self._dual_model_inference(prompt)
|
|
|
return self._merge_results(result1, result2)
|
|
|
def extract_call_signs(self, cid: str,ctext: str) -> Dict[str, str]:
|
|
|
"""执行完整提取流程"""
|
|
|
prompt = self.prompt_template.format(cid=cid,ctext=ctext)
|
|
|
result = self._model_inference(prompt)
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 主执行流程 ====================
|
|
|
def main():
|
|
|
# 示例输入数据
|
|
|
#input_data = read_csv_to_array(r'./output_lite.csv')
|
|
|
|
|
|
# 预处理
|
|
|
id_data = read_from_table(input_table_name, "id")
|
|
|
text_data = read_from_table(input_table_name, "text")
|
|
|
data = [[id_data[i],text_data[i]] for i in range(len(id_data))]
|
|
|
#print(raw_text) #把id去掉,只返回text字段
|
|
|
# 初始化提取器
|
|
|
extractor = CallSignExtractor()
|
|
|
all_result = []
|
|
|
# 执行提取
|
|
|
for i in range(len(data)):
|
|
|
raw_id,raw_text = data[i][0],data[i][1]
|
|
|
#result = extractor.extract_call_signs_DoubleModel(raw_id,raw_text)
|
|
|
result = extractor.extract_call_signs(raw_id,raw_text)
|
|
|
# 格式化输出
|
|
|
result_list = parse_input(result)
|
|
|
print(result_list)
|
|
|
for row in result_list:
|
|
|
all_result.append(row)
|
|
|
write_to_table(all_result,HANDER, "CsExtraction")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |