user_controller修改完成+为UserConfig表添加to_dict方法 #5

Merged
ppy4sjqvf merged 3 commits from ybw-branch into develop 2 months ago

@ -20,7 +20,7 @@ def list_demo_images():
# 获取演示原始图片 - 修正路径构建
# 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_au6pp.config['DEMO_ORIGINAL_FOLDER'])
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
if os.path.exists(original_folder):
original_files = glob.glob(os.path.join(original_folder, '*'))
@ -105,7 +105,7 @@ def get_demo_algorithms():
try:
# 从数据库获取扰动算法
perturbation_algorithms = []
perturbation_configs = PerturbationConfig.query.all()
perturbation_configs = Perturbation.query.all()
for config in perturbation_configs:
perturbation_algorithms.append({
'id': config.id,
@ -113,12 +113,11 @@ def get_demo_algorithms():
'name': config.method_name,
'type': 'perturbation',
'description': config.description,
'default_epsilon': float(config.default_epsilon) if config.default_epsilon else None
})
# 从数据库获取微调算法
finetune_algorithms = []
finetune_configs = FinetuneConfig.query.all()
finetune_configs = Finetune.query.all()
for config in finetune_configs:
finetune_algorithms.append({
'id': config.id,
@ -157,8 +156,8 @@ def get_demo_stats():
comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0
# 统计数据库中的算法数量
perturbation_count = PerturbationConfig.query.count()
finetune_count = FinetuneConfig.query.count()
perturbation_count = Perturbation.query.count()
finetune_count = Finetune.query.count()
total_algorithms = perturbation_count + finetune_count
return jsonify({

@ -6,7 +6,7 @@
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from app import db
from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig
from app.database import User, UserConfig, Perturbation, Finetune
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
user_bp = Blueprint('user', __name__)
@ -46,21 +46,18 @@ def update_user_config(current_user_id):
db.session.add(user_config)
# 更新配置字段
if 'preferred_perturbation_config_id' in data:
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
if 'perturbation_configs_id' in data:
user_config.perturbation_configs_id = data['perturbation_configs_id']
if 'preferred_epsilon' in data:
epsilon = float(data['preferred_epsilon'])
if 'perturbation_intensity' in data:
intensity = float(data['perturbation_intensity'])
if 0 < epsilon <= 255:
user_config.preferred_epsilon = epsilon
user_config.perturbation_intensity = intensity
else:
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
if 'preferred_finetune_config_id' in data:
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
if 'preferred_purification' in data:
user_config.preferred_purification = bool(data['preferred_purification'])
if 'finetune_config_id' in data:
user_config.finetune_config_id = data['finetune_config_id']
db.session.commit()
@ -78,8 +75,8 @@ def update_user_config(current_user_id):
def get_available_algorithms():
"""获取可用的算法列表"""
try:
perturbation_configs = PerturbationConfig.query.all()
finetune_configs = FinetuneConfig.query.all()
perturbation_configs = Perturbation.query.all()
finetune_configs = Finetune.query.all()
return jsonify({
'perturbation_algorithms': [
@ -88,7 +85,6 @@ def get_available_algorithms():
'method_code': config.method_code,
'method_name': config.method_name,
'description': config.description,
'default_epsilon': float(config.default_epsilon)
} for config in perturbation_configs
],
'finetune_methods': [
@ -109,15 +105,15 @@ def get_available_algorithms():
def get_user_stats(current_user_id):
"""获取用户统计信息"""
try:
from app.database import Batch, Image
from app.database import Task, Image
# 统计用户的任务和图片数量
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
total_tasks = Task.query.filter_by(user_id=current_user_id).count()
completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count()
processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count()
failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count()
total_images = Image.query.filter_by(user_id=current_user_id).count()
total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count()
return jsonify({
'stats': {

@ -171,6 +171,19 @@ class UserConfig(db.Model):
def __repr__(self):
return f'<UserConfig for UserID {self.user_id}>'
def to_dict(self):
return {
'user_configs_id': getattr(self, 'user_configs_id', None),
'user_id': self.user_id,
'data_type_id': self.data_type_id,
'perturbation_configs_id': self.perturbation_configs_id,
'perturbation_intensity': float(self.perturbation_intensity) if self.perturbation_intensity is not None else None,
'finetune_configs_id': self.finetune_configs_id,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
}
# ----------------------------
# 5. 核心任务表 (tasks)
# ----------------------------

Loading…
Cancel
Save