You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

328 lines
9.5 KiB

import os
import asyncio
import inspect
import logging
import logging.config
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from lightrag.kg.shared_storage import initialize_pipeline_status
from dotenv import load_dotenv
load_dotenv(dotenv_path=".env", override=False)
WORKING_DIR = "./dickens"
def configure_logging():
"""Configure logging for the application"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
logger_instance.handlers = []
logger_instance.filters = []
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(
os.path.join(log_dir, "lightrag_compatible_demo.log")
)
print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
os.getenv("LLM_MODEL", "deepseek-chat"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("LLM_BINDING_API_KEY") or os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("LLM_BINDING_HOST", "https://api.deepseek.com"),
**kwargs,
)
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
func=lambda texts: siliconcloud_embedding(
texts,
model=os.getenv("EMBEDDING_MODEL", "netease-youdao/bce-embedding-base_v1"),
base_url=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434")+"/embeddings",
max_token_size=512,
api_key=os.getenv("EMBEDDING_BINDING_API_KEY", None)
),
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def main():
try:
# Clear old data files
# files_to_delete = [
# "graph_chunk_entity_relation.graphml",
# "kv_store_doc_status.json",
# "kv_store_full_docs.json",
# "kv_store_text_chunks.json",
# "vdb_chunks.json",
# "vdb_entities.json",
# "vdb_relationships.json",
# ]
# for file in files_to_delete:
# file_path = os.path.join(WORKING_DIR, file)
# if os.path.exists(file_path):
# os.remove(file_path)
# print(f"Deleting old file:: {file_path}")
# Initialize RAG instance
rag = await initialize_rag()
# Test embedding function
# test_text = ["This is a test string for embedding."]
# embedding = await rag.embedding_func(test_text)
# embedding_dim = embedding.shape[1]
# print("\n=======================")
# print("Test embedding function")
# print("========================")
# print(f"Test dict: {test_text}")
# print(f"Detected embedding dimension: {embedding_dim}\n\n")
# with open("./book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform naive search
# print("\n=====================")
# print("Query mode: naive")
# print("=====================")
# resp = await rag.aquery(
# "这个故事的主要主题是什么?",
# param=QueryParam(mode="naive", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# # Perform local search
# print("\n=====================")
# print("Query mode: local")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中谁是朱元璋?",
# param=QueryParam(mode="local", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# # Perform global search
# print("\n=====================")
# print("Query mode: global")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中孙悟空和谁有关系?",
# param=QueryParam(mode="global", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# Perform hybrid search
# print("\n=====================")
# print("Query mode: naive")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中朱元璋和马姑娘什么关系?",
# param=QueryParam(mode="naive", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# print("\n=====================")
# print("Query mode: bypass")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中朱元璋和马姑娘什么关系?",
# param=QueryParam(mode="bypass", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# print("\n=====================")
# print("Query mode: local")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中朱元璋和马姑娘什么关系?",
# param=QueryParam(mode="local", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# print("\n=====================")
# print("Query mode: global")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中朱元璋和马姑娘什么关系?",
# param=QueryParam(mode="global", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# print("\n=====================")
# print("Query mode: hybrid")
# print("=====================")
# resp = await rag.aquery(
# "这个故事中朱元璋和马姑娘什么关系?",
# param=QueryParam(mode="hybrid", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
print("\n=====================")
print("Query mode: ")
print("=====================")
resp = await rag.aquery(
"如果梦姑和段誉认识需要通过谁来认识?",
param=QueryParam(mode="bypass", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
except Exception as e:
print(f"An error occurred: {e}")
finally:
if rag:
await rag.finalize_storages()
if __name__ == "__main__":
# Configure logging before running the main function
configure_logging()
asyncio.run(main())
print("\nDone!")