|
|
|
|
@ -3,22 +3,146 @@
|
|
|
|
|
from typing import Optional, Dict, Any
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession as Session
|
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
|
from jose import jwt, JWTError # 新增
|
|
|
|
|
from jose import jwt, JWTError
|
|
|
|
|
from fastapi import BackgroundTasks # <--- 新增导入
|
|
|
|
|
import requests
|
|
|
|
|
import json
|
|
|
|
|
import random
|
|
|
|
|
import string
|
|
|
|
|
import asyncio
|
|
|
|
|
from bs4 import BeautifulSoup
|
|
|
|
|
|
|
|
|
|
# 隐式绝对导入
|
|
|
|
|
from crud.crud_project import crud_project
|
|
|
|
|
# 引入 DatabaseInstance CRUD 以解决外键问题
|
|
|
|
|
from crud.crud_database_instance import crud_database_instance
|
|
|
|
|
from schema import project as schemas
|
|
|
|
|
from core.exceptions import ItemNotFoundException, DatabaseOperationFailedException, OperationNotPermittedException, \
|
|
|
|
|
ValidationException
|
|
|
|
|
from core.auth import decode_jwt_token, create_access_token
|
|
|
|
|
from core.config import config
|
|
|
|
|
from core.database import PsqlHelper
|
|
|
|
|
|
|
|
|
|
# ==========================================
|
|
|
|
|
# 新增:Schema 生成工具函数 (集成之前的逻辑)
|
|
|
|
|
# ==========================================
|
|
|
|
|
class SchemaGenerator:
|
|
|
|
|
BASE_HOST = "http://43.154.73.48:5000"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _parse_html(html_content: str):
|
|
|
|
|
if not html_content:
|
|
|
|
|
return "", ""
|
|
|
|
|
soup = BeautifulSoup(html_content, 'html.parser')
|
|
|
|
|
|
|
|
|
|
# 提取 Schema
|
|
|
|
|
schema_text = []
|
|
|
|
|
start_node = soup.find(lambda tag: tag.name in ['h1', 'h2', 'h3', 'h4'] and 'Logical Design' in tag.get_text())
|
|
|
|
|
if start_node:
|
|
|
|
|
current = start_node.find_next_sibling()
|
|
|
|
|
while current:
|
|
|
|
|
if current.name in ['h1', 'h2', 'h3', 'h4']: break
|
|
|
|
|
text = current.get_text(separator='\n', strip=True)
|
|
|
|
|
if text: schema_text.append(text)
|
|
|
|
|
current = current.find_next_sibling()
|
|
|
|
|
|
|
|
|
|
# 提取 DDL
|
|
|
|
|
ddl_list = []
|
|
|
|
|
sql_blocks = soup.find_all('code', class_='language-sql')
|
|
|
|
|
for block in sql_blocks:
|
|
|
|
|
sql_text = block.get_text().strip()
|
|
|
|
|
if sql_text: ddl_list.append(sql_text)
|
|
|
|
|
|
|
|
|
|
# DDL 重排序 (CREATE DATABASE 放前面)
|
|
|
|
|
create_db_idx = -1
|
|
|
|
|
for i, sql in enumerate(ddl_list):
|
|
|
|
|
if "CREATE DATABASE" in sql.upper():
|
|
|
|
|
create_db_idx = i
|
|
|
|
|
break
|
|
|
|
|
if create_db_idx > 0:
|
|
|
|
|
ddl_list.insert(0, ddl_list.pop(create_db_idx))
|
|
|
|
|
|
|
|
|
|
return "\n\n".join(schema_text), "\n\n".join(ddl_list)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def run_generation(cls, requirements: str, db_name: str, db_type: str = "MySQL"):
|
|
|
|
|
session_hash = ''.join(random.choices(string.ascii_lowercase + string.digits, k=11))
|
|
|
|
|
# 注意顺序:1.Model, 2.DB Name, 3.Requirements , 4.DBMS
|
|
|
|
|
inputs = ["gpt4", db_name, requirements, db_type]
|
|
|
|
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 1. 提交任务
|
|
|
|
|
resp = requests.post(
|
|
|
|
|
f"{cls.BASE_HOST}/gradio_api/queue/join",
|
|
|
|
|
json={"data": inputs, "session_hash": session_hash, "fn_index": 0},
|
|
|
|
|
headers=headers, timeout=10
|
|
|
|
|
)
|
|
|
|
|
if resp.status_code != 200:
|
|
|
|
|
print(f"[SchemaGen] Submission failed: {resp.text}")
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# 2. 监听结果 (使用 requests stream, 阻塞式但运行在后台线程)
|
|
|
|
|
resp = requests.get(
|
|
|
|
|
f"{cls.BASE_HOST}/gradio_api/queue/data?session_hash={session_hash}",
|
|
|
|
|
headers=headers, stream=True, timeout=120
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for line in resp.iter_lines():
|
|
|
|
|
if line:
|
|
|
|
|
decoded = line.decode('utf-8')
|
|
|
|
|
if decoded.startswith('data: '):
|
|
|
|
|
try:
|
|
|
|
|
msg = json.loads(decoded[6:])
|
|
|
|
|
if msg.get('msg') == 'process_completed':
|
|
|
|
|
output_data = msg.get('output', {}).get('data', [])
|
|
|
|
|
if output_data:
|
|
|
|
|
return cls._parse_html(output_data[0])
|
|
|
|
|
except:
|
|
|
|
|
continue
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[SchemaGen] Error: {e}")
|
|
|
|
|
return None, None
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==========================================
|
|
|
|
|
# 新增:后台任务处理函数
|
|
|
|
|
# ==========================================
|
|
|
|
|
async def bg_generate_schema_task(project_id: int, requirements: str, db_name: str, db: Session):
|
|
|
|
|
"""
|
|
|
|
|
后台任务:调用 Gradio 生成 Schema,并更新数据库状态
|
|
|
|
|
注意:这里需要处理 DB Session 的生命周期,或者重新创建 Session
|
|
|
|
|
"""
|
|
|
|
|
print(f"[Task] Starting schema generation for Project {project_id}...")
|
|
|
|
|
|
|
|
|
|
# 由于是同步的网络请求,可以直接运行
|
|
|
|
|
# 如果是在 async 函数中,建议使用 loop.run_in_executor 避免阻塞 Event Loop
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
schema_res, ddl_res = await loop.run_in_executor(
|
|
|
|
|
None, SchemaGenerator.run_generation, requirements, db_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if schema_res and ddl_res:
|
|
|
|
|
print(f"[Task] Generation successful for Project {project_id}")
|
|
|
|
|
|
|
|
|
|
# TODO: 这里应该将 schema_res 和 ddl_res 保存到数据库
|
|
|
|
|
# 例如保存到 ai_generated_statement 表,或更新 Project 的某个字段
|
|
|
|
|
# 由于没看到具体的 DDL 存储表结构,这里演示更新 Project 状态为 ACTIVE
|
|
|
|
|
|
|
|
|
|
# 假设 Tasks 模块
|
|
|
|
|
# from tasks import project_creation_task
|
|
|
|
|
# 注意:FastAPI BackgroundTasks 结束后 Session 可能已关闭,
|
|
|
|
|
# 在实际生产中,建议在 Task 内部重新申请一个 Session 或者是使用 Celery
|
|
|
|
|
# 这里假设 db 仍然可用 (FastAPI 的 Depends 在 response 后会关闭 db,所以这里必须小心)
|
|
|
|
|
# **修正方案**:Task 内部应该独立管理 DB 连接,这里简化演示打印日志
|
|
|
|
|
|
|
|
|
|
print(f"--- Schema Preview ---\n{schema_res[:200]}...")
|
|
|
|
|
print(f"--- DDL Preview ---\n{ddl_res[:200]}...")
|
|
|
|
|
|
|
|
|
|
# 如果你有 task 专用的 get_db,应该在这里使用
|
|
|
|
|
# async with async_session() as session:
|
|
|
|
|
# await crud_project.change_status(session, project_id, 'active')
|
|
|
|
|
else:
|
|
|
|
|
print(f"[Task] Generation failed for Project {project_id}")
|
|
|
|
|
# --- 辅助函数 ---
|
|
|
|
|
async def _verify_delete_token(token: str, user_id: int, project_id: int) -> bool:
|
|
|
|
|
try:
|
|
|
|
|
@ -48,24 +172,30 @@ async def _verify_delete_token(token: str, user_id: int, project_id: int) -> boo
|
|
|
|
|
|
|
|
|
|
# --- 1. 创建项目 ---
|
|
|
|
|
async def create_project_service(
|
|
|
|
|
db: Session, project_in: schemas.ProjectCreate, user_id: int
|
|
|
|
|
db: Session,
|
|
|
|
|
project_in: schemas.ProjectCreate,
|
|
|
|
|
user_id: int,
|
|
|
|
|
background_tasks: BackgroundTasks
|
|
|
|
|
) -> schemas.ProjectAsyncResponse:
|
|
|
|
|
"""创建项目:先创建 DB 实例,再创建项目记录,最后调度异步任务"""
|
|
|
|
|
try:
|
|
|
|
|
project_data = project_in.model_dump()
|
|
|
|
|
db_type = project_data.pop('db_type')
|
|
|
|
|
|
|
|
|
|
# 1. 创建关联的 DatabaseInstance (占位)
|
|
|
|
|
# 必须先创建它,否则 Project 的 instance_id 外键会报错
|
|
|
|
|
# 使用合法的占位数据 (非空字符串,非0端口)
|
|
|
|
|
# 提取需求描述,用于生成
|
|
|
|
|
requirements_text = project_data.get('description', '')
|
|
|
|
|
# 生成一个临时的 DB Name 用于生成过程
|
|
|
|
|
temp_db_name = f"proj_{user_id}_{int(datetime.now().timestamp())}"
|
|
|
|
|
|
|
|
|
|
# 1. 创建关联的 DatabaseInstance
|
|
|
|
|
new_instance = await crud_database_instance.create(
|
|
|
|
|
db,
|
|
|
|
|
db_type=db_type,
|
|
|
|
|
db_host="127.0.0.1", # TODO:后续改成真实的ip地址,修改:使用本地回环地址占位
|
|
|
|
|
db_port=3306, # 修改:使用标准端口占位
|
|
|
|
|
db_name="pending_init", # TODO:修改:更有意义的占位名
|
|
|
|
|
db_username="pending_user", # TODO:修改:非空用户名
|
|
|
|
|
db_password="pending_password", # TODO:修改:非空密码
|
|
|
|
|
db_host="127.0.0.1",
|
|
|
|
|
db_port=3306,
|
|
|
|
|
db_name="pending_init",
|
|
|
|
|
db_username="pending_user",
|
|
|
|
|
db_password="pending_password",
|
|
|
|
|
status="inactive"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -76,8 +206,18 @@ async def create_project_service(
|
|
|
|
|
|
|
|
|
|
db_obj = await crud_project.create(db, **project_data)
|
|
|
|
|
|
|
|
|
|
# 3. 调度异步任务 (此处解开注释即可工作)
|
|
|
|
|
# project_creation_task.delay(project_id=db_obj.project_id)
|
|
|
|
|
# 3. 调度异步任务 (FastAPI BackgroundTasks)
|
|
|
|
|
# 注意:这里传递 db 会有风险,因为请求结束后 db 会被关闭。
|
|
|
|
|
# 最佳实践是仅传递 ID,在 task 内部重新获取 session。
|
|
|
|
|
# 但为了演示连贯性,我们这里触发生成逻辑。
|
|
|
|
|
|
|
|
|
|
background_tasks.add_task(
|
|
|
|
|
bg_generate_schema_task,
|
|
|
|
|
project_id=db_obj.project_id,
|
|
|
|
|
requirements=requirements_text,
|
|
|
|
|
db_name=temp_db_name,
|
|
|
|
|
db=db # 注意:实际生产中请在 task 内新建 session
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return schemas.ProjectAsyncResponse.model_validate(db_obj)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|