|
|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import final
|
|
|
|
|
from typing import final, Optional
|
|
|
|
|
import configparser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1214,6 +1214,232 @@ class Neo4JStorage(BaseGraphStorage):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
async def get_knowledge_graph_note(
|
|
|
|
|
self,
|
|
|
|
|
node_label: str,
|
|
|
|
|
max_depth: int = 3,
|
|
|
|
|
max_nodes: int = None,
|
|
|
|
|
note_ids: Optional[list[str]] = None,
|
|
|
|
|
) -> KnowledgeGraph:
|
|
|
|
|
"""
|
|
|
|
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node_label: Label of the starting node, * means all nodes
|
|
|
|
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
|
|
|
|
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
|
|
|
|
note_ids: List of note ids to filter nodes by
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
|
|
|
|
indicating whether the graph was truncated due to max_nodes limit
|
|
|
|
|
"""
|
|
|
|
|
# Get max_nodes from global_config if not provided
|
|
|
|
|
if max_nodes is None:
|
|
|
|
|
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
|
|
|
|
else:
|
|
|
|
|
# Limit max_nodes to not exceed global_config max_graph_nodes
|
|
|
|
|
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
|
|
|
|
|
|
|
|
|
workspace_label = self._get_workspace_label()
|
|
|
|
|
result = KnowledgeGraph()
|
|
|
|
|
seen_nodes = set()
|
|
|
|
|
seen_edges = set()
|
|
|
|
|
|
|
|
|
|
async with self._driver.session(
|
|
|
|
|
database=self._DATABASE, default_access_mode="READ"
|
|
|
|
|
) as session:
|
|
|
|
|
try:
|
|
|
|
|
if node_label == "*":
|
|
|
|
|
# First check total node count to determine if graph is truncated
|
|
|
|
|
count_query = (
|
|
|
|
|
f"""
|
|
|
|
|
MATCH (n:`{workspace_label}`)
|
|
|
|
|
WHERE $note_ids IS NULL OR n.file_path IN $note_ids
|
|
|
|
|
RETURN count(n) as total
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
count_result = None
|
|
|
|
|
try:
|
|
|
|
|
count_result = await session.run(count_query,{"note_ids": note_ids})
|
|
|
|
|
count_record = await count_result.single()
|
|
|
|
|
|
|
|
|
|
if count_record and count_record["total"] > max_nodes:
|
|
|
|
|
result.is_truncated = True
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
if count_result:
|
|
|
|
|
await count_result.consume()
|
|
|
|
|
|
|
|
|
|
# Run main query to get nodes with highest degree
|
|
|
|
|
main_query = f"""
|
|
|
|
|
MATCH (n:`{workspace_label}`)
|
|
|
|
|
WHERE $note_ids IS NULL OR n.file_path IN $note_ids
|
|
|
|
|
OPTIONAL MATCH (n)-[r]-()
|
|
|
|
|
WITH n, COALESCE(count(r), 0) AS degree
|
|
|
|
|
ORDER BY degree DESC
|
|
|
|
|
LIMIT $max_nodes
|
|
|
|
|
WITH collect({{node: n}}) AS filtered_nodes
|
|
|
|
|
UNWIND filtered_nodes AS node_info
|
|
|
|
|
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
|
|
|
|
OPTIONAL MATCH (a)-[r]-(b)
|
|
|
|
|
WHERE a IN kept_nodes AND b IN kept_nodes
|
|
|
|
|
RETURN filtered_nodes AS node_info,
|
|
|
|
|
collect(DISTINCT r) AS relationships
|
|
|
|
|
"""
|
|
|
|
|
result_set = None
|
|
|
|
|
try:
|
|
|
|
|
result_set = await session.run(
|
|
|
|
|
main_query,
|
|
|
|
|
{"max_nodes": max_nodes,
|
|
|
|
|
"note_ids": note_ids},
|
|
|
|
|
)
|
|
|
|
|
record = await result_set.single()
|
|
|
|
|
finally:
|
|
|
|
|
if result_set:
|
|
|
|
|
await result_set.consume()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
|
|
|
|
# First try without limit to check if we need to truncate
|
|
|
|
|
full_query = f"""
|
|
|
|
|
MATCH (start:`{workspace_label}`)
|
|
|
|
|
WHERE start.entity_id = $entity_id
|
|
|
|
|
WITH start
|
|
|
|
|
CALL apoc.path.subgraphAll(start, {{
|
|
|
|
|
relationshipFilter: '',
|
|
|
|
|
labelFilter: '{workspace_label}',
|
|
|
|
|
minLevel: 0,
|
|
|
|
|
maxLevel: $max_depth,
|
|
|
|
|
bfs: true
|
|
|
|
|
}})
|
|
|
|
|
YIELD nodes, relationships
|
|
|
|
|
WITH nodes, relationships, size(nodes) AS total_nodes
|
|
|
|
|
UNWIND nodes AS node
|
|
|
|
|
WITH collect({{node: node}}) AS node_info, relationships, total_nodes
|
|
|
|
|
RETURN node_info, relationships, total_nodes
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Try to get full result
|
|
|
|
|
full_result = None
|
|
|
|
|
try:
|
|
|
|
|
full_result = await session.run(
|
|
|
|
|
full_query,
|
|
|
|
|
{
|
|
|
|
|
"entity_id": node_label,
|
|
|
|
|
"max_depth": max_depth,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
full_record = await full_result.single()
|
|
|
|
|
|
|
|
|
|
# If no record found, return empty KnowledgeGraph
|
|
|
|
|
if not full_record:
|
|
|
|
|
logger.debug(f"No nodes found for entity_id: {node_label}")
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
# If record found, check node count
|
|
|
|
|
total_nodes = full_record["total_nodes"]
|
|
|
|
|
|
|
|
|
|
if total_nodes <= max_nodes:
|
|
|
|
|
# If node count is within limit, use full result directly
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
|
|
|
|
)
|
|
|
|
|
record = full_record
|
|
|
|
|
else:
|
|
|
|
|
# If node count exceeds limit, set truncated flag and run limited query
|
|
|
|
|
result.is_truncated = True
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Run limited query
|
|
|
|
|
limited_query = f"""
|
|
|
|
|
MATCH (start:`{workspace_label}`)
|
|
|
|
|
WHERE start.entity_id = $entity_id
|
|
|
|
|
WITH start
|
|
|
|
|
CALL apoc.path.subgraphAll(start, {{
|
|
|
|
|
relationshipFilter: '',
|
|
|
|
|
labelFilter: '{workspace_label}',
|
|
|
|
|
minLevel: 0,
|
|
|
|
|
maxLevel: $max_depth,
|
|
|
|
|
limit: $max_nodes,
|
|
|
|
|
bfs: true
|
|
|
|
|
}})
|
|
|
|
|
YIELD nodes, relationships
|
|
|
|
|
UNWIND nodes AS node
|
|
|
|
|
WITH collect({{node: node}}) AS node_info, relationships
|
|
|
|
|
RETURN node_info, relationships
|
|
|
|
|
"""
|
|
|
|
|
result_set = None
|
|
|
|
|
try:
|
|
|
|
|
result_set = await session.run(
|
|
|
|
|
limited_query,
|
|
|
|
|
{
|
|
|
|
|
"entity_id": node_label,
|
|
|
|
|
"max_depth": max_depth,
|
|
|
|
|
"max_nodes": max_nodes,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
record = await result_set.single()
|
|
|
|
|
finally:
|
|
|
|
|
if result_set:
|
|
|
|
|
await result_set.consume()
|
|
|
|
|
finally:
|
|
|
|
|
if full_result:
|
|
|
|
|
await full_result.consume()
|
|
|
|
|
|
|
|
|
|
if record:
|
|
|
|
|
# Handle nodes (compatible with multi-label cases)
|
|
|
|
|
for node_info in record["node_info"]:
|
|
|
|
|
node = node_info["node"]
|
|
|
|
|
node_id = node.id
|
|
|
|
|
if node_id not in seen_nodes:
|
|
|
|
|
result.nodes.append(
|
|
|
|
|
KnowledgeGraphNode(
|
|
|
|
|
id=f"{node_id}",
|
|
|
|
|
labels=[node.get("entity_id")],
|
|
|
|
|
properties=dict(node),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
seen_nodes.add(node_id)
|
|
|
|
|
|
|
|
|
|
# Handle relationships (including direction information)
|
|
|
|
|
for rel in record["relationships"]:
|
|
|
|
|
edge_id = rel.id
|
|
|
|
|
if edge_id not in seen_edges:
|
|
|
|
|
start = rel.start_node
|
|
|
|
|
end = rel.end_node
|
|
|
|
|
result.edges.append(
|
|
|
|
|
KnowledgeGraphEdge(
|
|
|
|
|
id=f"{edge_id}",
|
|
|
|
|
type=rel.type,
|
|
|
|
|
source=f"{start.id}",
|
|
|
|
|
target=f"{end.id}",
|
|
|
|
|
properties=dict(rel),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
seen_edges.add(edge_id)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except neo4jExceptions.ClientError as e:
|
|
|
|
|
logger.warning(f"APOC plugin error: {str(e)}")
|
|
|
|
|
if node_label != "*":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Neo4j: falling back to basic Cypher recursive search..."
|
|
|
|
|
)
|
|
|
|
|
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
async def _robust_fallback(
|
|
|
|
|
self, node_label: str, max_depth: int, max_nodes: int
|
|
|
|
|
|