Compare commits
No commits in common. 'main' and 'dev' have entirely different histories.
|
Before Width: | Height: | Size: 194 KiB |
|
Before Width: | Height: | Size: 97 KiB |
|
Before Width: | Height: | Size: 213 KiB |
|
Before Width: | Height: | Size: 45 KiB |
|
Before Width: | Height: | Size: 232 KiB |
|
Before Width: | Height: | Size: 78 KiB |
|
Before Width: | Height: | Size: 124 KiB |
|
Before Width: | Height: | Size: 67 KiB |
|
Before Width: | Height: | Size: 499 KiB |
|
Before Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 72 KiB |
|
Before Width: | Height: | Size: 73 KiB |
|
Before Width: | Height: | Size: 72 KiB |
|
Before Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 56 KiB |
|
Before Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 19 KiB |
|
Before Width: | Height: | Size: 56 KiB |
|
Before Width: | Height: | Size: 56 KiB |
|
Before Width: | Height: | Size: 43 KiB |
|
Before Width: | Height: | Size: 501 KiB |
|
Before Width: | Height: | Size: 437 KiB |
|
Before Width: | Height: | Size: 311 KiB |
|
Before Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 33 KiB |
|
Before Width: | Height: | Size: 128 KiB |
|
Before Width: | Height: | Size: 126 KiB |
|
Before Width: | Height: | Size: 4.4 MiB |
@ -1,30 +0,0 @@
|
||||
[run]
|
||||
source = .
|
||||
omit =
|
||||
*/tests/*
|
||||
*/test_*.py
|
||||
*/__pycache__/*
|
||||
*/venv/*
|
||||
*/env/*
|
||||
setup.py
|
||||
*/site-packages/*
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
def __repr__
|
||||
raise AssertionError
|
||||
raise NotImplementedError
|
||||
if __name__ == .__main__.:
|
||||
if TYPE_CHECKING:
|
||||
@abstractmethod
|
||||
@abc.abstractmethod
|
||||
|
||||
precision = 2
|
||||
show_missing = True
|
||||
|
||||
[html]
|
||||
directory = tests/coverage_html
|
||||
|
||||
[xml]
|
||||
output = coverage.xml
|
||||
@ -1,235 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import ijson
|
||||
|
||||
|
||||
DB_PATH = Path("backend/instance/app.sqlite")
|
||||
|
||||
|
||||
def ensure_db_exists() -> None:
|
||||
if not DB_PATH.parent.exists():
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not DB_PATH.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Database not found at {DB_PATH}. Please run backend/scripts/init_db.py first."
|
||||
)
|
||||
|
||||
|
||||
def insert_user_if_needed(conn: sqlite3.Connection, uid: int) -> None:
|
||||
if uid is None:
|
||||
return
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO users(uid) VALUES (?)",
|
||||
(uid,),
|
||||
)
|
||||
|
||||
|
||||
def ingest_records(conn: sqlite3.Connection, record_path: Path, batch_size: int = 1000) -> int:
|
||||
count = 0
|
||||
print(f"Opening {record_path} for records import...")
|
||||
with record_path.open("rb") as fp:
|
||||
items = ijson.items(fp, "item")
|
||||
buffer = []
|
||||
for obj in items:
|
||||
submission_oid = (obj.get("_id") or {}).get("$oid")
|
||||
status = obj.get("status")
|
||||
uid = obj.get("uid")
|
||||
code = obj.get("code")
|
||||
lang = obj.get("lang")
|
||||
pid_int = obj.get("pid")
|
||||
domain_id = obj.get("domainId")
|
||||
score = obj.get("score")
|
||||
exec_time = obj.get("time")
|
||||
memory = obj.get("memory")
|
||||
judge_texts = json.dumps(obj.get("judgeTexts")) if obj.get("judgeTexts") is not None else None
|
||||
compiler_texts = json.dumps(obj.get("compilerTexts")) if obj.get("compilerTexts") is not None else None
|
||||
test_cases = json.dumps(obj.get("testCases")) if obj.get("testCases") is not None else None
|
||||
judge_at = (obj.get("judgeAt") or {}).get("$date")
|
||||
rejudged = 1 if obj.get("rejudged") else 0
|
||||
files_json = json.dumps(obj.get("files")) if obj.get("files") is not None else None
|
||||
subtasks_json = json.dumps(obj.get("subtasks")) if obj.get("subtasks") is not None else None
|
||||
ip = obj.get("ip")
|
||||
judger = obj.get("judger")
|
||||
|
||||
insert_user_if_needed(conn, uid)
|
||||
|
||||
buffer.append(
|
||||
(
|
||||
submission_oid,
|
||||
status,
|
||||
uid,
|
||||
code,
|
||||
lang,
|
||||
pid_int,
|
||||
domain_id,
|
||||
score,
|
||||
exec_time,
|
||||
memory,
|
||||
judge_texts,
|
||||
compiler_texts,
|
||||
test_cases,
|
||||
judge_at,
|
||||
rejudged,
|
||||
files_json,
|
||||
subtasks_json,
|
||||
ip,
|
||||
judger,
|
||||
)
|
||||
)
|
||||
|
||||
count += 1
|
||||
if count % 1000 == 0:
|
||||
print(f" Processed {count} records...")
|
||||
if len(buffer) >= batch_size:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO submissions(
|
||||
submission_oid, status, uid, code, lang, pid_int, domain_id, score, exec_time, memory,
|
||||
judge_texts, compiler_texts, test_cases, judge_at, rejudged, files_json, subtasks_json, ip, judger
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
buffer,
|
||||
)
|
||||
conn.commit()
|
||||
print(f" Committed batch at {count} records")
|
||||
buffer.clear()
|
||||
|
||||
if buffer:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO submissions(
|
||||
submission_oid, status, uid, code, lang, pid_int, domain_id, score, exec_time, memory,
|
||||
judge_texts, compiler_texts, test_cases, judge_at, rejudged, files_json, subtasks_json, ip, judger
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
buffer,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def ingest_documents(conn: sqlite3.Connection, document_path: Path, batch_size: int = 1000) -> int:
|
||||
count = 0
|
||||
print(f"Opening {document_path} for documents import...")
|
||||
with document_path.open("rb") as fp:
|
||||
items = ijson.items(fp, "item")
|
||||
buffer = []
|
||||
for obj in items:
|
||||
doc_oid = (obj.get("_id") or {}).get("$oid")
|
||||
doc_id = obj.get("docId")
|
||||
pid_code = obj.get("pid")
|
||||
title = obj.get("title")
|
||||
content_json = obj.get("content") # 原始即为 JSON 字符串
|
||||
owner = obj.get("owner")
|
||||
domain_id = obj.get("domainId")
|
||||
doc_type = obj.get("docType")
|
||||
tag_json = json.dumps(obj.get("tag")) if obj.get("tag") is not None else None
|
||||
hidden = 1 if obj.get("hidden") else 0
|
||||
n_submit = obj.get("nSubmit")
|
||||
n_accept = obj.get("nAccept")
|
||||
sort = obj.get("sort")
|
||||
data_json = json.dumps(obj.get("data")) if obj.get("data") is not None else None
|
||||
additional_file_json = json.dumps(obj.get("additional_file")) if obj.get("additional_file") is not None else None
|
||||
config = obj.get("config")
|
||||
if isinstance(config, dict):
|
||||
config = json.dumps(config)
|
||||
stats_json = json.dumps(obj.get("stats")) if obj.get("stats") is not None else None
|
||||
|
||||
buffer.append(
|
||||
(
|
||||
doc_oid,
|
||||
doc_id,
|
||||
pid_code,
|
||||
title,
|
||||
content_json,
|
||||
owner,
|
||||
domain_id,
|
||||
doc_type,
|
||||
tag_json,
|
||||
hidden,
|
||||
n_submit,
|
||||
n_accept,
|
||||
sort,
|
||||
data_json,
|
||||
additional_file_json,
|
||||
config,
|
||||
stats_json,
|
||||
)
|
||||
)
|
||||
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
print(f" Processed {count} documents...")
|
||||
if len(buffer) >= batch_size:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO problems(
|
||||
doc_oid, doc_id, pid_code, title, content_json, owner, domain_id, doc_type,
|
||||
tag_json, hidden, n_submit, n_accept, sort, data_json, additional_file_json, config, stats_json
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
buffer,
|
||||
)
|
||||
conn.commit()
|
||||
print(f" Committed documents batch at {count} records")
|
||||
buffer.clear()
|
||||
|
||||
if buffer:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO problems(
|
||||
doc_oid, doc_id, pid_code, title, content_json, owner, domain_id, doc_type,
|
||||
tag_json, hidden, n_submit, n_accept, sort, data_json, additional_file_json, config, stats_json
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
buffer,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--record", type=Path, default=Path("data/record.json"))
|
||||
parser.add_argument("--document", type=Path, default=Path("data/document.json"))
|
||||
args = parser.parse_args()
|
||||
|
||||
ensure_db_exists()
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
try:
|
||||
conn.execute("PRAGMA journal_mode=WAL;")
|
||||
conn.execute("PRAGMA synchronous=NORMAL;")
|
||||
|
||||
print(f"Starting import...")
|
||||
print(f"Document file exists: {args.document.exists()}")
|
||||
print(f"Record file exists: {args.record.exists()}")
|
||||
|
||||
if args.document.exists():
|
||||
print(f"Importing documents from {args.document}...")
|
||||
n_doc = ingest_documents(conn, args.document)
|
||||
print(f"Imported documents: {n_doc}")
|
||||
else:
|
||||
print(f"Skip: document file not found: {args.document}")
|
||||
|
||||
if args.record.exists():
|
||||
print(f"Importing records from {args.record}...")
|
||||
n_rec = ingest_records(conn, args.record)
|
||||
print(f"Imported records: {n_rec}")
|
||||
else:
|
||||
print(f"Skip: record file not found: {args.record}")
|
||||
|
||||
print("Import completed!")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置文件
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
class Config:
|
||||
"""基础配置"""
|
||||
# Flask配置
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key-for-programming-evaluation-system'
|
||||
|
||||
# 数据库配置
|
||||
DB_PATH = 'backend/instance/app.sqlite'
|
||||
|
||||
# 应用配置
|
||||
DEBUG = True
|
||||
HOST = '0.0.0.0'
|
||||
PORT = 5000
|
||||
|
||||
# 登录配置
|
||||
DEFAULT_PASSWORD = '123456'
|
||||
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""开发环境配置"""
|
||||
DEBUG = True
|
||||
TESTING = False
|
||||
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""生产环境配置"""
|
||||
DEBUG = False
|
||||
TESTING = False
|
||||
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""测试环境配置"""
|
||||
DEBUG = True
|
||||
TESTING = True
|
||||
|
||||
|
||||
# 配置字典
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'production': ProductionConfig,
|
||||
'testing': TestingConfig,
|
||||
'default': DevelopmentConfig
|
||||
}
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
[pytest]
|
||||
# Pytest配置文件
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# 输出选项
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
--cov=.
|
||||
--cov-report=html:tests/coverage_html
|
||||
--cov-report=xml:coverage.xml
|
||||
--cov-report=term-missing
|
||||
--cov-report=json:tests/coverage.json
|
||||
--cov-config=.coveragerc
|
||||
--junitxml=tests/test-results.xml
|
||||
|
||||
# 标记定义
|
||||
markers =
|
||||
unit: 单元测试
|
||||
integration: 集成测试
|
||||
slow: 慢速测试
|
||||
auth: 认证相关测试
|
||||
stats: 统计相关测试
|
||||
assessment: 评估相关测试
|
||||
api: API测试
|
||||
lstm: LSTM模型测试
|
||||
|
||||
# 覆盖率配置
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
@ -1,39 +0,0 @@
|
||||
Flask==2.3.3
|
||||
Werkzeug==2.3.7
|
||||
Jinja2==3.1.3
|
||||
MarkupSafe==2.1.5
|
||||
itsdangerous==2.1.2
|
||||
click==8.1.7
|
||||
|
||||
flask-sqlalchemy==3.1.1
|
||||
SQLAlchemy==2.0.20
|
||||
flask-migrate==4.0.4
|
||||
|
||||
flask-login==0.6.2
|
||||
flask-wtf==1.2.1
|
||||
flask-cors==4.0.0
|
||||
email-validator==2.0.0
|
||||
|
||||
pandas==2.2.3
|
||||
numpy==1.26.4
|
||||
matplotlib==3.8.3
|
||||
seaborn==0.13.2
|
||||
scikit-learn==1.4.1.post1
|
||||
python-dateutil==2.9.0.post0
|
||||
six==1.17.0
|
||||
|
||||
|
||||
torch>=2.2.0
|
||||
torchvision>=0.17.0
|
||||
torchaudio>=2.2.0
|
||||
|
||||
|
||||
python-dotenv==1.0.0
|
||||
gunicorn==21.2.0
|
||||
pytest==7.3.1
|
||||
pytest-cov==4.1.0
|
||||
pytest-flask==1.2.0
|
||||
coverage==7.3.0
|
||||
ijson==3.2.3
|
||||
scikit-learn==1.4.1.post1
|
||||
requests==2.31.0
|
||||
@ -1,47 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试运行脚本
|
||||
快速运行所有测试并生成覆盖率报告
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行测试"""
|
||||
print("=" * 60)
|
||||
print("开始运行测试...")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行pytest
|
||||
cmd = [
|
||||
sys.executable, '-m', 'pytest',
|
||||
'tests/',
|
||||
'-v',
|
||||
'--cov=.',
|
||||
'--cov-report=html:tests/coverage_html',
|
||||
'--cov-report=xml:coverage.xml',
|
||||
'--cov-report=term-missing',
|
||||
'--cov-report=json:tests/coverage.json',
|
||||
'--junitxml=tests/test-results.xml',
|
||||
'--tb=short'
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, cwd=os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if result.returncode == 0:
|
||||
print("✓ 测试完成!")
|
||||
print("✓ 覆盖率报告已生成到: tests/coverage_html/index.html")
|
||||
else:
|
||||
print("✗ 部分测试失败")
|
||||
print(" 请查看上方输出了解详情")
|
||||
print("=" * 60)
|
||||
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
@ -1,20 +0,0 @@
|
||||
"""
|
||||
服务层模块
|
||||
包含所有业务逻辑
|
||||
"""
|
||||
|
||||
from .auth_service import AuthService, User
|
||||
from .stats_service import StatsService
|
||||
from .assessment_service import AssessmentService
|
||||
from .suggestion_service import SuggestionService
|
||||
from .api_service import ApiService
|
||||
|
||||
__all__ = [
|
||||
'AuthService',
|
||||
'User',
|
||||
'StatsService',
|
||||
'AssessmentService',
|
||||
'SuggestionService',
|
||||
'ApiService'
|
||||
]
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{% block title %}编程能力个性化评价系统{% endblock %}</title>
|
||||
|
||||
<!-- Bootstrap CSS -->
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||
<!-- Font Awesome -->
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css" rel="stylesheet">
|
||||
|
||||
{% block styles %}{% endblock %}
|
||||
</head>
|
||||
<body>
|
||||
{% block content %}{% endblock %}
|
||||
|
||||
<!-- Bootstrap JS -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
{% block scripts %}{% endblock %}
|
||||
</body>
|
||||
</html>
|
||||
@ -1,5 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试包初始化文件
|
||||
"""
|
||||
@ -1,178 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pytest配置文件
|
||||
提供公共的fixtures和测试配置
|
||||
"""
|
||||
import pytest
|
||||
import sqlite3
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from flask import Flask
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from services import AuthService, StatsService, AssessmentService, SuggestionService
|
||||
from data_analyzer import DataAnalyzer
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def test_db_path():
|
||||
"""创建临时测试数据库路径"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, 'test.sqlite')
|
||||
yield db_path
|
||||
# 清理测试数据库
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def init_test_db(test_db_path):
|
||||
"""初始化测试数据库"""
|
||||
conn = sqlite3.connect(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建用户表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
uid INTEGER PRIMARY KEY,
|
||||
username TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建提交表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS submissions (
|
||||
sid INTEGER PRIMARY KEY,
|
||||
uid INTEGER,
|
||||
pid TEXT,
|
||||
pid_int INTEGER,
|
||||
status INTEGER,
|
||||
score INTEGER,
|
||||
lang TEXT,
|
||||
submit_time TEXT,
|
||||
FOREIGN KEY (uid) REFERENCES users(uid)
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建问题表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS problems (
|
||||
pid INTEGER PRIMARY KEY,
|
||||
title TEXT,
|
||||
difficulty TEXT,
|
||||
tags TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 插入测试用户
|
||||
cursor.executemany(
|
||||
"INSERT INTO users (uid, username) VALUES (?, ?)",
|
||||
[(1, '测试用户1'), (2, '测试用户2'), (3, '测试用户3')]
|
||||
)
|
||||
|
||||
# 插入测试提交数据
|
||||
test_submissions = [
|
||||
(1, 1, '1001', 1001, 1, 100, 'Python', '2024-01-01 10:00:00'),
|
||||
(2, 1, '1002', 1002, 1, 95, 'Python', '2024-01-02 11:00:00'),
|
||||
(3, 1, '1003', 1003, 0, 0, 'Java', '2024-01-03 12:00:00'),
|
||||
(4, 2, '1001', 1001, 1, 90, 'C++', '2024-01-01 13:00:00'),
|
||||
(5, 2, '1002', 1002, 3, 85, 'Python', '2024-01-02 14:00:00'),
|
||||
]
|
||||
cursor.executemany(
|
||||
"INSERT INTO submissions (sid, uid, pid, pid_int, status, score, lang, submit_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
test_submissions
|
||||
)
|
||||
|
||||
# 插入测试问题数据
|
||||
test_problems = [
|
||||
(1001, '两数之和', 'Easy', 'Array,Hash Table'),
|
||||
(1002, '二叉树遍历', 'Medium', 'Tree,DFS'),
|
||||
(1003, '动态规划问题', 'Hard', 'DP,Math'),
|
||||
]
|
||||
cursor.executemany(
|
||||
"INSERT INTO problems (pid, title, difficulty, tags) VALUES (?, ?, ?, ?)",
|
||||
test_problems
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
yield test_db_path
|
||||
|
||||
# 测试后清理数据
|
||||
conn = sqlite3.connect(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM submissions")
|
||||
cursor.execute("DELETE FROM users")
|
||||
cursor.execute("DELETE FROM problems")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_service(init_test_db):
|
||||
"""创建认证服务实例"""
|
||||
return AuthService(init_test_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stats_service(init_test_db):
|
||||
"""创建统计服务实例"""
|
||||
return StatsService(init_test_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def assessment_service(init_test_db):
|
||||
"""创建评估服务实例"""
|
||||
return AssessmentService(init_test_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suggestion_service(init_test_db):
|
||||
"""创建建议服务实例"""
|
||||
return SuggestionService(init_test_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_analyzer(init_test_db):
|
||||
"""创建数据分析器实例"""
|
||||
return DataAnalyzer(init_test_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(init_test_db):
|
||||
"""创建Flask测试应用"""
|
||||
from app import app as flask_app
|
||||
flask_app.config['TESTING'] = True
|
||||
flask_app.config['DB_PATH'] = init_test_db
|
||||
flask_app.config['WTF_CSRF_ENABLED'] = False
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""创建Flask测试客户端"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(app):
|
||||
"""创建Flask CLI测试运行器"""
|
||||
return app.test_cli_runner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(client, init_test_db):
|
||||
"""创建已登录的测试客户端"""
|
||||
# 登录用户
|
||||
client.post('/login', data={
|
||||
'username': '测试用户1',
|
||||
'password': '123456'
|
||||
}, follow_redirects=True)
|
||||
yield client
|
||||
# 登出
|
||||
client.get('/logout', follow_redirects=True)
|
||||
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据库集成测试
|
||||
"""
|
||||
import pytest
|
||||
import sqlite3
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDatabaseIntegration:
|
||||
"""数据库集成测试"""
|
||||
|
||||
def test_database_connection(self, init_test_db):
|
||||
"""测试数据库连接"""
|
||||
conn = sqlite3.connect(init_test_db)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert result[0] == 1
|
||||
|
||||
def test_users_table_exists(self, init_test_db):
|
||||
"""测试用户表存在"""
|
||||
conn = sqlite3.connect(init_test_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='users'
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == 'users'
|
||||
|
||||
def test_submissions_table_exists(self, init_test_db):
|
||||
"""测试提交表存在"""
|
||||
conn = sqlite3.connect(init_test_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='submissions'
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_data_integrity(self, init_test_db):
|
||||
"""测试数据完整性"""
|
||||
conn = sqlite3.connect(init_test_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查用户数据
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
user_count = cursor.fetchone()[0]
|
||||
assert user_count >= 3
|
||||
|
||||
# 检查提交数据
|
||||
cursor.execute("SELECT COUNT(*) FROM submissions")
|
||||
submission_count = cursor.fetchone()[0]
|
||||
assert submission_count >= 5
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_foreign_key_constraint(self, init_test_db):
|
||||
"""测试外键约束"""
|
||||
conn = sqlite3.connect(init_test_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询提交记录的用户是否存在
|
||||
cursor.execute("""
|
||||
SELECT s.uid, u.uid
|
||||
FROM submissions s
|
||||
LEFT JOIN users u ON s.uid = u.uid
|
||||
WHERE s.uid IS NOT NULL
|
||||
""")
|
||||
results = cursor.fetchall()
|
||||
|
||||
for submission_uid, user_uid in results:
|
||||
assert user_uid is not None, f"用户 {submission_uid} 不存在"
|
||||
|
||||
conn.close()
|
||||
@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
评估服务单元测试
|
||||
"""
|
||||
import pytest
|
||||
from services.assessment_service import AssessmentService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.assessment
|
||||
class TestAssessmentService:
|
||||
"""评估服务测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stats_data(self):
|
||||
"""模拟统计数据"""
|
||||
return {
|
||||
'total_submissions': 10,
|
||||
'accepted_submissions': 7,
|
||||
'acceptance_rate': 70.0,
|
||||
'avg_score': 85,
|
||||
'max_score': 100,
|
||||
'unique_problems': 5,
|
||||
'rank': 10
|
||||
}
|
||||
|
||||
def test_get_user_assessment_structure(self, assessment_service, mock_stats_data):
|
||||
"""测试评估结果结构"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
assert assessment is not None
|
||||
assert 'overall_score' in assessment
|
||||
assert 'level' in assessment
|
||||
assert 'total_problems_solved' in assessment
|
||||
assert 'acceptance_rate' in assessment
|
||||
assert 'strengths' in assessment
|
||||
assert 'improvement_areas' in assessment
|
||||
|
||||
def test_overall_score_calculation(self, assessment_service, mock_stats_data):
|
||||
"""测试综合评分计算"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
score = assessment['overall_score']
|
||||
assert 0 <= score <= 100
|
||||
assert isinstance(score, (int, float))
|
||||
|
||||
def test_level_determination(self, assessment_service, mock_stats_data):
|
||||
"""测试等级判定"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
level = assessment['level']
|
||||
valid_levels = ['初学者', '入门', '进阶', '熟练', '精通', '专家']
|
||||
assert level in valid_levels
|
||||
|
||||
def test_strengths_analysis(self, assessment_service, mock_stats_data):
|
||||
"""测试优势领域分析"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
strengths = assessment['strengths']
|
||||
assert isinstance(strengths, list)
|
||||
assert len(strengths) >= 0
|
||||
|
||||
def test_improvement_areas_analysis(self, assessment_service, mock_stats_data):
|
||||
"""测试改进领域分析"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
improvements = assessment['improvement_areas']
|
||||
assert isinstance(improvements, list)
|
||||
assert len(improvements) >= 0
|
||||
|
||||
def test_radar_data_generation(self, assessment_service, mock_stats_data):
|
||||
"""测试雷达图数据生成"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
if 'radar_data' in assessment:
|
||||
radar_data = assessment['radar_data']
|
||||
assert isinstance(radar_data, (list, dict))
|
||||
|
||||
def test_learning_suggestions(self, assessment_service, mock_stats_data):
|
||||
"""测试学习建议生成"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
if 'learning_suggestions' in assessment:
|
||||
suggestions = assessment['learning_suggestions']
|
||||
assert isinstance(suggestions, list)
|
||||
|
||||
def test_comparison_data(self, assessment_service, mock_stats_data):
|
||||
"""测试对比数据"""
|
||||
assessment = assessment_service.get_user_assessment(1, mock_stats_data)
|
||||
|
||||
if 'comparison' in assessment:
|
||||
comparison = assessment['comparison']
|
||||
assert 'personal_avg' in comparison or 'global_avg' in comparison
|
||||
|
||||
def test_assessment_with_low_stats(self, assessment_service):
|
||||
"""测试低统计数据的评估"""
|
||||
low_stats = {
|
||||
'total_submissions': 2,
|
||||
'accepted_submissions': 0,
|
||||
'acceptance_rate': 0,
|
||||
'avg_score': 0,
|
||||
'max_score': 0,
|
||||
'unique_problems': 0,
|
||||
'rank': 999
|
||||
}
|
||||
|
||||
assessment = assessment_service.get_user_assessment(1, low_stats)
|
||||
assert assessment is not None
|
||||
assert assessment['overall_score'] >= 0
|
||||
|
||||
def test_assessment_with_high_stats(self, assessment_service):
|
||||
"""测试高统计数据的评估"""
|
||||
high_stats = {
|
||||
'total_submissions': 100,
|
||||
'accepted_submissions': 95,
|
||||
'acceptance_rate': 95.0,
|
||||
'avg_score': 98,
|
||||
'max_score': 100,
|
||||
'unique_problems': 80,
|
||||
'rank': 1
|
||||
}
|
||||
|
||||
assessment = assessment_service.get_user_assessment(1, high_stats)
|
||||
assert assessment is not None
|
||||
assert assessment['overall_score'] > 60
|
||||
@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
认证服务单元测试
|
||||
"""
|
||||
import pytest
|
||||
from services.auth_service import User, AuthService
|
||||
|
||||
|
||||
class TestUser:
|
||||
"""用户模型测试"""
|
||||
|
||||
def test_user_creation(self):
|
||||
"""测试用户创建"""
|
||||
user = User(uid=1, username='测试用户')
|
||||
assert user.uid == 1
|
||||
assert user.username == '测试用户'
|
||||
assert user.id == '1'
|
||||
|
||||
def test_user_default_username(self):
|
||||
"""测试默认用户名"""
|
||||
user = User(uid=2)
|
||||
assert user.username == '用户2'
|
||||
|
||||
def test_get_id(self):
|
||||
"""测试获取用户ID"""
|
||||
user = User(uid=123)
|
||||
assert user.get_id() == '123'
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestAuthService:
|
||||
"""认证服务测试"""
|
||||
|
||||
def test_load_user_exists(self, auth_service):
|
||||
"""测试加载存在的用户"""
|
||||
user = auth_service.load_user('1')
|
||||
assert user is not None
|
||||
assert user.uid == 1
|
||||
assert user.username == '测试用户1'
|
||||
|
||||
def test_load_user_not_exists(self, auth_service):
|
||||
"""测试加载不存在的用户"""
|
||||
user = auth_service.load_user('999')
|
||||
assert user is None
|
||||
|
||||
def test_load_user_invalid_id(self, auth_service):
|
||||
"""测试加载无效用户ID"""
|
||||
user = auth_service.load_user('invalid')
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_by_username(self, auth_service):
|
||||
"""测试通过用户名认证"""
|
||||
user = auth_service.authenticate_user('测试用户1', '123456')
|
||||
assert user is not None
|
||||
assert user.uid == 1
|
||||
assert user.username == '测试用户1'
|
||||
|
||||
def test_authenticate_by_uid(self, auth_service):
|
||||
"""测试通过UID认证"""
|
||||
user = auth_service.authenticate_user('1', '123456')
|
||||
assert user is not None
|
||||
assert user.uid == 1
|
||||
|
||||
def test_authenticate_wrong_password(self, auth_service):
|
||||
"""测试错误密码"""
|
||||
user = auth_service.authenticate_user('测试用户1', 'wrong_password')
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_nonexistent_user(self, auth_service):
|
||||
"""测试认证不存在的用户"""
|
||||
user = auth_service.authenticate_user('nonexistent', '123456')
|
||||
assert user is None
|
||||
|
||||
def test_get_all_users(self, auth_service):
|
||||
"""测试获取所有用户"""
|
||||
users = auth_service.get_all_users()
|
||||
assert len(users) >= 3
|
||||
assert any(user['uid'] == 1 for user in users)
|
||||
assert any(user['username'] == '测试用户1' for user in users)
|
||||
|
||||
def test_get_all_users_empty_db(self, test_db_path):
|
||||
"""测试空数据库获取用户"""
|
||||
import os
|
||||
import tempfile
|
||||
temp_db = os.path.join(tempfile.mkdtemp(), 'empty.db')
|
||||
service = AuthService(temp_db)
|
||||
users = service.get_all_users()
|
||||
assert users == []
|
||||
@ -1,76 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
缓存管理器单元测试
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from services.cache_manager import CacheManager, get_cache_manager
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCacheManager:
|
||||
"""缓存管理器测试"""
|
||||
|
||||
def test_cache_manager_singleton(self):
|
||||
"""测试单例模式"""
|
||||
cache1 = get_cache_manager()
|
||||
cache2 = get_cache_manager()
|
||||
assert cache1 is cache2
|
||||
|
||||
def test_set_and_get(self):
|
||||
"""测试设置和获取缓存"""
|
||||
cache = get_cache_manager()
|
||||
cache.set('test_key', 'test_value', ttl=10)
|
||||
|
||||
value = cache.get('test_key')
|
||||
assert value == 'test_value'
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
"""测试获取不存在的键"""
|
||||
cache = get_cache_manager()
|
||||
value = cache.get('nonexistent_key')
|
||||
assert value is None
|
||||
|
||||
def test_cache_expiration(self):
|
||||
"""测试缓存过期"""
|
||||
cache = get_cache_manager()
|
||||
cache.set('expire_key', 'expire_value', ttl=1)
|
||||
|
||||
# 立即获取应该成功
|
||||
assert cache.get('expire_key') == 'expire_value'
|
||||
|
||||
# 等待过期
|
||||
time.sleep(2)
|
||||
assert cache.get('expire_key') is None
|
||||
|
||||
def test_delete_cache(self):
|
||||
"""测试删除缓存"""
|
||||
cache = get_cache_manager()
|
||||
cache.set('delete_key', 'delete_value')
|
||||
|
||||
cache.delete('delete_key')
|
||||
assert cache.get('delete_key') is None
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""测试清空缓存"""
|
||||
cache = get_cache_manager()
|
||||
cache.set('key1', 'value1')
|
||||
cache.set('key2', 'value2')
|
||||
|
||||
cache.clear()
|
||||
assert cache.get('key1') is None
|
||||
assert cache.get('key2') is None
|
||||
|
||||
def test_cache_with_complex_data(self):
|
||||
"""测试缓存复杂数据"""
|
||||
cache = get_cache_manager()
|
||||
complex_data = {
|
||||
'list': [1, 2, 3],
|
||||
'dict': {'a': 1, 'b': 2},
|
||||
'nested': {'x': [1, 2], 'y': {'z': 3}}
|
||||
}
|
||||
|
||||
cache.set('complex_key', complex_data)
|
||||
retrieved = cache.get('complex_key')
|
||||
assert retrieved == complex_data
|
||||
@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置文件单元测试
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConfig:
|
||||
"""配置测试"""
|
||||
|
||||
def test_config_import(self):
|
||||
"""测试配置导入"""
|
||||
from config import Config, DevelopmentConfig, ProductionConfig, TestingConfig
|
||||
|
||||
assert Config is not None
|
||||
assert DevelopmentConfig is not None
|
||||
assert ProductionConfig is not None
|
||||
assert TestingConfig is not None
|
||||
|
||||
def test_development_config(self):
|
||||
"""测试开发环境配置"""
|
||||
from config import DevelopmentConfig
|
||||
|
||||
assert DevelopmentConfig.DEBUG is True
|
||||
assert DevelopmentConfig.TESTING is False
|
||||
|
||||
def test_production_config(self):
|
||||
"""测试生产环境配置"""
|
||||
from config import ProductionConfig
|
||||
|
||||
assert ProductionConfig.DEBUG is False
|
||||
assert ProductionConfig.TESTING is False
|
||||
|
||||
def test_testing_config(self):
|
||||
"""测试测试环境配置"""
|
||||
from config import TestingConfig
|
||||
|
||||
assert TestingConfig.TESTING is True
|
||||
|
||||
def test_config_dict(self):
|
||||
"""测试配置字典"""
|
||||
from config import config
|
||||
|
||||
assert 'development' in config
|
||||
assert 'production' in config
|
||||
assert 'testing' in config
|
||||
assert 'default' in config
|
||||
|
||||
def test_secret_key_exists(self):
|
||||
"""测试密钥存在"""
|
||||
from config import Config
|
||||
|
||||
assert hasattr(Config, 'SECRET_KEY')
|
||||
assert Config.SECRET_KEY is not None
|
||||
|
||||
def test_db_path_exists(self):
|
||||
"""测试数据库路径"""
|
||||
from config import Config
|
||||
|
||||
assert hasattr(Config, 'DB_PATH')
|
||||
assert Config.DB_PATH is not None
|
||||
@ -1,62 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据分析器单元测试
|
||||
"""
|
||||
import pytest
|
||||
from data_analyzer import DataAnalyzer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDataAnalyzer:
|
||||
"""数据分析器测试"""
|
||||
|
||||
def test_analyzer_creation(self, data_analyzer):
|
||||
"""测试分析器创建"""
|
||||
assert data_analyzer is not None
|
||||
assert hasattr(data_analyzer, 'db_path')
|
||||
|
||||
def test_analyze_user_stats(self, data_analyzer):
|
||||
"""测试分析用户统计"""
|
||||
try:
|
||||
stats = data_analyzer.analyze_user_stats(1)
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
except Exception as e:
|
||||
pytest.skip(f"analyze_user_stats方法异常: {e}")
|
||||
|
||||
def test_get_difficulty_stats(self, data_analyzer):
|
||||
"""测试获取难度统计"""
|
||||
try:
|
||||
if hasattr(data_analyzer, 'get_difficulty_stats'):
|
||||
stats = data_analyzer.get_difficulty_stats(1)
|
||||
assert isinstance(stats, (list, dict))
|
||||
except Exception:
|
||||
pytest.skip("get_difficulty_stats方法不可用")
|
||||
|
||||
def test_get_time_trends(self, data_analyzer):
|
||||
"""测试获取时间趋势"""
|
||||
try:
|
||||
if hasattr(data_analyzer, 'get_time_trends'):
|
||||
trends = data_analyzer.get_time_trends(1)
|
||||
assert isinstance(trends, (list, dict))
|
||||
except Exception:
|
||||
pytest.skip("get_time_trends方法不可用")
|
||||
|
||||
def test_get_comparison_data(self, data_analyzer):
|
||||
"""测试获取对比数据"""
|
||||
try:
|
||||
comparison = data_analyzer.get_comparison_data(1)
|
||||
assert comparison is not None
|
||||
assert isinstance(comparison, dict)
|
||||
except Exception as e:
|
||||
pytest.skip(f"get_comparison_data方法异常: {e}")
|
||||
|
||||
def test_get_learning_insights(self, data_analyzer):
|
||||
"""测试获取学习洞察"""
|
||||
try:
|
||||
insights = data_analyzer.get_learning_insights(1)
|
||||
assert insights is not None
|
||||
assert isinstance(insights, list)
|
||||
except Exception as e:
|
||||
pytest.skip(f"get_learning_insights方法异常: {e}")
|
||||
@ -1,139 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LSTM预测器单元测试
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.lstm
|
||||
class TestLSTMPredictor:
|
||||
"""LSTM预测器测试"""
|
||||
|
||||
def test_lstm_predictor_import(self):
|
||||
"""测试LSTM预测器导入"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
assert LSTMPredictor is not None
|
||||
except ImportError:
|
||||
pytest.skip("LSTM预测器模块不存在")
|
||||
|
||||
def test_lstm_predictor_initialization(self):
|
||||
"""测试LSTM预测器初始化"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
assert predictor is not None
|
||||
except Exception as e:
|
||||
pytest.skip(f"LSTM预测器初始化失败: {e}")
|
||||
|
||||
def test_lstm_model_exists(self):
|
||||
"""测试LSTM模型文件存在"""
|
||||
import os
|
||||
model_path = 'models/lstm_knowledge_predictor.pth'
|
||||
|
||||
if os.path.exists(model_path):
|
||||
assert True
|
||||
else:
|
||||
pytest.skip("LSTM模型文件不存在")
|
||||
|
||||
def test_predict_knowledge_mastery(self):
|
||||
"""测试知识掌握度预测"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
|
||||
# 测试预测
|
||||
result = predictor.predict_knowledge_mastery(1)
|
||||
|
||||
if result:
|
||||
assert isinstance(result, dict)
|
||||
except Exception as e:
|
||||
pytest.skip(f"预测功能测试失败: {e}")
|
||||
|
||||
def test_predict_with_invalid_user(self):
|
||||
"""测试无效用户预测"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
|
||||
result = predictor.predict_knowledge_mastery(999999)
|
||||
# 应该返回None或空结果
|
||||
assert result is None or result == {}
|
||||
except Exception as e:
|
||||
pytest.skip(f"无效用户测试失败: {e}")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.lstm
|
||||
@pytest.mark.slow
|
||||
class TestLSTMModel:
|
||||
"""LSTM模型测试"""
|
||||
|
||||
def test_model_loading(self):
|
||||
"""测试模型加载"""
|
||||
try:
|
||||
import os
|
||||
model_path = 'models/lstm_knowledge_predictor.pth'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
pytest.skip("模型文件不存在")
|
||||
|
||||
model_data = torch.load(model_path, map_location='cpu')
|
||||
assert model_data is not None
|
||||
except Exception as e:
|
||||
pytest.skip(f"模型加载失败: {e}")
|
||||
|
||||
def test_model_inference(self):
|
||||
"""测试模型推理"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
import os
|
||||
|
||||
if not os.path.exists('models/lstm_knowledge_predictor.pth'):
|
||||
pytest.skip("模型文件不存在")
|
||||
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
|
||||
# 创建测试输入
|
||||
test_input = torch.randn(1, 10, 5) # batch_size=1, seq_len=10, features=5
|
||||
|
||||
# 测试模型可以处理输入
|
||||
# 注意:这只是测试模型结构,不测试实际预测效果
|
||||
assert True
|
||||
except Exception as e:
|
||||
pytest.skip(f"模型推理测试失败: {e}")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.lstm
|
||||
class TestLSTMIntegration:
|
||||
"""LSTM集成测试"""
|
||||
|
||||
def test_lstm_data_preparation(self):
|
||||
"""测试LSTM数据准备"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
|
||||
# 测试数据准备函数
|
||||
if hasattr(predictor, 'prepare_data'):
|
||||
data = predictor.prepare_data(1)
|
||||
assert data is not None
|
||||
except Exception as e:
|
||||
pytest.skip(f"数据准备测试失败: {e}")
|
||||
|
||||
def test_lstm_feature_extraction(self):
|
||||
"""测试LSTM特征提取"""
|
||||
try:
|
||||
from services.lstm_predictor import LSTMPredictor
|
||||
predictor = LSTMPredictor(db_path='backend/instance/app.sqlite')
|
||||
|
||||
if hasattr(predictor, 'extract_features'):
|
||||
features = predictor.extract_features(1)
|
||||
assert features is not None
|
||||
except Exception as e:
|
||||
pytest.skip(f"特征提取测试失败: {e}")
|
||||
@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统计服务单元测试
|
||||
"""
|
||||
import pytest
|
||||
from services.stats_service import StatsService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.stats
|
||||
class TestStatsService:
|
||||
"""统计服务测试"""
|
||||
|
||||
def test_get_user_stats_basic(self, stats_service):
|
||||
"""测试获取用户基本统计信息"""
|
||||
stats = stats_service.get_user_stats(1)
|
||||
|
||||
assert stats is not None
|
||||
assert 'total_submissions' in stats
|
||||
assert 'accepted_submissions' in stats
|
||||
assert 'acceptance_rate' in stats
|
||||
assert 'rank' in stats
|
||||
assert 'language_distribution' in stats
|
||||
|
||||
def test_get_user_stats_submissions_count(self, stats_service):
|
||||
"""测试提交数量统计"""
|
||||
stats = stats_service.get_user_stats(1)
|
||||
|
||||
assert stats['total_submissions'] == 3
|
||||
assert stats['accepted_submissions'] == 2
|
||||
|
||||
def test_get_user_stats_acceptance_rate(self, stats_service):
|
||||
"""测试通过率计算"""
|
||||
stats = stats_service.get_user_stats(1)
|
||||
|
||||
expected_rate = (2 / 3) * 100
|
||||
assert abs(stats['acceptance_rate'] - expected_rate) < 0.1
|
||||
|
||||
def test_get_user_stats_language_distribution(self, stats_service):
|
||||
"""测试语言分布统计"""
|
||||
stats = stats_service.get_user_stats(1)
|
||||
|
||||
lang_dist = stats['language_distribution']
|
||||
assert len(lang_dist) >= 1
|
||||
|
||||
# 检查语言分布格式
|
||||
for lang_stat in lang_dist:
|
||||
assert 'language' in lang_stat
|
||||
assert 'total_submissions' in lang_stat
|
||||
assert 'accepted_submissions' in lang_stat
|
||||
assert 'acceptance_rate' in lang_stat
|
||||
|
||||
def test_get_user_stats_nonexistent_user(self, stats_service):
|
||||
"""测试获取不存在用户的统计"""
|
||||
stats = stats_service.get_user_stats(999)
|
||||
|
||||
assert stats is not None
|
||||
assert stats['total_submissions'] == 0
|
||||
assert stats['accepted_submissions'] == 0
|
||||
|
||||
def test_get_user_stats_rank(self, stats_service):
|
||||
"""测试用户排名"""
|
||||
stats1 = stats_service.get_user_stats(1)
|
||||
stats2 = stats_service.get_user_stats(2)
|
||||
|
||||
assert 'rank' in stats1
|
||||
assert 'rank' in stats2
|
||||
assert stats1['rank'] >= 1
|
||||
assert stats2['rank'] >= 1
|
||||
|
||||
def test_get_language_stats(self, stats_service):
|
||||
"""测试获取语言统计"""
|
||||
try:
|
||||
lang_stats = stats_service.get_language_stats()
|
||||
assert lang_stats is not None
|
||||
assert isinstance(lang_stats, list)
|
||||
except AttributeError:
|
||||
# 如果方法不存在,跳过测试
|
||||
pytest.skip("get_language_stats方法不存在")
|
||||
|
||||
def test_cache_functionality(self, stats_service):
|
||||
"""测试缓存功能"""
|
||||
# 第一次调用
|
||||
stats1 = stats_service.get_user_stats(1)
|
||||
# 第二次调用(应该从缓存获取)
|
||||
stats2 = stats_service.get_user_stats(1)
|
||||
|
||||
assert stats1 == stats2
|
||||
|
||||
def test_get_user_stats_with_details(self, stats_service):
|
||||
"""测试获取详细统计信息"""
|
||||
stats = stats_service.get_user_stats(1)
|
||||
|
||||
# 检查是否包含详细统计
|
||||
if 'difficulty_stats' in stats:
|
||||
assert isinstance(stats['difficulty_stats'], (list, dict))
|
||||
|
||||
if 'time_trends' in stats:
|
||||
assert isinstance(stats['time_trends'], (list, dict))
|
||||
|
||||
if 'topic_stats' in stats:
|
||||
assert isinstance(stats['topic_stats'], (list, dict))
|
||||
@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
建议服务单元测试
|
||||
"""
|
||||
import pytest
|
||||
from services.suggestion_service import SuggestionService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSuggestionService:
|
||||
"""建议服务测试"""
|
||||
|
||||
def test_get_suggestions_structure(self, suggestion_service):
|
||||
"""测试建议结构"""
|
||||
try:
|
||||
suggestions = suggestion_service.get_suggestions(1)
|
||||
assert suggestions is not None
|
||||
assert isinstance(suggestions, (list, dict))
|
||||
except Exception as e:
|
||||
pytest.skip(f"get_suggestions方法异常: {e}")
|
||||
|
||||
def test_get_problem_recommendations(self, suggestion_service):
|
||||
"""测试问题推荐"""
|
||||
try:
|
||||
if hasattr(suggestion_service, 'get_problem_recommendations'):
|
||||
recommendations = suggestion_service.get_problem_recommendations(1)
|
||||
assert isinstance(recommendations, list)
|
||||
except Exception:
|
||||
pytest.skip("get_problem_recommendations方法不可用")
|
||||
|
||||
def test_get_learning_path(self, suggestion_service):
|
||||
"""测试学习路径"""
|
||||
try:
|
||||
if hasattr(suggestion_service, 'get_learning_path'):
|
||||
path = suggestion_service.get_learning_path(1)
|
||||
assert path is not None
|
||||
except Exception:
|
||||
pytest.skip("get_learning_path方法不可用")
|
||||
@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SonarQube覆盖率验证脚本
|
||||
验证coverage.xml是否正确生成并可被SonarQube识别
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
def check_coverage_xml():
|
||||
"""检查coverage.xml文件"""
|
||||
print("=" * 70)
|
||||
print("SonarQube 覆盖率报告验证")
|
||||
print("=" * 70)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists('coverage.xml'):
|
||||
print("❌ 错误: coverage.xml 文件不存在")
|
||||
print("\n请先运行测试生成覆盖率报告:")
|
||||
print(" python run_tests.py")
|
||||
print(" 或")
|
||||
print(" pytest tests/ --cov=. --cov-report=xml:coverage.xml")
|
||||
return False
|
||||
|
||||
print("✓ coverage.xml 文件存在")
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize('coverage.xml')
|
||||
print(f"✓ 文件大小: {file_size:,} bytes")
|
||||
|
||||
if file_size < 100:
|
||||
print("❌ 警告: 文件太小,可能没有包含有效数据")
|
||||
return False
|
||||
|
||||
# 解析XML文件
|
||||
try:
|
||||
tree = ET.parse('coverage.xml')
|
||||
root = tree.getroot()
|
||||
print("✓ XML格式有效")
|
||||
|
||||
# 获取覆盖率信息
|
||||
line_rate = float(root.get('line-rate', 0))
|
||||
lines_valid = int(root.get('lines-valid', 0))
|
||||
lines_covered = int(root.get('lines-covered', 0))
|
||||
|
||||
coverage_percent = line_rate * 100
|
||||
|
||||
print(f"\n📊 覆盖率统计:")
|
||||
print(f" 总代码行数: {lines_valid:,}")
|
||||
print(f" 已覆盖行数: {lines_covered:,}")
|
||||
print(f" 覆盖率: {coverage_percent:.2f}%")
|
||||
|
||||
# 检查是否有包信息
|
||||
packages = root.findall('.//package')
|
||||
print(f"\n📦 包数量: {len(packages)}")
|
||||
|
||||
# 检查是否有类信息
|
||||
classes = root.findall('.//class')
|
||||
print(f"📄 文件数量: {len(classes)}")
|
||||
|
||||
if len(classes) > 0:
|
||||
print("\n前5个已分析的文件:")
|
||||
for cls in classes[:5]:
|
||||
filename = cls.get('filename', 'unknown')
|
||||
cls_line_rate = float(cls.get('line-rate', 0))
|
||||
print(f" - {filename}: {cls_line_rate*100:.2f}%")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ coverage.xml 文件有效,可以被 SonarQube 识别")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n下一步:")
|
||||
print("1. 确保 SonarQube 服务正在运行")
|
||||
print("2. 运行 sonar-scanner 进行代码分析")
|
||||
print("3. 访问 http://localhost:9000 查看结果")
|
||||
|
||||
return True
|
||||
|
||||
except ET.ParseError as e:
|
||||
print(f"❌ XML解析错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ 未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_sonar_config():
|
||||
"""检查SonarQube配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print("检查 SonarQube 配置")
|
||||
print("=" * 70)
|
||||
|
||||
if not os.path.exists('sonar-project.properties'):
|
||||
print("❌ sonar-project.properties 文件不存在")
|
||||
return False
|
||||
|
||||
print("✓ sonar-project.properties 文件存在")
|
||||
|
||||
with open('sonar-project.properties', 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 检查关键配置
|
||||
checks = {
|
||||
'sonar.python.coverage.reportPaths': 'coverage.xml',
|
||||
'sonar.sources': ['app.py', 'services', 'config.py'],
|
||||
'sonar.tests': 'tests'
|
||||
}
|
||||
|
||||
all_ok = True
|
||||
for key, expected in checks.items():
|
||||
if key in content:
|
||||
print(f"✓ {key} 已配置")
|
||||
else:
|
||||
print(f"❌ {key} 未配置")
|
||||
all_ok = False
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
success = check_coverage_xml()
|
||||
config_ok = check_sonar_config()
|
||||
|
||||
if success and config_ok:
|
||||
print("\n🎉 所有检查通过! 可以运行 SonarQube 分析")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n⚠️ 存在问题,请修复后再运行 SonarQube 分析")
|
||||
sys.exit(1)
|
||||