|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
Neo4j 数据库层 + 内存缓存
|
|
|
"""
|
|
|
import os
|
|
|
import json
|
|
|
import time
|
|
|
import functools
|
|
|
from neo4j import GraphDatabase
|
|
|
from neo4j.exceptions import ServiceUnavailable, SessionExpired, AuthError
|
|
|
from flask import jsonify, request
|
|
|
|
|
|
# ── 配置 ─────────────────────────────────────────────────
|
|
|
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
|
|
|
NEO4J_USER = os.environ.get("NEO4J_USER", "neo4j")
|
|
|
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "xh3.1415926")
|
|
|
|
|
|
if not os.environ.get("NEO4J_PASSWORD"):
|
|
|
print("[警告] 未设置 NEO4J_PASSWORD 环境变量,使用默认值。"
|
|
|
"建议: export NEO4J_PASSWORD=your_password")
|
|
|
|
|
|
# ── Neo4j Driver ─────────────────────────────────────────
|
|
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
|
|
|
|
|
|
|
|
def run_query(cypher, params=None):
|
|
|
"""执行只读 Cypher 查询,出错时抛出明确异常"""
|
|
|
try:
|
|
|
with driver.session() as session:
|
|
|
result = session.run(cypher, params or {})
|
|
|
return [dict(r) for r in result]
|
|
|
except (ServiceUnavailable, SessionExpired) as e:
|
|
|
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
|
|
|
except AuthError as e:
|
|
|
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"查询执行失败: {e}") from e
|
|
|
|
|
|
|
|
|
def run_write(cypher, params=None):
|
|
|
"""执行写入 Cypher 语句"""
|
|
|
try:
|
|
|
with driver.session() as session:
|
|
|
session.run(cypher, params or {})
|
|
|
except (ServiceUnavailable, SessionExpired) as e:
|
|
|
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
|
|
|
except AuthError as e:
|
|
|
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"写入执行失败: {e}") from e
|
|
|
|
|
|
|
|
|
def run_write_batch(cypher, param_list):
|
|
|
"""批量写入,每批次在同一事务内提交"""
|
|
|
if not param_list:
|
|
|
return
|
|
|
try:
|
|
|
with driver.session() as session:
|
|
|
with session.begin_transaction() as tx:
|
|
|
for params in param_list:
|
|
|
tx.run(cypher, params)
|
|
|
tx.commit()
|
|
|
except (ServiceUnavailable, SessionExpired) as e:
|
|
|
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
|
|
|
except AuthError as e:
|
|
|
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"批量写入失败: {e}") from e
|
|
|
|
|
|
|
|
|
# ── UNWIND 批量操作(供 import_data 使用)──────────────────
|
|
|
def run_unwind_batch(cypher, rows):
|
|
|
"""UNWIND 批量写入,一次请求完成"""
|
|
|
if not rows:
|
|
|
return
|
|
|
with driver.session() as s:
|
|
|
s.run(f"UNWIND $rows AS row {cypher}", {"rows": rows})
|
|
|
|
|
|
|
|
|
# ── 内存缓存(带 TTL)─────────────────────────────────────
|
|
|
_cache_store = {} # key -> {"value": ..., "expire_at": float}
|
|
|
|
|
|
|
|
|
def ttl_cache(ttl_seconds=300):
|
|
|
"""
|
|
|
API 响应缓存装饰器,默认 5 分钟过期。
|
|
|
缓存的是 jsonify 之前的原始数据,避免 Response 对象不可复用。
|
|
|
"""
|
|
|
def decorator(func):
|
|
|
@functools.wraps(func)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
key_parts = [func.__name__]
|
|
|
key_parts.append(str(sorted(request.args.items())) if request else "")
|
|
|
key_parts.extend(str(a) for a in args)
|
|
|
key_parts.append(str(sorted(kwargs.items())))
|
|
|
key = "|".join(key_parts)
|
|
|
|
|
|
now = time.time()
|
|
|
entry = _cache_store.get(key)
|
|
|
if entry and entry["expire_at"] > now:
|
|
|
raw = entry["value"]
|
|
|
resp = jsonify(raw) if isinstance(raw, (dict, list)) else raw
|
|
|
resp.headers["X-Cache"] = "HIT"
|
|
|
return resp
|
|
|
|
|
|
result = func(*args, **kwargs)
|
|
|
if hasattr(result, 'get_json'):
|
|
|
raw = result.get_json()
|
|
|
elif hasattr(result, 'get_data'):
|
|
|
raw = json.loads(result.get_data(as_text=True))
|
|
|
else:
|
|
|
raw = result
|
|
|
|
|
|
_cache_store[key] = {
|
|
|
"value": raw,
|
|
|
"expire_at": now + ttl_seconds,
|
|
|
}
|
|
|
result.headers["X-Cache"] = "MISS"
|
|
|
return result
|
|
|
return wrapper
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
def clear_cache(pattern=None):
|
|
|
"""清除缓存。pattern=None 清除全部,指定字符串只清除匹配键。"""
|
|
|
if pattern is None:
|
|
|
_cache_store.clear()
|
|
|
return
|
|
|
keys = [k for k in _cache_store if pattern in k]
|
|
|
for k in keys:
|
|
|
del _cache_store[k]
|