基本代码

main
林名赫 6 months ago
parent 733141f86a
commit d95dc6ca19

Before

Width:  |  Height:  |  Size: 179 KiB

After

Width:  |  Height:  |  Size: 179 KiB

@ -11,41 +11,11 @@ from lightrag.kg.shared_storage import initialize_pipeline_status
from dotenv import load_dotenv
# from lightrag.api.pg_reader import get_field_relation_types # 导入数据库查询异步函数
from typing import List, Optional
from typing import Tuple, List
load_dotenv(dotenv_path=".env", override=False)
WORKING_DIR = "./dickens"
# async def get_entity_types_async() -> Tuple[List[str], List[str]]:
# """异步获取实体类型列表(直接在当前事件循环中执行)
# Returns:
# Tuple[List[str], List[str]]: 第一个列表是field_types第二个列表是relation_types
# """
# try:
# # 直接 await 异步函数(在 main() 的事件循环中执行)
# result = await get_field_relation_types()
# if result is not None:
# field_types = result.get("unique_field_types", [])
# relation_types = result.get("unique_relation_types", [])
# return field_types, relation_types
# else:
# return [], []
# except Exception as e:
# logger.warning(f"获取实体类型失败: {e}")
# return [], []
def configure_logging():
"""Configure logging for the application"""
@ -135,32 +105,7 @@ async def print_stream(stream):
print(chunk, end="", flush=True)
# async def initialize_rag():
# rag = LightRAG(
# working_dir=WORKING_DIR,
# llm_model_func=llm_model_func,
# embedding_func=EmbeddingFunc(
# embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
# max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
# func=lambda texts: siliconcloud_embedding(
# texts,
# model=os.getenv("EMBEDDING_MODEL", "netease-youdao/bce-embedding-base_v1"),
# base_url=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434")+"/embeddings",
# max_token_size=512,
# api_key=os.getenv("EMBEDDING_BINDING_API_KEY", None)
# ),
# ),
# )
# await rag.initialize_storages()
# await initialize_pipeline_status()
# return rag
async def initialize_rag():
# # 获取字段类型和关系类型列表
# field_types, relation_types = await get_entity_types_async()
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
@ -175,11 +120,6 @@ async def initialize_rag():
api_key=os.getenv("EMBEDDING_BINDING_API_KEY", None)
),
),
# # 新增addon_params包含字段类型和关系类型
# addon_params={
# "entity_types": field_types, # 注入字段类型列表
# "relation_keywords": relation_types # 注入关系类型列表
# }
)
await rag.initialize_storages()
@ -343,7 +283,7 @@ async def main():
print("Query mode: ")
print("=====================")
resp = await rag.aquery(
"林深是谁",
"如果梦姑和段誉认识需要通过谁来认识",
param=QueryParam(mode="bypass", stream=True),
)
if inspect.isasyncgen(resp):

@ -2,11 +2,6 @@
LightRAG FastAPI Server
"""
# 顶部导入区添加
# from lightrag.api.pg_reader import get_entitytype_values # 导入数据库查询异步函数
from typing import List, Optional
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
@ -78,19 +73,6 @@ config.read("config.ini")
auth_configured = bool(auth_handler.accounts)
# # 新增异步函数:获取实体类型列表
# async def get_entity_types_async() -> List[List[str]]:
# """异步查询并返回实体类型列表"""
# try:
# entity_types = await get_entitytype_values()
# return entity_types if entity_types is not None else []
# except Exception as e:
# logger.warning(f"获取实体类型失败: {e}")
# return []
def create_app(args):
# Setup logging
logger.setLevel(args.log_level)
@ -309,12 +291,6 @@ def create_app(args):
),
)
# # 获取事件循环FastAPI 启动时已创建循环)
# loop = asyncio.get_event_loop()
# # 异步执行并等待结果(关键:用 loop.run_until_complete 执行异步函数)
# entity_types = loop.run_until_complete(get_entity_types_async())
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG(
@ -351,10 +327,7 @@ def create_app(args):
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
# "entity_types": entity_types # 新增实体类型
}, # 新增实体类型
addon_params={"language": args.summary_language},
)
else: # azure_openai
rag = LightRAG(
@ -382,10 +355,7 @@ def create_app(args):
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
# "entity_types": entity_types # 新增实体类型
},
addon_params={"language": args.summary_language},
)
# Add routes

@ -19,7 +19,6 @@ from fastapi import (
File,
HTTPException,
UploadFile,
Form,
)
from pydantic import BaseModel, Field, field_validator
@ -748,212 +747,6 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
logger.error(f"Error deleting file {file_path}: {str(e)}")
return False
async def pipeline_enqueue_file_uploadType(rag: LightRAG, file_path: Path, uploadType: Dict[str, Any] = None) -> bool:
"""将文件添加到处理队列,并包含元数据
参数:
rag: LightRAG实例
file_path: 要保存的文件路径
uploadType: 要与文档一起存储的元数据字典
返回:
bool: 如果文件成功入队则返回True否则返回False
"""
try:
content = ""
ext = file_path.suffix.lower()
file = None
async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# 根据文件类型处理
match ext:
case (
".txt"
| ".md"
| ".html"
| ".htm"
| ".tex"
| ".json"
| ".xml"
| ".yaml"
| ".yml"
| ".rtf"
| ".odt"
| ".epub"
| ".csv"
| ".log"
| ".conf"
| ".ini"
| ".properties"
| ".sql"
| ".bat"
| ".sh"
| ".c"
| ".cpp"
| ".py"
| ".java"
| ".js"
| ".ts"
| ".swift"
| ".go"
| ".rb"
| ".php"
| ".css"
| ".scss"
| ".less"
):
try:
# 尝试以UTF-8解码
content = file.decode("utf-8")
# 验证内容
if not content or len(content.strip()) == 0:
logger.error(f"文件内容为空: {file_path.name}")
return False
# 检查内容是否看起来像二进制数据的字符串表示
if content.startswith("b'") or content.startswith('b"'):
logger.error(
f"文件 {file_path.name} 似乎包含二进制数据的字符串表示而不是文本"
)
return False
except UnicodeDecodeError:
logger.error(
f"文件 {file_path.name} 不是有效的UTF-8编码文本。请在处理前将其转换为UTF-8。"
)
return False
case ".pdf":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
else:
if not pm.is_installed("pypdf2"): # type: ignore
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
else:
if not pm.is_installed("python-docx"): # type: ignore
try:
pm.install("python-docx")
except Exception:
pm.install("docx")
from docx import Document # type: ignore
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
else:
if not pm.is_installed("python-pptx"): # type: ignore
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
pptx_file = BytesIO(file)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
else:
if not pm.is_installed("openpyxl"): # type: ignore
pm.install("openpyxl")
from openpyxl import load_workbook # type: ignore
from io import BytesIO
xlsx_file = BytesIO(file)
wb = load_workbook(xlsx_file)
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += (
"\t".join(
str(cell) if cell is not None else ""
for cell in row
)
+ "\n"
)
content += "\n"
case _:
logger.error(
f"不支持的文件类型: {file_path.name} (扩展名 {ext})"
)
return False
# 将文件与元数据一起添加到RAG队列
if content:
# 检查内容是否只包含空白字符
if not content.strip():
logger.warning(
f"文件只包含空白字符。file_paths={file_path.name}"
)
# 使用带元数据的文档入队函数
await rag.apipeline_enqueue_documents_with_uploadType(
content,
file_paths=file_path.name,
uploadType=uploadType
)
logger.info(f"成功获取并入队文件: {file_path.name} 及其元数据")
return True
else:
logger.error(f"无法从文件中提取内容: {file_path.name}")
except Exception as e:
logger.error(f"处理或入队文件时出错 {file_path.name} 及其元数据: {str(e)}")
logger.error(traceback.format_exc())
finally:
# 如果是临时文件则删除
if file_path.name.startswith(temp_prefix):
try:
file_path.unlink()
except Exception as e:
logger.error(f"删除文件时出错 {file_path}: {str(e)}")
return False
async def pipeline_index_file(rag: LightRAG, file_path: Path):
"""Index a file
@ -963,10 +756,6 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
file_path: Path to the saved file
"""
try:
"""
pipeline_enqueue_file该函数负责读取文件内容根据文件类型进行不同的处理然后将处理后的内容添加到rag队列中
它支持多种文件类型包括文本文件PDFWord文档PowerPoint演示文稿Excel电子表格等
"""
if await pipeline_enqueue_file(rag, file_path):
await rag.apipeline_process_enqueue_documents()
@ -974,23 +763,6 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
logger.error(f"Error indexing file {file_path.name}: {str(e)}")
logger.error(traceback.format_exc())
async def pipeline_index_file_uploadType(rag: LightRAG, file_path: Path, uploadType: Dict[str, Any] = None):
"""Index a file with additional uploadType
Args:
rag: LightRAG instance
file_path: Path to the saved file
uploadType: Optional dictionary containing uploadType for the document
"""
try:
if await pipeline_enqueue_file_uploadType(rag, file_path, uploadType):
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logger.error(f"Error indexing file {file_path.name} with uploadType: {str(e)}")
logger.error(traceback.format_exc())
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
"""Index multiple files sequentially to avoid high CPU load
@ -1272,7 +1044,7 @@ def create_document_routes(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...), field_type: str=Form(None) ,relation_type: str=Form(None)
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""
Upload a file to the input directory and index it.
@ -1314,13 +1086,7 @@ def create_document_routes(
shutil.copyfileobj(file.file, buffer)
# Add to background tasks
# 创建元数据字典存储额外参数
uploadType = {
"field_type": field_type,
"relation_type": relation_type
}
background_tasks.add_task(pipeline_index_file_uploadType, rag, file_path, uploadType)
background_tasks.add_task(pipeline_index_file, rag, file_path)
return InsertResponse(
status="success",
@ -1711,7 +1477,6 @@ def create_document_routes(
error=doc_status.error,
metadata=doc_status.metadata,
file_path=doc_status.file_path,
uploadType=doc_status.uploadType,
)
)
return response

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save