From 35bd1900d71356e64b6c43621abb26453e721280 Mon Sep 17 00:00:00 2001 From: 202326010410 <3252818245@qq,com> Date: Tue, 23 Sep 2025 20:34:05 +0800 Subject: [PATCH 1/3] cs --- lstm.py | 240 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 lstm.py diff --git a/lstm.py b/lstm.py new file mode 100644 index 0000000..38b9dda --- /dev/null +++ b/lstm.py @@ -0,0 +1,240 @@ +import pandas as pd +from scipy.io.arff import loadarff +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset + +from torchmetrics.functional.classification import precision, recall, f1_score, auroc + + +# Download the dataset +traindata, trainmeta = loadarff('./ECG5000/ECG5000_TRAIN.arff') +testdata, testmeta = loadarff('./ECG5000/ECG5000_TEST.arff') +train = pd.DataFrame(traindata, columns=trainmeta.names()) +test = pd.DataFrame(testdata, columns=testmeta.names()) +df = pd.concat([train, test]) +print(train.shape, test.shape, df.shape) + +normal = df[df.iloc[:, -1] == b'1'] +abnormal = df[df.iloc[:, -1] != b'1'] + + +# 2. 数据预处理 +# 只使用正常样本训练自编码器 +X_normal = normal.iloc[:, :-1].values +X_abnormal = abnormal.iloc[:, :-1].values + +# 转换为PyTorch张量 (添加通道维度) +normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1) +abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1) +print(normal_tensor.shape, abnormal_tensor.shape) + +# 划分训练集(正常样本)和验证集索引 +dataset = TensorDataset(normal_tensor, normal_tensor) +train_idx = list(range(len(dataset)*4//5)) # 划分训练集索引 +val_idx = list(range(len(dataset)*4//5, len(dataset))) # 划分验证集索引 +print(len(train_idx), len(val_idx)) + +# 划分测试集(正常+异常) +x_val_tensor = normal_tensor[val_idx] +x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0) +y_test_tensor = torch.cat( + ( + torch.zeros(len(x_val_tensor),dtype=torch.long), + torch.ones(len(abnormal_tensor),dtype=torch.long) + ), + dim=0 +) +print(x_test_tensor.shape, y_test_tensor.shape) + +train_sampler = SubsetRandomSampler(indices=train_idx) +val_sampler = SubsetRandomSampler(indices=val_idx) +train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler) +val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler) + +class Encoder(nn.Module): + def __init__(self, context_len, n_variables, embedding_dim=64): + super(Encoder, self).__init__() + self.context_len, self.n_variables = context_len, n_variables # 时间步、输入特征 + self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim + self.lstm1 = nn.LSTM( + input_size=self.n_variables, + hidden_size=self.hidden_dim, + num_layers=1, + batch_first=True, + ) + self.lstm2 = nn.LSTM( + input_size=self.hidden_dim, + hidden_size=embedding_dim, + num_layers=1, + batch_first=True, + ) + + def forward(self, x): + batch_size = x.shape[0] + x, (_, _) = self.lstm1(x) + x, (hidden_n, _) = self.lstm2(x) + return hidden_n.reshape((batch_size, self.embedding_dim)) + +class Decoder(nn.Module): + def __init__(self, context_len, n_variables=1, input_dim=64): + super(Decoder, self).__init__() + self.context_len, self.input_dim = context_len, input_dim + self.hidden_dim, self.n_variables = 2 * input_dim, n_variables + self.lstm1 = nn.LSTM( + input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True + ) + self.lstm2 = nn.LSTM( + input_size=input_dim, + hidden_size=self.hidden_dim, + num_layers=1, + batch_first=True, + ) + self.output_layer = nn.Linear(self.hidden_dim, self.n_variables) + + def forward(self, x): + batch_size = x.shape[0] + x = x.repeat(self.context_len, self.n_variables) + x = x.reshape((batch_size, self.context_len, self.input_dim)) + x, (hidden_n, cell_n) = self.lstm1(x) + x, (hidden_n, cell_n) = self.lstm2(x) + x = x.reshape((batch_size, self.context_len, self.hidden_dim)) + + return self.output_layer(x) + +class LSTMAutoencoder(nn.Module): + def __init__(self, context_len, n_variables, embedding_dim): + super().__init__() + self.encoder = Encoder(context_len, n_variables, embedding_dim) + self.decoder = Decoder(context_len, n_variables, embedding_dim) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + +automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64) +optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4) +criterion = nn.MSELoss() + + + +def train(model, iterator): + model.train() + epoch_loss = 0 + + for batch_idx, (data, target) in enumerate(iterable=iterator): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + avg_loss = epoch_loss / len(iterator) + + return avg_loss + + +def evaluate(model, iterator): # Being used to validate and test + model.eval() + epoch_loss = 0 + + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(iterable=iterator): + output = model(data) + loss = criterion(output, target) + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(iterator) + + return avg_loss + +class EarlyStopping: + def __init__(self, patience=5, delta=0.0): + self.patience = patience # 允许的连续未改进次数 + self.delta = delta # 损失波动容忍阈值 + self.counter = 0 # 未改进计数器 + self.best_loss = float('inf') # 最佳验证损失值 + self.early_stop = False # 终止训练标志 + + def __call__(self, val_loss, model): + if val_loss < (self.best_loss - self.delta): + self.best_loss = val_loss + self.counter = 0 + # 保存最佳模型参数‌:ml-citation{ref="1,5" data="citationList"} + torch.save(model.state_dict(), 'best_model.pth') + else: + self.counter +=1 + if self.counter >= self.patience: + self.early_stop = True + +EarlyStopper = EarlyStopping(patience=10, delta=0.00001) # 设置参数 + +def main(): + train_losses = [] + val_losses = [] + + for epoch in range(300): + train_loss = train(model=automodel, iterator=train_loader) + val_loss = evaluate(model=automodel, iterator=val_loader) + + train_losses.append(train_loss) + val_losses.append(val_loss) + + print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}') + + # 触发早停判断 + EarlyStopper(val_loss, model=automodel) + if EarlyStopper.early_stop: + print(f"Early stopping at epoch {epoch}") + break + + plt.figure(figsize=(10, 5)) + plt.plot(train_losses, label='Training Loss') + plt.plot(val_losses, label='Validation Loss') + plt.xlabel('Epoch') + plt.ylabel('MSELoss') + plt.title('Training and Validation Loss over Epochs') + plt.legend() + plt.grid(True) + plt.show() + +main() + +# 5. 异常检测 +def detect_anomalies(model, x): + model.eval() + with torch.no_grad(): + reconstructions = model(x) + mse = torch.mean((x - reconstructions)**2, dim=(1,2)) + return mse + +# 在测试集上计算重建误差 +test_mse = detect_anomalies(automodel, x_test_tensor) + +# 设置阈值 (使用验证集正常样本的95%分位数) +val_mse = detect_anomalies(automodel, x_val_tensor) +threshold = torch.quantile(val_mse, 0.95) + +# 预测结果 +y_pred = (test_mse > threshold).long() + + +print(f'Threshold: {threshold:.4f}') +print(y_pred.dtype) +print(y_pred.shape) + + +pre = precision(preds=y_pred, target=y_test_tensor, task="binary") +print(f"Precision: {pre:.5f}") + +rec = recall(preds=y_pred, target=y_test_tensor, task="binary") +print(f"Recall: {rec:.5f}") + +f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary") +print(f"F1 Score: {f1:.5f}") + +auc = auroc(preds=test_mse, target=y_test_tensor, task="binary") +print(f"AUC: {auc:.5f}") -- 2.34.1 From 77285c14286b75dd1621bca6c1c6e9a1db1813ec Mon Sep 17 00:00:00 2001 From: 202326010410 <3252818245@qq,com> Date: Tue, 23 Sep 2025 20:40:22 +0800 Subject: [PATCH 2/3] cs --- lstm.py | 240 -------------------------------------------------------- 1 file changed, 240 deletions(-) delete mode 100644 lstm.py diff --git a/lstm.py b/lstm.py deleted file mode 100644 index 38b9dda..0000000 --- a/lstm.py +++ /dev/null @@ -1,240 +0,0 @@ -import pandas as pd -from scipy.io.arff import loadarff -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset - -from torchmetrics.functional.classification import precision, recall, f1_score, auroc - - -# Download the dataset -traindata, trainmeta = loadarff('./ECG5000/ECG5000_TRAIN.arff') -testdata, testmeta = loadarff('./ECG5000/ECG5000_TEST.arff') -train = pd.DataFrame(traindata, columns=trainmeta.names()) -test = pd.DataFrame(testdata, columns=testmeta.names()) -df = pd.concat([train, test]) -print(train.shape, test.shape, df.shape) - -normal = df[df.iloc[:, -1] == b'1'] -abnormal = df[df.iloc[:, -1] != b'1'] - - -# 2. 数据预处理 -# 只使用正常样本训练自编码器 -X_normal = normal.iloc[:, :-1].values -X_abnormal = abnormal.iloc[:, :-1].values - -# 转换为PyTorch张量 (添加通道维度) -normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1) -abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1) -print(normal_tensor.shape, abnormal_tensor.shape) - -# 划分训练集(正常样本)和验证集索引 -dataset = TensorDataset(normal_tensor, normal_tensor) -train_idx = list(range(len(dataset)*4//5)) # 划分训练集索引 -val_idx = list(range(len(dataset)*4//5, len(dataset))) # 划分验证集索引 -print(len(train_idx), len(val_idx)) - -# 划分测试集(正常+异常) -x_val_tensor = normal_tensor[val_idx] -x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0) -y_test_tensor = torch.cat( - ( - torch.zeros(len(x_val_tensor),dtype=torch.long), - torch.ones(len(abnormal_tensor),dtype=torch.long) - ), - dim=0 -) -print(x_test_tensor.shape, y_test_tensor.shape) - -train_sampler = SubsetRandomSampler(indices=train_idx) -val_sampler = SubsetRandomSampler(indices=val_idx) -train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler) -val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler) - -class Encoder(nn.Module): - def __init__(self, context_len, n_variables, embedding_dim=64): - super(Encoder, self).__init__() - self.context_len, self.n_variables = context_len, n_variables # 时间步、输入特征 - self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim - self.lstm1 = nn.LSTM( - input_size=self.n_variables, - hidden_size=self.hidden_dim, - num_layers=1, - batch_first=True, - ) - self.lstm2 = nn.LSTM( - input_size=self.hidden_dim, - hidden_size=embedding_dim, - num_layers=1, - batch_first=True, - ) - - def forward(self, x): - batch_size = x.shape[0] - x, (_, _) = self.lstm1(x) - x, (hidden_n, _) = self.lstm2(x) - return hidden_n.reshape((batch_size, self.embedding_dim)) - -class Decoder(nn.Module): - def __init__(self, context_len, n_variables=1, input_dim=64): - super(Decoder, self).__init__() - self.context_len, self.input_dim = context_len, input_dim - self.hidden_dim, self.n_variables = 2 * input_dim, n_variables - self.lstm1 = nn.LSTM( - input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True - ) - self.lstm2 = nn.LSTM( - input_size=input_dim, - hidden_size=self.hidden_dim, - num_layers=1, - batch_first=True, - ) - self.output_layer = nn.Linear(self.hidden_dim, self.n_variables) - - def forward(self, x): - batch_size = x.shape[0] - x = x.repeat(self.context_len, self.n_variables) - x = x.reshape((batch_size, self.context_len, self.input_dim)) - x, (hidden_n, cell_n) = self.lstm1(x) - x, (hidden_n, cell_n) = self.lstm2(x) - x = x.reshape((batch_size, self.context_len, self.hidden_dim)) - - return self.output_layer(x) - -class LSTMAutoencoder(nn.Module): - def __init__(self, context_len, n_variables, embedding_dim): - super().__init__() - self.encoder = Encoder(context_len, n_variables, embedding_dim) - self.decoder = Decoder(context_len, n_variables, embedding_dim) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - -automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64) -optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4) -criterion = nn.MSELoss() - - - -def train(model, iterator): - model.train() - epoch_loss = 0 - - for batch_idx, (data, target) in enumerate(iterable=iterator): - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - avg_loss = epoch_loss / len(iterator) - - return avg_loss - - -def evaluate(model, iterator): # Being used to validate and test - model.eval() - epoch_loss = 0 - - with torch.no_grad(): - for batch_idx, (data, target) in enumerate(iterable=iterator): - output = model(data) - loss = criterion(output, target) - epoch_loss += loss.item() - - avg_loss = epoch_loss / len(iterator) - - return avg_loss - -class EarlyStopping: - def __init__(self, patience=5, delta=0.0): - self.patience = patience # 允许的连续未改进次数 - self.delta = delta # 损失波动容忍阈值 - self.counter = 0 # 未改进计数器 - self.best_loss = float('inf') # 最佳验证损失值 - self.early_stop = False # 终止训练标志 - - def __call__(self, val_loss, model): - if val_loss < (self.best_loss - self.delta): - self.best_loss = val_loss - self.counter = 0 - # 保存最佳模型参数‌:ml-citation{ref="1,5" data="citationList"} - torch.save(model.state_dict(), 'best_model.pth') - else: - self.counter +=1 - if self.counter >= self.patience: - self.early_stop = True - -EarlyStopper = EarlyStopping(patience=10, delta=0.00001) # 设置参数 - -def main(): - train_losses = [] - val_losses = [] - - for epoch in range(300): - train_loss = train(model=automodel, iterator=train_loader) - val_loss = evaluate(model=automodel, iterator=val_loader) - - train_losses.append(train_loss) - val_losses.append(val_loss) - - print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}') - - # 触发早停判断 - EarlyStopper(val_loss, model=automodel) - if EarlyStopper.early_stop: - print(f"Early stopping at epoch {epoch}") - break - - plt.figure(figsize=(10, 5)) - plt.plot(train_losses, label='Training Loss') - plt.plot(val_losses, label='Validation Loss') - plt.xlabel('Epoch') - plt.ylabel('MSELoss') - plt.title('Training and Validation Loss over Epochs') - plt.legend() - plt.grid(True) - plt.show() - -main() - -# 5. 异常检测 -def detect_anomalies(model, x): - model.eval() - with torch.no_grad(): - reconstructions = model(x) - mse = torch.mean((x - reconstructions)**2, dim=(1,2)) - return mse - -# 在测试集上计算重建误差 -test_mse = detect_anomalies(automodel, x_test_tensor) - -# 设置阈值 (使用验证集正常样本的95%分位数) -val_mse = detect_anomalies(automodel, x_val_tensor) -threshold = torch.quantile(val_mse, 0.95) - -# 预测结果 -y_pred = (test_mse > threshold).long() - - -print(f'Threshold: {threshold:.4f}') -print(y_pred.dtype) -print(y_pred.shape) - - -pre = precision(preds=y_pred, target=y_test_tensor, task="binary") -print(f"Precision: {pre:.5f}") - -rec = recall(preds=y_pred, target=y_test_tensor, task="binary") -print(f"Recall: {rec:.5f}") - -f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary") -print(f"F1 Score: {f1:.5f}") - -auc = auroc(preds=test_mse, target=y_test_tensor, task="binary") -print(f"AUC: {auc:.5f}") -- 2.34.1 From 51808f35ef8e3d6b595dae8368d9ab047c987d5f Mon Sep 17 00:00:00 2001 From: 202326010410 <3252818245@qq,com> Date: Sat, 27 Sep 2025 23:05:56 +0800 Subject: [PATCH 3/3] cs --- src/AccountManager.java | 53 ++++++++++ src/BaseQuestionGenerator.java | 79 +++++++++++++++ src/Constants.java | 16 +++ src/FileManager.java | 99 +++++++++++++++++++ src/HighQuestionGenerator.java | 47 +++++++++ src/MathExamGenerator.java | 155 ++++++++++++++++++++++++++++++ src/MiddleQuestionGenerator.java | 41 ++++++++ src/PrimaryQuestionGenerator.java | 94 ++++++++++++++++++ 8 files changed, 584 insertions(+) create mode 100644 src/AccountManager.java create mode 100644 src/BaseQuestionGenerator.java create mode 100644 src/Constants.java create mode 100644 src/FileManager.java create mode 100644 src/HighQuestionGenerator.java create mode 100644 src/MathExamGenerator.java create mode 100644 src/MiddleQuestionGenerator.java create mode 100644 src/PrimaryQuestionGenerator.java diff --git a/src/AccountManager.java b/src/AccountManager.java new file mode 100644 index 0000000..0c2d8c7 --- /dev/null +++ b/src/AccountManager.java @@ -0,0 +1,53 @@ +import java.util.HashMap; +import java.util.Map; + + +/** + * 负责账号验证与学段映射. + */ +public final class AccountManager { + + + private final Map userToPassword = new HashMap<>(); + private final Map userToLevel = new HashMap<>(); + + /** + * 用户姓名密码信息初始化. + */ + public AccountManager() { + + register("张三1", "123", "小学"); + register("张三2", "123", "小学"); + register("张三3", "123", "小学"); + + register("李四1", "123", "初中"); + register("李四2", "123", "初中"); + register("李四3", "123", "初中"); + + register("王五1", "123", "高中"); + register("王五2", "123", "高中"); + register("王五3", "123", "高中"); + } + + + private void register(String user, String pwd, String level) { + userToPassword.put(user, pwd); + userToLevel.put(user, level); + } + + /** + * 账号密码匹配. + */ + public boolean validate(String user, String pwd) { + return user != null && pwd != null && pwd.equals(userToPassword.get(user)); + } + + /** + * 获取用户初始等级. + */ + public String getLevel(String user) { + return userToLevel.get(user); + } + + +} \ No newline at end of file diff --git a/src/BaseQuestionGenerator.java b/src/BaseQuestionGenerator.java new file mode 100644 index 0000000..a85191e --- /dev/null +++ b/src/BaseQuestionGenerator.java @@ -0,0 +1,79 @@ +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; + +/** + * 抽象题目生成器基类. 子类需要实现 {@link #generateQuestion()} 来提供不同学段的题目生成逻辑. + */ +public abstract class BaseQuestionGenerator { + + /** + * 用于生成随机数的随机数生成器. + */ + protected final Random random = new Random(); + + /** + * 生成一组不重复的数学题目列表. + * + * @param count 要生成的题目数量. + * @param existingQuestions 已经存在的题目集合. + * @return 生成的题目列表. + */ + public List generate(int count, Set existingQuestions) { + List result = new ArrayList<>(); + Set seen = new HashSet<>(existingQuestions == null ? 0 : existingQuestions.size()); + if (existingQuestions != null) { + seen.addAll(existingQuestions); + } + + int attempts = 0; + while (result.size() < count && attempts < count * 20) { + String question = generateQuestion(); + if (!seen.contains(question)) { + result.add(question); + seen.add(question); + } + attempts++; + } + + return result; + } + + /** + * 生成一个指定范围内的随机整数. + * + * @param a 下界. + * @param b 上界. + * @return 返回一个位于 a 和 b 之间的随机整数. + */ + protected int randInt(int a, int b) { + return a + random.nextInt(b - a + 1); + } + + /** + * 随机返回一个基础运算符: "+", "-", "*", 或 "/". + * + * @return 随机运算符字符串. + */ + protected String randomBasicOperator() { + return switch (random.nextInt(4)) { + case 0 -> "+"; + case 1 -> "-"; + case 2 -> "*"; + default -> "/"; + }; + } + + /** + * 生成单个题目. 子类必须实现该方法以提供不同学段的题目生成逻辑. + * + * @return 返回生成的数学题目字符串. + */ + protected abstract String generateQuestion(); +} + + + + diff --git a/src/Constants.java b/src/Constants.java new file mode 100644 index 0000000..1f297f7 --- /dev/null +++ b/src/Constants.java @@ -0,0 +1,16 @@ + +/** 常量定义. */ +public final class Constants { + public static final String BASE_OUTPUT_DIR = "output"; // 主输出文件夹 + public static final int MIN_QUESTIONS = 10; + public static final int MAX_QUESTIONS = 30; + public static final int MIN_OPERANDS = 1; + public static final int MAX_OPERANDS = 5; + public static final int MIN_VALUE = 1; + public static final int MAX_VALUE = 100; + + + private Constants() { + // no instances + } +} \ No newline at end of file diff --git a/src/FileManager.java b/src/FileManager.java new file mode 100644 index 0000000..b466569 --- /dev/null +++ b/src/FileManager.java @@ -0,0 +1,99 @@ +import java.io.BufferedWriter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * 负责与磁盘文件交互:读取历史题目、保存新生成的试卷文件. 所有生成的卷子会保存在 {@code Constants.BASE_OUTPUT_DIR} 目录下, + * 其中按照用户账号创建子目录。每次保存的文件名为 {@code yyyy-MM-dd-HH-mm-ss.txt} 格式. + */ +public final class FileManager { + + /** + * 试卷保存的根目录路径. + */ + private final Path baseDir; + + /** + * 构造方法。初始化输出目录. 如果 {@link Constants#BASE_OUTPUT_DIR} 指定的目录不存在, 则会自动创建. + * + * @throws RuntimeException 如果目录无法创建时抛出运行时异常. + */ + public FileManager() { + this.baseDir = Paths.get(Constants.BASE_OUTPUT_DIR); + try { + Files.createDirectories(baseDir); + } catch (IOException e) { + throw new RuntimeException("无法创建输出目录: " + baseDir, e); + } + } + + /** + * 查重功能 读取指定用户历史上已生成的所有题目内容,用于避免重复出题. 该方法会遍历用户目录下所有以 .txt 结尾的文件, 过滤掉形如 + * "第x题:"的题号前缀,只收集真实题干文本. + * + * @param username 用户名(用于定位其专属文件夹) + * @return 包含用户所有历史题目的字符串集合;如果目录不存在或无文件则返回空集合. + */ + public Set loadExistingQuestions(String username) { + Path userDir = baseDir.resolve(username); + Set existing = new HashSet<>(); + if (!Files.exists(userDir)) { + return existing; + } + try { + Files.list(userDir) + .filter(p -> p.toString().endsWith(".txt")) + .forEach(p -> { + try { + List lines = Files.readAllLines(p, StandardCharsets.UTF_8); + for (String line : lines) { + String t = line.trim(); + // 只保留题目正文 + if (!t.isEmpty() && !t.matches("^第\\d+题.*$")) { + existing.add(t); + } + } + } catch (IOException ignore) { + // 忽略单个文件失败 + } + }); + } catch (IOException ignore) { + // 遍历目录失败时忽略 + } + return existing; + } + + /** + * 将生成的试卷保存到指定用户目录中. 保存路径格式: BASE_OUTPUT_DIR/用户名/yyyy-MM-dd-HH-mm-ss.txt 每道题会以“第n题: + * 题目”的形式写入文件,并在题目间留一行空行. + * + * @param username 当前登录的用户名,用于确定子目录 + * @param questions 题目列表,每个元素是一道题的字符串 + * @return 保存的试卷文件的完整 {@link Path} + * @throws IOException 当文件写入或目录创建失败时抛出. + */ + public Path saveExam(String username, List questions) throws IOException { + Path userDir = baseDir.resolve(username); + Files.createDirectories(userDir); + + String name = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss").format(new Date()) + ".txt"; + Path out = userDir.resolve(name); + + try (BufferedWriter w = Files.newBufferedWriter(out, StandardCharsets.UTF_8)) { + for (int i = 0; i < questions.size(); i++) { + w.write(String.format("第%d题: %s", i + 1, questions.get(i))); + w.newLine(); + w.newLine(); + } + } + return out; + } +} diff --git a/src/HighQuestionGenerator.java b/src/HighQuestionGenerator.java new file mode 100644 index 0000000..004f890 --- /dev/null +++ b/src/HighQuestionGenerator.java @@ -0,0 +1,47 @@ +import java.util.ArrayList; +import java.util.List; + +/** + * 高中题目生成器. + */ +class HighQuestionGenerator extends BaseQuestionGenerator { + + @Override + protected String generateQuestion() { + int ops = randInt(Constants.MIN_OPERANDS, Constants.MAX_OPERANDS); + List tokens = new ArrayList<>(); + boolean hasTrig = false; + + for (int i = 0; i < ops; i++) { + if (!hasTrig && random.nextDouble() < 0.5) { + tokens.add(randomTrig()); + hasTrig = true; + } else if (random.nextDouble() < 0.2) { + tokens.add("(" + randInt(Constants.MIN_VALUE, Constants.MAX_VALUE) + "^2)"); + } else { + tokens.add(String.valueOf(randInt(Constants.MIN_VALUE, Constants.MAX_VALUE))); + } + if (i < ops - 1) { + tokens.add(randomBasicOperator()); + } + } + + if (!hasTrig) { + int pos = Math.max(0, random.nextInt(tokens.size() / 2 + 1) * 2); + tokens.add(pos, randomTrig()); + if (pos + 1 < tokens.size()) { + tokens.add(pos + 1, "+"); + } + } + + return String.join(" ", tokens); + } + + private String randomTrig() { + return switch (random.nextInt(3)) { + case 0 -> "sin(" + randInt(0, 90) + ")"; + case 1 -> "cos(" + randInt(0, 90) + ")"; + default -> "tan(" + randInt(0, 60) + ")"; + }; + } +} diff --git a/src/MathExamGenerator.java b/src/MathExamGenerator.java new file mode 100644 index 0000000..f38714c --- /dev/null +++ b/src/MathExamGenerator.java @@ -0,0 +1,155 @@ +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Scanner; +import java.util.Set; + +/** + * 程序入口类,提供命令行交互以生成中小学数学卷子. + */ +public final class MathExamGenerator { + + private final AccountManager accountManager = new AccountManager(); + private final FileManager fileManager = new FileManager(); + + /** + * 程序主入口. + * + * @param unused 未使用的命令行参数 + */ + public static void main(String[] unused) { + new MathExamGenerator().run(); + } + + /** + * 运行程序的主循环,包含登录、题目生成、文件保存等逻辑. + */ + private void run() { + try (Scanner scanner = new Scanner(System.in)) { + while (true) { + String currentUser = login(scanner); + String currentLevel = accountManager.getLevel(currentUser); + + while (true) { + System.out.printf( + "准备生成 %s 数学题目,请输入生成题目数量(输入-1将退出当前用户,重新登录):", + currentLevel); + String line = scanner.nextLine().trim(); + + if ("-1".equals(line)) { + break; + } + + if (line.startsWith("切换为")) { + currentLevel = handleLevelSwitch(line, currentLevel); + continue; + } + + handleQuestionGeneration(currentUser, currentLevel, line); + } + } + } + } + + /** + * 登录流程:验证用户名和密码. + * + * @param scanner 控制台输入流. + * @return 成功登录的用户名. + */ + private String login(Scanner scanner) { + while (true) { + System.out.print("请输入用户名 密码(用空格分隔):"); + String line = scanner.nextLine().trim(); + if (line.isEmpty()) { + continue; + } + + String[] parts = line.split("\\s+"); + if (parts.length < 2) { + System.out.println("请输入正确的用户名、密码"); + continue; + } + + if (accountManager.validate(parts[0], parts[1])) { + String username = parts[0]; + System.out.println("当前选择为 " + accountManager.getLevel(username) + " 出题"); + return username; + } else { + System.out.println("用户名或密码错误,请重新输入"); + } + } + } + + /** + * 处理学段切换命令. + * + * @param line 用户输入 + * @param currentLevel 当前学段 + * @return 切换后的学段(若输入无效则返回原学段) + */ + private String handleLevelSwitch(String line, String currentLevel) { + String opt = line.replaceFirst("切换为", "").trim(); + if (!"小学".equals(opt) && !"初中".equals(opt) && !"高中".equals(opt)) { + System.out.println("请输入小学、初中或高中三个选项中的一个"); + return currentLevel; + } + System.out.println("当前选择为 " + opt + " 出题"); + return opt; + } + + /** + * 生成指定数量的题目并保存为试卷文件. + * + * @param username 用户名 + * @param level 学段 + * @param line 用户输入的题目数量 + */ + private void handleQuestionGeneration(String username, String level, String line) { + int num; + try { + num = Integer.parseInt(line); + } catch (NumberFormatException e) { + System.out.println("请输入有效的整数数量或-1退出"); + return; + } + + if (num < Constants.MIN_QUESTIONS || num > Constants.MAX_QUESTIONS) { + System.out.printf( + "题目数量的有效输入范围是 %d-%d(含)%n", Constants.MIN_QUESTIONS, Constants.MAX_QUESTIONS); + return; + } + + BaseQuestionGenerator generator = getGenerator(level); + Set existing = fileManager.loadExistingQuestions(username); + List questions = generator.generate(num, existing); + + if (questions.size() < num) { + System.out.println("无法生成足够不重复的题目,请删除部分历史题目或减少生成数量"); + return; + } + + try { + Path path = fileManager.saveExam(username, questions); + // 转为 String 打印相对路径 + System.out.println("试卷已保存:" + path); + } catch (IOException e) { + System.out.println("保存文件时出错:" + e.getMessage()); + } + } + + /** + * 根据学段返回对应的题目生成器. + * + * @param level 学段 + * @return 对应的 BaseQuestionGenerator 实例 + */ + private BaseQuestionGenerator getGenerator(String level) { + return switch (level) { + case "小学" -> new PrimaryQuestionGenerator(); + case "初中" -> new MiddleQuestionGenerator(); + case "高中" -> new HighQuestionGenerator(); + default -> new PrimaryQuestionGenerator(); + }; + } +} diff --git a/src/MiddleQuestionGenerator.java b/src/MiddleQuestionGenerator.java new file mode 100644 index 0000000..57c81c4 --- /dev/null +++ b/src/MiddleQuestionGenerator.java @@ -0,0 +1,41 @@ +import java.util.ArrayList; +import java.util.List; + +/** + * 初中题目生成器. + */ +class MiddleQuestionGenerator extends BaseQuestionGenerator { + + @Override + protected String generateQuestion() { + int ops = randInt(Constants.MIN_OPERANDS, Constants.MAX_OPERANDS); + List tokens = new ArrayList<>(); + boolean hasPowerOrSqrt = false; + + for (int i = 0; i < ops; i++) { + if (!hasPowerOrSqrt && random.nextDouble() < 0.4) { + if (random.nextBoolean()) { + tokens.add("(" + randInt(Constants.MIN_VALUE, Constants.MAX_VALUE) + "^2)"); + } else { + tokens.add("sqrt(" + randInt(Constants.MIN_VALUE, Constants.MAX_VALUE) + ")"); + } + hasPowerOrSqrt = true; + } else { + tokens.add(String.valueOf(randInt(Constants.MIN_VALUE, Constants.MAX_VALUE))); + } + if (i < ops - 1) { + tokens.add(randomBasicOperator()); + } + } + + if (!hasPowerOrSqrt) { + int pos = Math.max(0, random.nextInt(tokens.size() / 2 + 1) * 2); + tokens.add(pos, "sqrt(4)"); + if (pos + 1 < tokens.size()) { + tokens.add(pos + 1, "+"); + } + } + + return String.join(" ", tokens); + } +} \ No newline at end of file diff --git a/src/PrimaryQuestionGenerator.java b/src/PrimaryQuestionGenerator.java new file mode 100644 index 0000000..b5e21e5 --- /dev/null +++ b/src/PrimaryQuestionGenerator.java @@ -0,0 +1,94 @@ +import java.util.ArrayList; +import java.util.List; + +class PrimaryQuestionGenerator extends BaseQuestionGenerator { + + @Override + protected String generateQuestion() { + List numbers = generateNumbers(); + List operators = generateOperators(numbers); + List tokens = combineTokens(numbers, operators); + addRandomBrackets(tokens); + return String.join(" ", tokens); + } + + /** + * 生成操作数列表. + */ + private List generateNumbers() { + int ops = randInt(2, Constants.MAX_OPERANDS); + List numbers = new ArrayList<>(); + for (int i = 0; i < ops; i++) { + numbers.add(String.valueOf(randInt(1, Constants.MAX_VALUE))); + } + return numbers; + } + + /** + * 生成操作符列表,保证减法不产生负数. + */ + private List generateOperators(List numbers) { + List operators = new ArrayList<>(); + String current = numbers.getFirst(); + for (int i = 1; i < numbers.size(); i++) { + String op = randomBasicOperator(); + if ("-".equals(op)) { + int left = Integer.parseInt(current.replaceAll("[^0-9]", "")); + int right = Integer.parseInt(numbers.get(i)); + if (left < right) { + op = random.nextBoolean() ? "+" : "*"; + } + } + operators.add(op); + current = computeIntermediate(current, numbers.get(i), op); + } + return operators; + } + + /** + * 合成数字和运算符为 tokens 列表. + */ + private List combineTokens(List numbers, List operators) { + List tokens = new ArrayList<>(); + tokens.add(numbers.getFirst()); + for (int i = 1; i < numbers.size(); i++) { + tokens.add(operators.get(i - 1)); + tokens.add(numbers.get(i)); + } + return tokens; + } + + /** + * 随机为子表达式添加一对括号,不包裹整个表达式. + */ + private void addRandomBrackets(List tokens) { + int ops = (tokens.size() + 1) / 2; + if (ops < 2 || !random.nextBoolean()) { + return; + } + + int pairCount = ops - 1; + int chosenOpIndex = random.nextInt(pairCount); + int leftIndex = chosenOpIndex * 2; + int rightIndex = chosenOpIndex * 2 + 2; + + if (!(leftIndex == 0 && rightIndex == tokens.size() - 1)) { + tokens.add(leftIndex, "("); + tokens.add(rightIndex + 2, ")"); + } + } + + /** + * 计算当前中间值,用于保证减法不产生负数. + */ + private String computeIntermediate(String current, String next, String op) { + int c = Integer.parseInt(current); + int n = Integer.parseInt(next); + return switch (op) { + case "*" -> String.valueOf(c * n); + case "+" -> String.valueOf(c + n); + case "/" -> n != 0 ? String.valueOf(c / n) : String.valueOf(c); + default -> String.valueOf(c); // 避免减法导致负数 + }; + } +} -- 2.34.1