|
|
|
|
@ -1,5 +1,3 @@
|
|
|
|
|
# src/main.py
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
中小学数学卷子自动生成程序(命令行版)
|
|
|
|
|
功能要点:
|
|
|
|
|
@ -10,13 +8,19 @@
|
|
|
|
|
- 生成题目时避免与该账号已有文件中的题目重复(查重)。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 标准库导入
|
|
|
|
|
import os
|
|
|
|
|
import sqlite3
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from typing import Callable, List, Set
|
|
|
|
|
from typing import List, Set
|
|
|
|
|
import sys # 导入 sys 模块以支持程序退出
|
|
|
|
|
|
|
|
|
|
# 从 questions 模块导入题目生成函数
|
|
|
|
|
from .questions import generate_primary_question, generate_middle_question, generate_high_question
|
|
|
|
|
# 本地应用模块导入
|
|
|
|
|
from .questions import (
|
|
|
|
|
HighQuestionGenerator,
|
|
|
|
|
MiddleQuestionGenerator,
|
|
|
|
|
PrimaryQuestionGenerator,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 数据库管理
|
|
|
|
|
@ -25,11 +29,20 @@ from .questions import generate_primary_question, generate_middle_question, gene
|
|
|
|
|
DB_NAME = "accounts.db"
|
|
|
|
|
VALID_LEVELS = ["小学", "初中", "高中"]
|
|
|
|
|
|
|
|
|
|
# 将难度级别与对应的生成器类进行映射
|
|
|
|
|
QUESTION_GENERATORS = {
|
|
|
|
|
"小学": PrimaryQuestionGenerator(),
|
|
|
|
|
"初中": MiddleQuestionGenerator(),
|
|
|
|
|
"高中": HighQuestionGenerator(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_db():
|
|
|
|
|
"""初始化数据库和账户表,并导入预设数据。"""
|
|
|
|
|
conn = sqlite3.connect(DB_NAME)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
# 创建用户表
|
|
|
|
|
cursor.execute("""
|
|
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
|
|
|
username TEXT PRIMARY KEY,
|
|
|
|
|
@ -37,6 +50,16 @@ def init_db():
|
|
|
|
|
level TEXT NOT NULL
|
|
|
|
|
)
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
# 新增:创建题目表,用于查重
|
|
|
|
|
cursor.execute("""
|
|
|
|
|
CREATE TABLE IF NOT EXISTS questions (
|
|
|
|
|
question TEXT PRIMARY KEY,
|
|
|
|
|
username TEXT NOT NULL,
|
|
|
|
|
level TEXT NOT NULL,
|
|
|
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
|
|
|
)
|
|
|
|
|
""")
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
|
|
|
|
predefined_accounts = {
|
|
|
|
|
@ -49,14 +72,16 @@ def init_db():
|
|
|
|
|
for username, password in users.items():
|
|
|
|
|
cursor.execute("SELECT 1 FROM users WHERE username = ?", (username,))
|
|
|
|
|
if cursor.fetchone() is None:
|
|
|
|
|
cursor.execute("INSERT INTO users (username, password, level) VALUES (?, ?, ?)",
|
|
|
|
|
(username, password, level))
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"INSERT INTO users (username, password, level) VALUES (?, ?, ?)",
|
|
|
|
|
(username, password, level)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def authenticate_user(username, password):
|
|
|
|
|
def authenticate_user(username: str, password: str) -> (str, str):
|
|
|
|
|
"""
|
|
|
|
|
通过查询数据库验证用户名和密码。
|
|
|
|
|
返回: (level, username) 或 (None, None)
|
|
|
|
|
@ -64,16 +89,17 @@ def authenticate_user(username, password):
|
|
|
|
|
conn = sqlite3.connect(DB_NAME)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
cursor.execute("SELECT level FROM users WHERE username = ? AND password = ?",
|
|
|
|
|
(username, password))
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT level FROM users WHERE username = ? AND password = ?",
|
|
|
|
|
(username, password)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = cursor.fetchone()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
if result:
|
|
|
|
|
return result[0], username
|
|
|
|
|
else:
|
|
|
|
|
return None, None
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 文件与查重相关函数
|
|
|
|
|
@ -83,70 +109,67 @@ def _ensure_dir(path: str) -> None:
|
|
|
|
|
"""确保目录存在,不存在则创建。"""
|
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
def _read_existing_questions(folder: str) -> Set[str]:
|
|
|
|
|
def _read_existing_questions_from_db(username: str) -> Set[str]:
|
|
|
|
|
"""
|
|
|
|
|
读取指定文件夹下所有文本文件,返回其中已存在的题目集合。
|
|
|
|
|
从数据库中读取指定用户已存在的题目集合。
|
|
|
|
|
"""
|
|
|
|
|
questions = set()
|
|
|
|
|
if not os.path.isdir(folder):
|
|
|
|
|
return questions
|
|
|
|
|
for fname in os.listdir(folder):
|
|
|
|
|
if not fname.lower().endswith(".txt"):
|
|
|
|
|
continue
|
|
|
|
|
fpath = os.path.join(folder, fname)
|
|
|
|
|
conn = sqlite3.connect(DB_NAME)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute("SELECT question FROM questions WHERE username = ?", (username,))
|
|
|
|
|
existing_questions = {row[0] for row in cursor.fetchall()}
|
|
|
|
|
conn.close()
|
|
|
|
|
return existing_questions
|
|
|
|
|
|
|
|
|
|
def _save_new_questions_to_db(new_questions: List[str], username: str, level: str):
|
|
|
|
|
"""
|
|
|
|
|
将新生成的题目保存到数据库中。
|
|
|
|
|
"""
|
|
|
|
|
conn = sqlite3.connect(DB_NAME)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
for question in new_questions:
|
|
|
|
|
try:
|
|
|
|
|
with open(fpath, "r", encoding="utf-8") as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
s = line.strip()
|
|
|
|
|
if not s:
|
|
|
|
|
continue
|
|
|
|
|
if s.split(".", 1)[0].isdigit() and s.count(".") >= 1:
|
|
|
|
|
parts = s.split(".", 1)
|
|
|
|
|
if len(parts) == 2:
|
|
|
|
|
content = parts[1].strip()
|
|
|
|
|
else:
|
|
|
|
|
content = s
|
|
|
|
|
else:
|
|
|
|
|
content = s
|
|
|
|
|
if content:
|
|
|
|
|
questions.add(content)
|
|
|
|
|
except Exception:
|
|
|
|
|
continue
|
|
|
|
|
return questions
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"INSERT OR IGNORE INTO questions (question, username, level) VALUES (?, ?, ?)",
|
|
|
|
|
(question, username, level)
|
|
|
|
|
)
|
|
|
|
|
except sqlite3.Error as exc:
|
|
|
|
|
print(f"Error saving question to database: {exc}")
|
|
|
|
|
conn.rollback()
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 生成卷子并保存
|
|
|
|
|
# ------------------------------
|
|
|
|
|
|
|
|
|
|
def generate_paper(level: str, count: int, folder: str) -> str:
|
|
|
|
|
def generate_paper(
|
|
|
|
|
level: str, count: int, folder: str, username: str
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
根据 level 和题目数量 count 生成一个卷子并保存。
|
|
|
|
|
"""
|
|
|
|
|
_ensure_dir(folder)
|
|
|
|
|
|
|
|
|
|
if level == "小学":
|
|
|
|
|
gen_func: Callable[[], str] = generate_primary_question
|
|
|
|
|
elif level == "初中":
|
|
|
|
|
gen_func = generate_middle_question
|
|
|
|
|
elif level == "高中":
|
|
|
|
|
gen_func = generate_high_question
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("level 必须是:小学、初中或高中")
|
|
|
|
|
# 使用多态性,动态获取生成器对象
|
|
|
|
|
generator = QUESTION_GENERATORS.get(level)
|
|
|
|
|
if not generator:
|
|
|
|
|
raise ValueError("无效的题目难度级别")
|
|
|
|
|
|
|
|
|
|
existing = _read_existing_questions(folder)
|
|
|
|
|
existing = _read_existing_questions_from_db(username)
|
|
|
|
|
|
|
|
|
|
new_questions: List[str] = []
|
|
|
|
|
attempt_limit = count * 20
|
|
|
|
|
attempts = 0
|
|
|
|
|
while len(new_questions) < count and attempts < attempt_limit:
|
|
|
|
|
attempts += 1
|
|
|
|
|
q = gen_func()
|
|
|
|
|
# 调用统一的 generate_question 方法
|
|
|
|
|
q = generator.generate_question()
|
|
|
|
|
q_norm = " ".join(q.split())
|
|
|
|
|
if q_norm not in existing and q_norm not in new_questions:
|
|
|
|
|
new_questions.append(q_norm)
|
|
|
|
|
|
|
|
|
|
if len(new_questions) < count:
|
|
|
|
|
print("提示:未能生成足够的不重复题目,请稍后再试或清理旧卷子。")
|
|
|
|
|
print("提示:未能生成足够的不重复题目,请稍后再试或清理已有卷子后重试。")
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
filename = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".txt"
|
|
|
|
|
@ -156,6 +179,9 @@ def generate_paper(level: str, count: int, folder: str) -> str:
|
|
|
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
|
|
|
for idx, q in enumerate(new_questions, start=1):
|
|
|
|
|
f.write(f"{idx}. {q}\n\n")
|
|
|
|
|
|
|
|
|
|
_save_new_questions_to_db(new_questions, username, level)
|
|
|
|
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print("保存文件失败:", exc)
|
|
|
|
|
return ""
|
|
|
|
|
@ -173,7 +199,13 @@ def login_prompt() -> (str, str):
|
|
|
|
|
要求格式:"用户名 密码"(用空格分开)。
|
|
|
|
|
"""
|
|
|
|
|
while True:
|
|
|
|
|
raw = input("请输入用户名和密码(空格隔开):").strip()
|
|
|
|
|
raw = input("请输入用户名和密码(空格隔开 或 输入“-2”退出):").strip()
|
|
|
|
|
|
|
|
|
|
# 新增的退出选项
|
|
|
|
|
if raw == "-2":
|
|
|
|
|
print("已退出。")
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
parts = raw.split()
|
|
|
|
|
if len(parts) != 2:
|
|
|
|
|
print("输入格式错误,请输入:用户名 密码 (中间以空格隔开)")
|
|
|
|
|
@ -186,8 +218,7 @@ def login_prompt() -> (str, str):
|
|
|
|
|
if level:
|
|
|
|
|
print(f"当前选择为 {level} 出题")
|
|
|
|
|
return level, auth_username
|
|
|
|
|
else:
|
|
|
|
|
print("请输入正确的用户名、密码")
|
|
|
|
|
print("请输入正确的用户名、密码")
|
|
|
|
|
|
|
|
|
|
def main_loop() -> None:
|
|
|
|
|
"""主循环:登录->出题->可切换/退出登录"""
|
|
|
|
|
@ -196,11 +227,15 @@ def main_loop() -> None:
|
|
|
|
|
user_folder = os.path.join("papers", username)
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
prompt = f"准备生成 {level} 数学题目,请输入生成题目数量(输入 -1 退出当前用户,或输入'切换为 XX'切换类型):"
|
|
|
|
|
prompt = (
|
|
|
|
|
f"准备生成 {level} 数学题目,请输入生成题目数量"
|
|
|
|
|
"(输入 -1 退出当前用户,或输入'切换为 XX'切换类型):"
|
|
|
|
|
)
|
|
|
|
|
inp = input(prompt).strip()
|
|
|
|
|
|
|
|
|
|
if inp.startswith("切换为"):
|
|
|
|
|
new_level = inp.replace("切换为", "").strip()
|
|
|
|
|
# 严格处理切换命令,只允许“切换为 高中”这种格式
|
|
|
|
|
if inp.startswith("切换为 ") and len(inp.split()) == 2:
|
|
|
|
|
new_level = inp.split()[1]
|
|
|
|
|
if new_level in VALID_LEVELS:
|
|
|
|
|
level = new_level
|
|
|
|
|
print(f"已切换,准备生成 {level} 数学题目")
|
|
|
|
|
@ -222,7 +257,7 @@ def main_loop() -> None:
|
|
|
|
|
print("题目数量必须在 10-30 之间(包含 10 和 30)")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
res = generate_paper(level, count, user_folder)
|
|
|
|
|
res = generate_paper(level, count, user_folder, username)
|
|
|
|
|
if not res:
|
|
|
|
|
print("生成失败,请尝试更小的题目数量或清理已有卷子后重试。")
|
|
|
|
|
else:
|
|
|
|
|
@ -236,5 +271,5 @@ def main() -> None:
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
print("\n程序已被用户中断,退出。")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
main()
|