You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
8.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
Created on Wed May 14 08:06:07 2025
@author: 缄默
"""
import ollama
import concurrent.futures
import json
from typing import Dict, Any, Tuple
import csv
import re
import pymysql
# ==================== 数据库 ====================
# 数据库配置(需根据实际环境修改)
input_table_name = "only_text"
DB_CONFIG = {
'host': 'localhost',
'user': 'root',
'password': '111111',
'database': 'atc',
'charset': 'utf8mb4'
}
# 读取函数实现
def read_from_table(table_name, key_name):
"""
读取指定表的指定列数据
:param table_name: 目标表名
:param key_name: 目标字段名
:return: 包含所有字段值的列表
"""
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
# 使用反引号包裹防止SQL注入
query = f"SELECT `{key_name}` FROM `{table_name}`"
cursor.execute(query)
# 提取结果并转换为列表
return [row[0] for row in cursor.fetchall()]
except Exception as e:
print(f"Database error: {str(e)}")
return []
finally:
if 'conn' in locals() and conn.open:
cursor.close()
conn.close()
# 写入函数实现
def write_to_table(data, headers, table_name):
"""
写入数据到指定表(自动建表)
:param data: 二维数据列表
:param headers: 表头列表
:param table_name: 目标表名
"""
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
# 自动建表逻辑
columns = [f"`{col}` TEXT" for col in headers]
create_sql = f"CREATE TABLE IF NOT EXISTS `{table_name}` ({','.join(columns)})"
cursor.execute(create_sql)
# 数据预处理填充None
processed_data = []
for row in data:
extended_row = list(row) + [None] * (len(headers) - len(row))
processed_data.append(extended_row)
# 构造插入语句
fields = ', '.join([f'`{h}`' for h in headers])
placeholders = ', '.join(['%s'] * len(headers))
insert_sql = f"INSERT INTO `{table_name}` ({fields}) VALUES ({placeholders})"
# 批量插入数据
cursor.executemany(insert_sql, processed_data)
conn.commit()
except Exception as e:
print(f"Database error: {str(e)}")
conn.rollback()
finally:
if 'conn' in locals() and conn.open:
cursor.close()
conn.close()
HANDER = ["ID","CallSignal","Behavior","FlightLevel","Location","Time"]
# ==================== 读入文件/处理 ====================
def read_csv_to_array(file_path):
with open(file_path, 'r', newline='', encoding='utf-8') as file:
csv_reader = csv.reader(file)
data = list(csv_reader)
return data
def extract_result(text):
# 匹配 "result:{...}" 格式的部分
match = re.search(r'result:(\s*\{.*?\})', text, re.DOTALL)
if not match:
return ""
json_str = match.group(1).strip()
return json_str
def parse_input(s):
# 去掉最外层的大括号
s = s.replace('\n','')
#s = s.replace(' ','')
s = s.strip('{}')
# 去掉首尾的方括号
if len(s) >= 2 and s[0] == '[' and s[-1] == ']':
s = s[1:-1]
# 分割成各个行
rows = s.split('],[')
# 分割每行的字段并生成二维列表
return [row.split(',') for row in rows]
# ==================== 核心处理模块 ====================
class CallSignExtractor:
def __init__(self, model1="deepseek-r1:8b", model2="qwen2"):
self.model1 = model1
self.model2 = model2
self.prompt_template = """As an aviation communications analyst, extract all flight communications data from the following ATC dialogue following these strict rules:
Data Extraction Requirements:
Identify all aircraft call signs (e.g. “Lufthansa Seven Thirty Nine”) using standard aviation patterns.
Extraction is performed for each occurrence of the callsign:
Behavior: action phrase following the callsign (climb/descend/left turn/pointing/hold, etc.)
Flight level: number extracted from the phrase “FL310”/“level three zero zero zero”,only output number,such as "three zero zero zero" (NULL if not present)
Position: NAV waypoint (3-5 letter code) or airport name (NULL if none)
Time: UTC timestamp (if explicitly mentioned) (NULL if not present)
Format Rules:
Multiple entries for same callsign must be listed separately
Output must be ordered: [id,callsign,behavior,flight_level,location,time]
Use NULL for missing information
Strict CSV format inside result brackets with NO additional text
Processing Logic:
Analyze context before/after callsign mentions
Preserve original callsign wording exactly
Treat multi-line interactions as single context when connected
Example Input:
"APP-pZiJZT", "AT: lufthansa two juliett whiskey climb FL250\nPI: turn left direct kolad lufthansa two juliett whiskey"
Example Output:
result:{{[APP-pZiJZT,lufthansa two juliett whiskey,climbing,two five zero,NULL,NULL],[APP-pZiJZT,lufthansa two juliett whiskey,turn left,NULL,kolad,NULL]}}/result
Process this transmission:
ID:{cid}
Text:{ctext}"""
def _call_model(self, model_name: str, prompt: str) -> str:
"""调用Ollama模型并返回响应"""
try:
response = ollama.generate(model=model_name, prompt=prompt)
rs = extract_result(response["response"])
#print(rs,type(rs))
return rs
except Exception as e:
print(f"模型调用错误 ({model_name}): {e}")
return ""
def _dual_model_inference(self, prompt: str) -> Tuple[str, str]:
"""并行执行双模型推理"""
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._call_model, self.model1, prompt)
future2 = executor.submit(self._call_model, self.model2, prompt)
return future1.result(), future2.result()
def _model_inference(self, prompt: str) -> Tuple[str, str]:
"""并行执行双模型推理"""
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(self._call_model, self.model2, prompt)
return future.result()
def _merge_results(self, result1: str, result2: str) -> Dict[str, str]:
"""智能融合双模型结果"""
try:
json1 = json.loads(result1) if result1.strip() else {}
except json.JSONDecodeError:
json1 = {}
try:
json2 = json.loads(result2) if result2.strip() else {}
except json.JSONDecodeError:
json2 = {}
# 合并策略保留所有识别结果冲突时优先模型1
merged = {**json1, **json2} # 后者覆盖前者
return merged
def extract_call_signs_DoubleModel(self, cid: str,ctext: str) -> Dict[str, str]:
"""执行完整提取流程"""
prompt = self.prompt_template.format(cid=cid,ctext=ctext)
result1, result2 = self._dual_model_inference(prompt)
return self._merge_results(result1, result2)
def extract_call_signs(self, cid: str,ctext: str) -> Dict[str, str]:
"""执行完整提取流程"""
prompt = self.prompt_template.format(cid=cid,ctext=ctext)
result = self._model_inference(prompt)
return result
# ==================== 主执行流程 ====================
def main():
# 示例输入数据
#input_data = read_csv_to_array(r'./output_lite.csv')
# 预处理
id_data = read_from_table(input_table_name, "id")
text_data = read_from_table(input_table_name, "text")
data = [[id_data[i],text_data[i]] for i in range(len(id_data))]
#print(raw_text) #把id去掉只返回text字段
# 初始化提取器
extractor = CallSignExtractor()
all_result = []
# 执行提取
for i in range(len(data)):
raw_id,raw_text = data[i][0],data[i][1]
#result = extractor.extract_call_signs_DoubleModel(raw_id,raw_text)
result = extractor.extract_call_signs(raw_id,raw_text)
# 格式化输出
result_list = parse_input(result)
print(result_list)
for row in result_list:
all_result.append(row)
write_to_table(all_result,HANDER, "CsExtraction")
if __name__ == "__main__":
main()