diff --git a/rag/LightRAG-deepseek/lightrag/api/routers/graph_routes.py b/rag/LightRAG-deepseek/lightrag/api/routers/graph_routes.py index f02779d..7023b6a 100644 --- a/rag/LightRAG-deepseek/lightrag/api/routers/graph_routes.py +++ b/rag/LightRAG-deepseek/lightrag/api/routers/graph_routes.py @@ -2,7 +2,7 @@ This module contains all graph-related routes for the LightRAG API. """ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List import traceback from fastapi import APIRouter, Depends, Query, HTTPException from pydantic import BaseModel @@ -170,4 +170,49 @@ def create_graph_routes(rag, api_key: Optional[str] = None): status_code=500, detail=f"Error updating relation: {str(e)}" ) + @router.get("/graphs/note", dependencies=[Depends(combined_auth)]) + async def get_knowledge_graph_note( + label: str = Query(..., description="Label to get knowledge graph for"), + max_depth: int = Query(3, description="Maximum depth of graph", ge=1), + max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), + note_ids: List[str] = Query( + ..., + description="List of note IDs to get knowledge graph for", + example=["file1.txt", "file2.docx", "file3.pdf"] + ), + ): + """ + Retrieve a connected subgraph of nodes where the label includes the specified label. + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Hops(path) to the staring node take precedence + 2. Followed by the degree of the nodes + + Args: + label (str): Label of the starting node + max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3 + max_nodes: Maxiumu nodes to return + note_ids: List of note ids to filter nodes by + + Returns: + Dict[str, List[str]]: Knowledge graph for label + """ + try: + return await rag.get_knowledge_graph_note( + node_label=label, + max_depth=max_depth, + max_nodes=max_nodes, + note_ids=note_ids, + ) + except Exception as e: + logger.error(f"Error getting knowledge graph for label '{label}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error getting knowledge graph: {str(e)}" + ) + + + + return router + + diff --git a/rag/LightRAG-deepseek/lightrag/kg/neo4j_impl.py b/rag/LightRAG-deepseek/lightrag/kg/neo4j_impl.py index 9e0aa1c..56373bc 100644 --- a/rag/LightRAG-deepseek/lightrag/kg/neo4j_impl.py +++ b/rag/LightRAG-deepseek/lightrag/kg/neo4j_impl.py @@ -1,7 +1,7 @@ import os import re from dataclasses import dataclass -from typing import final +from typing import final, Optional import configparser @@ -1214,6 +1214,232 @@ class Neo4JStorage(BaseGraphStorage): ) return result + async def get_knowledge_graph_note( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = None, + note_ids: Optional[list[str]] = None, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + note_ids: List of note ids to filter nodes by + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ + # Get max_nodes from global_config if not provided + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + + workspace_label = self._get_workspace_label() + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + if node_label == "*": + # First check total node count to determine if graph is truncated + count_query = ( + f""" + MATCH (n:`{workspace_label}`) + WHERE $note_ids IS NULL OR n.file_path IN $note_ids + RETURN count(n) as total + """ + ) + count_result = None + try: + count_result = await session.run(count_query,{"note_ids": note_ids}) + count_record = await count_result.single() + + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() + + # Run main query to get nodes with highest degree + main_query = f""" + MATCH (n:`{workspace_label}`) + WHERE $note_ids IS NULL OR n.file_path IN $note_ids + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({{node: n}}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = None + try: + result_set = await session.run( + main_query, + {"max_nodes": max_nodes, + "note_ids": note_ids}, + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + + else: + # return await self._robust_fallback(node_label, max_depth, max_nodes) + # First try without limit to check if we need to truncate + full_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id + WITH start + CALL apoc.path.subgraphAll(start, {{ + relationshipFilter: '', + labelFilter: '{workspace_label}', + minLevel: 0, + maxLevel: $max_depth, + bfs: true + }}) + YIELD nodes, relationships + WITH nodes, relationships, size(nodes) AS total_nodes + UNWIND nodes AS node + WITH collect({{node: node}}) AS node_info, relationships, total_nodes + RETURN node_info, relationships, total_nodes + """ + + # Try to get full result + full_result = None + try: + full_result = await session.run( + full_query, + { + "entity_id": node_label, + "max_depth": max_depth, + }, + ) + full_record = await full_result.single() + + # If no record found, return empty KnowledgeGraph + if not full_record: + logger.debug(f"No nodes found for entity_id: {node_label}") + return result + + # If record found, check node count + total_nodes = full_record["total_nodes"] + + if total_nodes <= max_nodes: + # If node count is within limit, use full result directly + logger.debug( + f"Using full result with {total_nodes} nodes (no truncation needed)" + ) + record = full_record + else: + # If node count exceeds limit, set truncated flag and run limited query + result.is_truncated = True + logger.info( + f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" + ) + + # Run limited query + limited_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id + WITH start + CALL apoc.path.subgraphAll(start, {{ + relationshipFilter: '', + labelFilter: '{workspace_label}', + minLevel: 0, + maxLevel: $max_depth, + limit: $max_nodes, + bfs: true + }}) + YIELD nodes, relationships + UNWIND nodes AS node + WITH collect({{node: node}}) AS node_info, relationships + RETURN node_info, relationships + """ + result_set = None + try: + result_set = await session.run( + limited_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + finally: + if full_result: + await full_result.consume() + + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except neo4jExceptions.ClientError as e: + logger.warning(f"APOC plugin error: {str(e)}") + if node_label != "*": + logger.warning( + "Neo4j: falling back to basic Cypher recursive search..." + ) + return await self._robust_fallback(node_label, max_depth, max_nodes) + else: + logger.warning( + "Neo4j: APOC plugin error with wildcard query, returning empty result" + ) + + return result async def _robust_fallback( self, node_label: str, max_depth: int, max_nodes: int diff --git a/rag/LightRAG-deepseek/lightrag/lightrag.py b/rag/LightRAG-deepseek/lightrag/lightrag.py index a2024ef..2dd7681 100644 --- a/rag/LightRAG-deepseek/lightrag/lightrag.py +++ b/rag/LightRAG-deepseek/lightrag/lightrag.py @@ -564,6 +564,35 @@ class LightRAG: return await self.chunk_entity_relation_graph.get_knowledge_graph( node_label, max_depth, max_nodes ) + + async def get_knowledge_graph_note( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = None, + note_ids: Optional[List[str]] = None, # Add note_ids parameter + ) -> KnowledgeGraph: + """Get knowledge graph for a given label + + Args: + node_label (str): Label to get knowledge graph for + max_depth (int): Maximum depth of graph + max_nodes (int, optional): Maximum number of nodes to return. Defaults to self.max_graph_nodes. + note_ids (List[str], optional): List of note IDs to include in the graph. Defaults to None. + + Returns: + KnowledgeGraph: Knowledge graph containing nodes and edges + """ + # Use self.max_graph_nodes as default if max_nodes is None + if max_nodes is None: + max_nodes = self.max_graph_nodes + else: + # Limit max_nodes to not exceed self.max_graph_nodes + max_nodes = min(max_nodes, self.max_graph_nodes) + + return await self.chunk_entity_relation_graph.get_knowledge_graph_note( + node_label, max_depth, max_nodes, note_ids + ) def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] diff --git a/rag/LightRAG-deepseek/test.html b/rag/LightRAG-deepseek/test.html new file mode 100644 index 0000000..4c548c1 --- /dev/null +++ b/rag/LightRAG-deepseek/test.html @@ -0,0 +1,217 @@ + + +
+ +
+
灵简 - 灵性之“简”
+
灵简
+
灵简 - 灵性之“简”
+
+
灵简助手