You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
Neo4j 数据库层 + 内存缓存
"""
import os
import json
import time
import functools
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, SessionExpired, AuthError
from flask import jsonify, request
# ── 配置 ─────────────────────────────────────────────────
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.environ.get("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "xh3.1415926")
if not os.environ.get("NEO4J_PASSWORD"):
print("[警告] 未设置 NEO4J_PASSWORD 环境变量,使用默认值。"
"建议: export NEO4J_PASSWORD=your_password")
# ── Neo4j Driver ─────────────────────────────────────────
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
def run_query(cypher, params=None):
"""执行只读 Cypher 查询,出错时抛出明确异常"""
try:
with driver.session() as session:
result = session.run(cypher, params or {})
return [dict(r) for r in result]
except (ServiceUnavailable, SessionExpired) as e:
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
except AuthError as e:
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
except Exception as e:
raise RuntimeError(f"查询执行失败: {e}") from e
def run_write(cypher, params=None):
"""执行写入 Cypher 语句"""
try:
with driver.session() as session:
session.run(cypher, params or {})
except (ServiceUnavailable, SessionExpired) as e:
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
except AuthError as e:
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
except Exception as e:
raise RuntimeError(f"写入执行失败: {e}") from e
def run_write_batch(cypher, param_list):
"""批量写入,每批次在同一事务内提交"""
if not param_list:
return
try:
with driver.session() as session:
with session.begin_transaction() as tx:
for params in param_list:
tx.run(cypher, params)
tx.commit()
except (ServiceUnavailable, SessionExpired) as e:
raise RuntimeError(f"Neo4j 数据库不可用: {e}") from e
except AuthError as e:
raise RuntimeError(f"Neo4j 认证失败: {e}") from e
except Exception as e:
raise RuntimeError(f"批量写入失败: {e}") from e
# ── UNWIND 批量操作(供 import_data 使用)──────────────────
def run_unwind_batch(cypher, rows):
"""UNWIND 批量写入,一次请求完成"""
if not rows:
return
with driver.session() as s:
s.run(f"UNWIND $rows AS row {cypher}", {"rows": rows})
# ── 内存缓存(带 TTL─────────────────────────────────────
_cache_store = {} # key -> {"value": ..., "expire_at": float}
def ttl_cache(ttl_seconds=300):
"""
API 响应缓存装饰器,默认 5 分钟过期。
缓存的是 jsonify 之前的原始数据,避免 Response 对象不可复用。
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
key_parts = [func.__name__]
key_parts.append(str(sorted(request.args.items())) if request else "")
key_parts.extend(str(a) for a in args)
key_parts.append(str(sorted(kwargs.items())))
key = "|".join(key_parts)
now = time.time()
entry = _cache_store.get(key)
if entry and entry["expire_at"] > now:
raw = entry["value"]
resp = jsonify(raw) if isinstance(raw, (dict, list)) else raw
resp.headers["X-Cache"] = "HIT"
return resp
result = func(*args, **kwargs)
if hasattr(result, 'get_json'):
raw = result.get_json()
elif hasattr(result, 'get_data'):
raw = json.loads(result.get_data(as_text=True))
else:
raw = result
_cache_store[key] = {
"value": raw,
"expire_at": now + ttl_seconds,
}
result.headers["X-Cache"] = "MISS"
return result
return wrapper
return decorator
def clear_cache(pattern=None):
"""清除缓存。pattern=None 清除全部,指定字符串只清除匹配键。"""
if pattern is None:
_cache_store.clear()
return
keys = [k for k in _cache_store if pattern in k]
for k in keys:
del _cache_store[k]