diff --git a/src/backend/app/controllers/demo_controller.py b/src/backend/app/controllers/demo_controller.py index fba7c97..2dec6f8 100644 --- a/src/backend/app/controllers/demo_controller.py +++ b/src/backend/app/controllers/demo_controller.py @@ -20,7 +20,7 @@ def list_demo_images(): # 获取演示原始图片 - 修正路径构建 # 获取项目根目录(backend目录) project_root = os.path.dirname(current_app.root_path) - original_folder = os.path.join(project_root, current_au6pp.config['DEMO_ORIGINAL_FOLDER']) + original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER']) if os.path.exists(original_folder): original_files = glob.glob(os.path.join(original_folder, '*')) @@ -105,7 +105,7 @@ def get_demo_algorithms(): try: # 从数据库获取扰动算法 perturbation_algorithms = [] - perturbation_configs = PerturbationConfig.query.all() + perturbation_configs = Perturbation.query.all() for config in perturbation_configs: perturbation_algorithms.append({ 'id': config.id, @@ -113,12 +113,11 @@ def get_demo_algorithms(): 'name': config.method_name, 'type': 'perturbation', 'description': config.description, - 'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None }) # 从数据库获取微调算法 finetune_algorithms = [] - finetune_configs = FinetuneConfig.query.all() + finetune_configs = Finetune.query.all() for config in finetune_configs: finetune_algorithms.append({ 'id': config.id, @@ -157,8 +156,8 @@ def get_demo_stats(): comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0 # 统计数据库中的算法数量 - perturbation_count = PerturbationConfig.query.count() - finetune_count = FinetuneConfig.query.count() + perturbation_count = Perturbation.query.count() + finetune_count = Finetune.query.count() total_algorithms = perturbation_count + finetune_count return jsonify({ diff --git a/src/backend/app/controllers/user_controller.py b/src/backend/app/controllers/user_controller.py index 2b680ed..3d99fda 100644 --- a/src/backend/app/controllers/user_controller.py +++ b/src/backend/app/controllers/user_controller.py @@ -6,7 +6,7 @@ from flask import Blueprint, request, jsonify from flask_jwt_extended import jwt_required from app import db -from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig +from app.database import User, UserConfig, Perturbation, Finetune from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器 user_bp = Blueprint('user', __name__) @@ -46,21 +46,18 @@ def update_user_config(current_user_id): db.session.add(user_config) # 更新配置字段 - if 'preferred_perturbation_config_id' in data: - user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id'] + if 'perturbation_configs_id' in data: + user_config.perturbation_configs_id = data['perturbation_configs_id'] - if 'preferred_epsilon' in data: - epsilon = float(data['preferred_epsilon']) + if 'perturbation_intensity' in data: + intensity = float(data['perturbation_intensity']) if 0 < epsilon <= 255: - user_config.preferred_epsilon = epsilon + user_config.perturbation_intensity = intensity else: return jsonify({'error': '扰动强度必须在0-255之间'}), 400 - if 'preferred_finetune_config_id' in data: - user_config.preferred_finetune_config_id = data['preferred_finetune_config_id'] - - if 'preferred_purification' in data: - user_config.preferred_purification = bool(data['preferred_purification']) + if 'finetune_config_id' in data: + user_config.finetune_config_id = data['finetune_config_id'] db.session.commit() @@ -78,8 +75,8 @@ def update_user_config(current_user_id): def get_available_algorithms(): """获取可用的算法列表""" try: - perturbation_configs = PerturbationConfig.query.all() - finetune_configs = FinetuneConfig.query.all() + perturbation_configs = Perturbation.query.all() + finetune_configs = Finetune.query.all() return jsonify({ 'perturbation_algorithms': [ @@ -88,7 +85,6 @@ def get_available_algorithms(): 'method_code': config.method_code, 'method_name': config.method_name, 'description': config.description, - 'default_epsilon': float(config.default_epsilon) } for config in perturbation_configs ], 'finetune_methods': [ @@ -109,15 +105,15 @@ def get_available_algorithms(): def get_user_stats(current_user_id): """获取用户统计信息""" try: - from app.database import Batch, Image + from app.database import Task, Image # 统计用户的任务和图片数量 - total_tasks = Batch.query.filter_by(user_id=current_user_id).count() - completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count() - processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count() - failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count() + total_tasks = Task.query.filter_by(user_id=current_user_id).count() + completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count() + processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count() + failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count() - total_images = Image.query.filter_by(user_id=current_user_id).count() + total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count() return jsonify({ 'stats': { diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index d69d9a6..8349728 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -171,6 +171,19 @@ class UserConfig(db.Model): def __repr__(self): return f'' + def to_dict(self): + return { + 'user_configs_id': getattr(self, 'user_configs_id', None), + 'user_id': self.user_id, + 'data_type_id': self.data_type_id, + 'perturbation_configs_id': self.perturbation_configs_id, + 'perturbation_intensity': float(self.perturbation_intensity) if self.perturbation_intensity is not None else None, + 'finetune_configs_id': self.finetune_configs_id, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None, + + } + # ---------------------------- # 5. 核心任务表 (tasks) # ----------------------------