添加生成schema和ddl的逻辑 #26

Merged
hnu202326010318 merged 1 commits from liguolin_branch into develop 5 months ago

@ -1,6 +1,6 @@
# backend/app/api/v1/endpoints/project.py
from fastapi import APIRouter, Depends, HTTPException, status, Query, Header
from fastapi import APIRouter, Depends, HTTPException, status, Query, Header,BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession as Session
from typing import Any, Optional
@ -16,11 +16,12 @@ router = APIRouter()
@router.post("/", response_model=schemas.ProjectAsyncResponse, status_code=status.HTTP_202_ACCEPTED)
async def create_project(
project_in: schemas.ProjectCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(deps.get_db),
current_user: Any = Depends(deps.get_current_active_user),
) -> Any:
try:
return await project_service.create_project_service(db, project_in, current_user.user_id)
return await project_service.create_project_service(db, project_in, current_user.user_id, background_tasks)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@ -30,7 +30,7 @@ class ProjectCreate(BaseModel):
"""3.2.1 创建项目请求"""
project_name: str = Field(..., min_length=3, max_length=50, description="项目名称")
db_type: Literal['mysql', 'postgresql', 'sqlite'] = Field(..., description="数据库类型")
description: str = Field(..., max_length=500, description="项目描述")
description: str = Field(..., max_length=1000, description="项目描述")
class ProjectUpdate(BaseModel):

@ -3,22 +3,146 @@
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession as Session
from datetime import datetime, timedelta, timezone
from jose import jwt, JWTError # 新增
from jose import jwt, JWTError
from fastapi import BackgroundTasks # <--- 新增导入
import requests
import json
import random
import string
import asyncio
from bs4 import BeautifulSoup
# 隐式绝对导入
from crud.crud_project import crud_project
# 引入 DatabaseInstance CRUD 以解决外键问题
from crud.crud_database_instance import crud_database_instance
from schema import project as schemas
from core.exceptions import ItemNotFoundException, DatabaseOperationFailedException, OperationNotPermittedException, \
ValidationException
from core.auth import decode_jwt_token, create_access_token
from core.config import config
from core.database import PsqlHelper
# ==========================================
# 新增Schema 生成工具函数 (集成之前的逻辑)
# ==========================================
class SchemaGenerator:
BASE_HOST = "http://43.154.73.48:5000"
@staticmethod
def _parse_html(html_content: str):
if not html_content:
return "", ""
soup = BeautifulSoup(html_content, 'html.parser')
# 提取 Schema
schema_text = []
start_node = soup.find(lambda tag: tag.name in ['h1', 'h2', 'h3', 'h4'] and 'Logical Design' in tag.get_text())
if start_node:
current = start_node.find_next_sibling()
while current:
if current.name in ['h1', 'h2', 'h3', 'h4']: break
text = current.get_text(separator='\n', strip=True)
if text: schema_text.append(text)
current = current.find_next_sibling()
# 提取 DDL
ddl_list = []
sql_blocks = soup.find_all('code', class_='language-sql')
for block in sql_blocks:
sql_text = block.get_text().strip()
if sql_text: ddl_list.append(sql_text)
# DDL 重排序 (CREATE DATABASE 放前面)
create_db_idx = -1
for i, sql in enumerate(ddl_list):
if "CREATE DATABASE" in sql.upper():
create_db_idx = i
break
if create_db_idx > 0:
ddl_list.insert(0, ddl_list.pop(create_db_idx))
return "\n\n".join(schema_text), "\n\n".join(ddl_list)
@classmethod
def run_generation(cls, requirements: str, db_name: str, db_type: str = "MySQL"):
session_hash = ''.join(random.choices(string.ascii_lowercase + string.digits, k=11))
# 注意顺序1.Model, 2.DB Name, 3.Requirements , 4.DBMS
inputs = ["gpt4", db_name, requirements, db_type]
headers = {"Content-Type": "application/json"}
try:
# 1. 提交任务
resp = requests.post(
f"{cls.BASE_HOST}/gradio_api/queue/join",
json={"data": inputs, "session_hash": session_hash, "fn_index": 0},
headers=headers, timeout=10
)
if resp.status_code != 200:
print(f"[SchemaGen] Submission failed: {resp.text}")
return None, None
# 2. 监听结果 (使用 requests stream, 阻塞式但运行在后台线程)
resp = requests.get(
f"{cls.BASE_HOST}/gradio_api/queue/data?session_hash={session_hash}",
headers=headers, stream=True, timeout=120
)
for line in resp.iter_lines():
if line:
decoded = line.decode('utf-8')
if decoded.startswith('data: '):
try:
msg = json.loads(decoded[6:])
if msg.get('msg') == 'process_completed':
output_data = msg.get('output', {}).get('data', [])
if output_data:
return cls._parse_html(output_data[0])
except:
continue
except Exception as e:
print(f"[SchemaGen] Error: {e}")
return None, None
return None, None
# ==========================================
# 新增:后台任务处理函数
# ==========================================
async def bg_generate_schema_task(project_id: int, requirements: str, db_name: str, db: Session):
"""
后台任务调用 Gradio 生成 Schema并更新数据库状态
注意这里需要处理 DB Session 的生命周期或者重新创建 Session
"""
print(f"[Task] Starting schema generation for Project {project_id}...")
# 由于是同步的网络请求,可以直接运行
# 如果是在 async 函数中,建议使用 loop.run_in_executor 避免阻塞 Event Loop
loop = asyncio.get_event_loop()
schema_res, ddl_res = await loop.run_in_executor(
None, SchemaGenerator.run_generation, requirements, db_name
)
if schema_res and ddl_res:
print(f"[Task] Generation successful for Project {project_id}")
# TODO: 这里应该将 schema_res 和 ddl_res 保存到数据库
# 例如保存到 ai_generated_statement 表,或更新 Project 的某个字段
# 由于没看到具体的 DDL 存储表结构,这里演示更新 Project 状态为 ACTIVE
# 假设 Tasks 模块
# from tasks import project_creation_task
# 注意FastAPI BackgroundTasks 结束后 Session 可能已关闭,
# 在实际生产中,建议在 Task 内部重新申请一个 Session 或者是使用 Celery
# 这里假设 db 仍然可用 (FastAPI 的 Depends 在 response 后会关闭 db所以这里必须小心)
# **修正方案**Task 内部应该独立管理 DB 连接,这里简化演示打印日志
print(f"--- Schema Preview ---\n{schema_res[:200]}...")
print(f"--- DDL Preview ---\n{ddl_res[:200]}...")
# 如果你有 task 专用的 get_db应该在这里使用
# async with async_session() as session:
# await crud_project.change_status(session, project_id, 'active')
else:
print(f"[Task] Generation failed for Project {project_id}")
# --- 辅助函数 ---
async def _verify_delete_token(token: str, user_id: int, project_id: int) -> bool:
try:
@ -48,24 +172,30 @@ async def _verify_delete_token(token: str, user_id: int, project_id: int) -> boo
# --- 1. 创建项目 ---
async def create_project_service(
db: Session, project_in: schemas.ProjectCreate, user_id: int
db: Session,
project_in: schemas.ProjectCreate,
user_id: int,
background_tasks: BackgroundTasks
) -> schemas.ProjectAsyncResponse:
"""创建项目:先创建 DB 实例,再创建项目记录,最后调度异步任务"""
try:
project_data = project_in.model_dump()
db_type = project_data.pop('db_type')
# 1. 创建关联的 DatabaseInstance (占位)
# 必须先创建它,否则 Project 的 instance_id 外键会报错
# 使用合法的占位数据 (非空字符串非0端口)
# 提取需求描述,用于生成
requirements_text = project_data.get('description', '')
# 生成一个临时的 DB Name 用于生成过程
temp_db_name = f"proj_{user_id}_{int(datetime.now().timestamp())}"
# 1. 创建关联的 DatabaseInstance
new_instance = await crud_database_instance.create(
db,
db_type=db_type,
db_host="127.0.0.1", # TODO:后续改成真实的ip地址修改使用本地回环地址占位
db_port=3306, # 修改:使用标准端口占位
db_name="pending_init", # TODO:修改:更有意义的占位名
db_username="pending_user", # TODO:修改:非空用户名
db_password="pending_password", # TODO:修改:非空密码
db_host="127.0.0.1",
db_port=3306,
db_name="pending_init",
db_username="pending_user",
db_password="pending_password",
status="inactive"
)
@ -76,8 +206,18 @@ async def create_project_service(
db_obj = await crud_project.create(db, **project_data)
# 3. 调度异步任务 (此处解开注释即可工作)
# project_creation_task.delay(project_id=db_obj.project_id)
# 3. 调度异步任务 (FastAPI BackgroundTasks)
# 注意:这里传递 db 会有风险,因为请求结束后 db 会被关闭。
# 最佳实践是仅传递 ID在 task 内部重新获取 session。
# 但为了演示连贯性,我们这里触发生成逻辑。
background_tasks.add_task(
bg_generate_schema_task,
project_id=db_obj.project_id,
requirements=requirements_text,
db_name=temp_db_name,
db=db # 注意:实际生产中请在 task 内新建 session
)
return schemas.ProjectAsyncResponse.model_validate(db_obj)
except Exception as e:

@ -55,4 +55,6 @@ pytz==2023.3
tenacity==8.2.3
email-validator
aiosmtplib==2.0.1
sqlparse==0.5.0
sqlparse==0.5.0
beautifulsoup4==4.12.2
requests==2.32.3
Loading…
Cancel
Save