push the test

main
icter99 8 months ago
parent 55c85c50da
commit 313919503e

@ -1,18 +1,19 @@
from neo4j import GraphDatabase
from neo4j import GraphDatabase, AsyncGraphDatabase
from app.config import settings
class Neo4jConnection:
def __init__(self, uri, user, password):
# 初始化连接,只负责连接
self.driver = GraphDatabase.driver(uri, auth=(user, password))
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password), max_connection_pool_size=100)
def close(self):
async def close(self):
# 关闭连接
self.driver.close()
await self.driver.close()
def get_session(self):
# 提供 session 对象,供外部查询使用
return self.driver.session()
return self.driver.session()
# 初始化全局 Neo4j 连接
neo4j_conn = Neo4jConnection(

@ -0,0 +1,24 @@
import random
from locust import HttpUser, task, between
class APIUser(HttpUser):
# 设置模拟用户之间的等待时间(例如:每个用户之间随机等待 1-3 秒)
wait_time = between(1, 3)
# 定义一个包含多个章节 ID 的列表
chapter_ids = ["1", "3", "6", "8", "15"]
@task
def get_chapter_relations(self):
# 从章节 ID 列表中随机选择一个章节 ID
chapter_id = random.choice(self.chapter_ids)
self.client.get(f"/chapters/{chapter_id}/relations")
@task
def get_all_relations(self):
self.client.get("/chapters/relations/")
@task
def get_relations_by_level(self):
level = random.randint(1, 5) # 随机选择一个层级,范围为 1 到 5
self.client.get(f"/chapters/relations/level/{level}")

@ -1,8 +1,10 @@
from fastapi import APIRouter
from typing import List
from fastapi import APIRouter, HTTPException
# from app.routers import get_cached_data
from app.schemas.user import ChapterRelationCreate, ChapterRelationResponse
from app.services.user_service import create_chapter_relation, get_graph_relations, search_chapters, \
from app.services.user_service import get_graph_relations, search_chapters, \
get_relations_to_level_simple, get_chapter_relations
router = APIRouter(
@ -10,46 +12,55 @@ router = APIRouter(
tags=["Chapters"],
)
@router.post("/relation/", response_model=ChapterRelationResponse)
def add_chapter_relation(relation: ChapterRelationCreate):
"""
创建章节之间的关系
"""
return create_chapter_relation(relation)
@router.get("/relations/", response_model=List[ChapterRelationResponse])
def list_chapter_relations():
async def list_chapter_relations():
# cache_key = "all_chapter_relations"
# cached_data = get_cached_data(cache_key) # 从缓存获取
# if cached_data:
# return [ChapterRelationResponse(**data) for data in cached_data] # 缓存命中,返回数据
"""
查询所有章节关系
"""
return get_graph_relations()
return await get_graph_relations()
@router.get("/search_chapters", response_model=List[ChapterRelationResponse])
def search_chapters_endpoint(q: str):
async def search_chapters_endpoint(q: str):
"""
API 端点根据搜索关键字模糊查询章节及其相关章节
:param q: 搜索关键字
:return: 匹配的章节关系列表
"""
# cache_key = f"search_chapters:{q}"
# # 尝试从缓存中获取数据
# cached_data = get_cached_data(cache_key)
# if cached_data:
# # 如果缓存命中,直接返回缓存中的数据
# return [ChapterRelationResponse(**data) for data in cached_data]
if not q:
raise HTTPException(status_code=400, detail="搜索关键字不能为空")
results = search_chapters(q)
results = await search_chapters(q)
if not results:
raise HTTPException(status_code=404, detail="未找到匹配的章节关系")
return results
@router.get("/relations/level/{level}", response_model=List[ChapterRelationResponse])
def get_relations_by_level_simple(level: int):
async def get_relations_by_level_simple(level: int):
"""
简化版根据目标层级查询从 Root 到目标层级的所有节点和关系
:param level: 目标层级1 -> Root & Subject, 2 -> Root, Subject, Topic, ..., 5 -> Problem
:return: 对应层级的节点和关系
"""
# cache_key = f"level_{level}_relations"
# cached_data = get_cached_data(cache_key) # 从缓存获取
# if cached_data:
# return [ChapterRelationResponse(**data) for data in cached_data] # 缓存命中,返回数据
if level < 0 or level > 5:
raise HTTPException(status_code=400, detail="层级参数必须在 0 到 5 之间")
try:
relations = get_relations_to_level_simple(level)
relations = await get_relations_to_level_simple(level)
if not relations:
raise HTTPException(status_code=404, detail="未找到相关节点关系")
return relations
@ -62,10 +73,11 @@ async def get_relations_by_chapter(chapter: str):
:param chapter_name: 章节名称
:return: 章节关系列表
"""
try:
relations = get_chapter_relations(chapter)
if not relations:
raise HTTPException(status_code=404, detail=f"No relations found for chapter {chapter}")
return relations
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching relations: {str(e)}")
# cache_key = f"chapter_relations:{chapter}" # 不使用分页,直接根据章节名称生成缓存键
# cached_data = get_cached_data(cache_key) # 从缓存中获取数据
# if cached_data:
# return [ChapterRelationResponse(**data) for data in cached_data] # 如果缓存命中,直接返回缓存数据
relations = await get_chapter_relations(chapter)
if not relations:
raise HTTPException(status_code=404, detail=f"No relations found for chapter {chapter}")
return relations

@ -1,39 +1,15 @@
import logging
from typing import List
from fastapi import HTTPException
from app.database import neo4j_conn
from app.schemas.user import ChapterRelationCreate, ChapterRelationResponse
def create_chapter_relation(relation_data: ChapterRelationCreate):
"""
创建章节关系
:param relation_data: 包含起始章节结束章节和关系类型
:return: 创建的关系数据
"""
query = """
MERGE (start:Chapter {name: $start_chapter})
MERGE (end:Chapter {name: $end_chapter})
CREATE (start)-[r:%s]->(end)
RETURN start.name AS start_chapter, TYPE(r) AS relation, end.name AS end_chapter
""" % relation_data.relation.upper() # 使用动态关系类型
# 获取 session 进行查询
with neo4j_conn.get_session() as session:
result = session.run(query, parameters={
"start_chapter": relation_data.start_chapter,
"end_chapter": relation_data.end_chapter,
})
record = result.single()
return {
"start_chapter": record["start_chapter"],
"relation": record["relation"],
"end_chapter": record["end_chapter"],
}
def get_graph_relations():
async def get_graph_relations():
"""
查询知识图谱中的所有节点及关系的详细信息
:return: 列表每一项为 ChapterRelationResponse 对象
@ -47,9 +23,9 @@ def get_graph_relations():
properties(end) AS end_properties,
labels(end) AS end_labels
"""
with neo4j_conn.get_session() as session:
result = session.run(query)
records = result.data()
async with neo4j_conn.get_session() as session:
result = await session.run(query)
records = await result.data()
responses = []
for record in records:
@ -69,7 +45,7 @@ def get_graph_relations():
return responses
def search_chapters(search_term: str) -> List[ChapterRelationResponse]:
async def search_chapters(search_term: str) -> List[ChapterRelationResponse]:
"""
根据搜索关键字模糊查询章节及其相关章节
:param search_term: 搜索关键字
@ -88,9 +64,9 @@ def search_chapters(search_term: str) -> List[ChapterRelationResponse]:
labels(end) AS end_labels
"""
try:
with neo4j_conn.get_session() as session:
result = session.run(query, parameters={"search_term": search_term})
records = result.data()
async with neo4j_conn.get_session() as session:
result = await session.run(query, parameters={"search_term": search_term})
records = await result.data()
responses = [
ChapterRelationResponse(
@ -109,7 +85,7 @@ def search_chapters(search_term: str) -> List[ChapterRelationResponse]:
return []
def get_relations_to_level_simple(level: int) -> List[ChapterRelationResponse]:
async def get_relations_to_level_simple(level: int) -> List[ChapterRelationResponse]:
"""
简化版根据目标层级查询从 Root 到目标层级的所有节点和关系
:param level: 目标层级1 -> Root & Subject, 2 -> Root, Subject, Topic, ..., 5 -> Problem
@ -130,9 +106,9 @@ def get_relations_to_level_simple(level: int) -> List[ChapterRelationResponse]:
"""
try:
with neo4j_conn.get_session() as session:
result = session.run(query)
records = result.data()
async with neo4j_conn.get_session() as session:
result = await session.run(query)
records = await result.data()
return [
ChapterRelationResponse(
@ -163,9 +139,9 @@ def get_relations_to_level_simple(level: int) -> List[ChapterRelationResponse]:
"""
try:
with neo4j_conn.get_session() as session:
result = session.run(query, parameters={"target_labels": target_labels})
records = result.data()
async with neo4j_conn.get_session() as session:
result = await session.run(query, parameters={"target_labels": target_labels})
records = await result.data()
return [
ChapterRelationResponse(
@ -180,7 +156,7 @@ def get_relations_to_level_simple(level: int) -> List[ChapterRelationResponse]:
except Exception as e:
logging.error(f"Error during get_relations_to_level_simple: {e}")
return []
def get_chapter_relations(chapter: str) -> List[ChapterRelationResponse]:
async def get_chapter_relations(chapter: str) -> List[ChapterRelationResponse]:
"""
根据章节名称查找该章节及其相关节点和关系
:param chapter_name: 章节名称
@ -197,10 +173,9 @@ def get_chapter_relations(chapter: str) -> List[ChapterRelationResponse]:
labels(end) AS end_labels
"""
try:
with neo4j_conn.get_session() as session:
result = session.run(query, parameters={"chapter": chapter})
records = result.data()
async with neo4j_conn.get_session() as session:
result = await session.run(query, parameters={"chapter": chapter})
records = await result.data()
return [
ChapterRelationResponse(
start_labels=record.get("start_labels", []),

@ -0,0 +1,61 @@
import logging
import random
import httpx
import pytest # 使用 pytest 来运行异步测试
from app.main import app # 替换为你的 FastAPI 应用实例
@pytest.mark.asyncio # 指定测试是异步的
async def test_list_chapter_relations():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/relations/")
assert response.status_code == 200
assert isinstance(response.json(), list) # 检查返回的数据类型
@pytest.mark.asyncio
async def test_search_chapters_valid():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/search_chapters?q=1")
assert response.status_code == 200
assert isinstance(response.json(), list)
assert len(response.json()) > 0 # 假设数据库中有相关数据
@pytest.mark.asyncio
async def test_search_chapters_invalid():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/search_chapters?q=") # 空查询
assert response.status_code == 400
assert response.json() == {"detail": "搜索关键字不能为空"}
@pytest.mark.asyncio
async def test_get_relations_by_level_valid():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/relations/level/3")
print(response.json())
assert response.status_code == 200
assert isinstance(response.json(), list)
assert len(response.json()) > 0
@pytest.mark.asyncio
async def test_get_relations_by_level_invalid():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/relations/level/6")
assert response.status_code == 400
assert response.json() == {"detail": "层级参数必须在 0 到 5 之间"}
@pytest.mark.asyncio
async def test_get_relations_by_chapter_valid():
chapter_id = random.choice(["1", "3", "6", "8", "15"])
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get(f"/chapters/{chapter_id}/relations")
print(response.json())
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_get_relations_by_chapter_not_found():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/chapters/10/relations")
print(response.json())
assert response.status_code == 404
assert response.json() == {"detail": "No relations found for chapter 10"}
Loading…
Cancel
Save