|
|
|
|
@ -6,7 +6,7 @@
|
|
|
|
|
from flask import Blueprint, request, jsonify
|
|
|
|
|
from flask_jwt_extended import jwt_required
|
|
|
|
|
from app import db
|
|
|
|
|
from app.database import User, UserConfig, PerturbationConfig, FinetuneConfig
|
|
|
|
|
from app.database import User, UserConfig, Perturbation, Finetune
|
|
|
|
|
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
|
|
|
|
|
|
|
|
|
|
user_bp = Blueprint('user', __name__)
|
|
|
|
|
@ -46,21 +46,18 @@ def update_user_config(current_user_id):
|
|
|
|
|
db.session.add(user_config)
|
|
|
|
|
|
|
|
|
|
# 更新配置字段
|
|
|
|
|
if 'preferred_perturbation_config_id' in data:
|
|
|
|
|
user_config.preferred_perturbation_config_id = data['preferred_perturbation_config_id']
|
|
|
|
|
if 'perturbation_configs_id' in data:
|
|
|
|
|
user_config.perturbation_configs_id = data['perturbation_configs_id']
|
|
|
|
|
|
|
|
|
|
if 'preferred_epsilon' in data:
|
|
|
|
|
epsilon = float(data['preferred_epsilon'])
|
|
|
|
|
if 'perturbation_intensity' in data:
|
|
|
|
|
intensity = float(data['perturbation_intensity'])
|
|
|
|
|
if 0 < epsilon <= 255:
|
|
|
|
|
user_config.preferred_epsilon = epsilon
|
|
|
|
|
user_config.perturbation_intensity = intensity
|
|
|
|
|
else:
|
|
|
|
|
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
|
|
|
|
|
|
|
|
|
if 'preferred_finetune_config_id' in data:
|
|
|
|
|
user_config.preferred_finetune_config_id = data['preferred_finetune_config_id']
|
|
|
|
|
|
|
|
|
|
if 'preferred_purification' in data:
|
|
|
|
|
user_config.preferred_purification = bool(data['preferred_purification'])
|
|
|
|
|
if 'finetune_config_id' in data:
|
|
|
|
|
user_config.finetune_config_id = data['finetune_config_id']
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
@ -78,8 +75,8 @@ def update_user_config(current_user_id):
|
|
|
|
|
def get_available_algorithms():
|
|
|
|
|
"""获取可用的算法列表"""
|
|
|
|
|
try:
|
|
|
|
|
perturbation_configs = PerturbationConfig.query.all()
|
|
|
|
|
finetune_configs = FinetuneConfig.query.all()
|
|
|
|
|
perturbation_configs = Perturbation.query.all()
|
|
|
|
|
finetune_configs = Finetune.query.all()
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
'perturbation_algorithms': [
|
|
|
|
|
@ -88,7 +85,6 @@ def get_available_algorithms():
|
|
|
|
|
'method_code': config.method_code,
|
|
|
|
|
'method_name': config.method_name,
|
|
|
|
|
'description': config.description,
|
|
|
|
|
'default_epsilon': float(config.default_epsilon)
|
|
|
|
|
} for config in perturbation_configs
|
|
|
|
|
],
|
|
|
|
|
'finetune_methods': [
|
|
|
|
|
@ -109,15 +105,15 @@ def get_available_algorithms():
|
|
|
|
|
def get_user_stats(current_user_id):
|
|
|
|
|
"""获取用户统计信息"""
|
|
|
|
|
try:
|
|
|
|
|
from app.database import Batch, Image
|
|
|
|
|
from app.database import Task, Image
|
|
|
|
|
|
|
|
|
|
# 统计用户的任务和图片数量
|
|
|
|
|
total_tasks = Batch.query.filter_by(user_id=current_user_id).count()
|
|
|
|
|
completed_tasks = Batch.query.filter_by(user_id=current_user_id, status='completed').count()
|
|
|
|
|
processing_tasks = Batch.query.filter_by(user_id=current_user_id, status='processing').count()
|
|
|
|
|
failed_tasks = Batch.query.filter_by(user_id=current_user_id, status='failed').count()
|
|
|
|
|
total_tasks = Task.query.filter_by(user_id=current_user_id).count()
|
|
|
|
|
completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count()
|
|
|
|
|
processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count()
|
|
|
|
|
failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count()
|
|
|
|
|
|
|
|
|
|
total_images = Image.query.filter_by(user_id=current_user_id).count()
|
|
|
|
|
total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count()
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
'stats': {
|
|
|
|
|
|