将lianghao_branch合并到develop #7
Merged
hnu202326010204
merged 15 commits from lianghao_branch into develop 1 month ago
File diff suppressed because it is too large
Load Diff
@ -1,129 +1,119 @@
|
||||
"""
|
||||
用户管理控制器
|
||||
处理用户配置等功能
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask_jwt_extended import jwt_required
|
||||
from app import db
|
||||
from app.database import User, UserConfig, Perturbation, Finetune
|
||||
from app.controllers.auth_controller import int_jwt_required # 导入JWT装饰器
|
||||
|
||||
user_bp = Blueprint('user', __name__)
|
||||
|
||||
@user_bp.route('/config', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_config(current_user_id):
|
||||
"""获取用户配置"""
|
||||
try:
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
# 如果没有配置,创建默认配置
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/config', methods=['PUT'])
|
||||
@int_jwt_required
|
||||
def update_user_config(current_user_id):
|
||||
"""更新用户配置"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
user_config = UserConfig.query.filter_by(user_id=current_user_id).first()
|
||||
|
||||
if not user_config:
|
||||
user_config = UserConfig(user_id=current_user_id)
|
||||
db.session.add(user_config)
|
||||
|
||||
# 更新配置字段
|
||||
if 'perturbation_configs_id' in data:
|
||||
user_config.perturbation_configs_id = data['perturbation_configs_id']
|
||||
|
||||
if 'perturbation_intensity' in data:
|
||||
intensity = float(data['perturbation_intensity'])
|
||||
if 0 < epsilon <= 255:
|
||||
user_config.perturbation_intensity = intensity
|
||||
else:
|
||||
return jsonify({'error': '扰动强度必须在0-255之间'}), 400
|
||||
|
||||
if 'finetune_config_id' in data:
|
||||
user_config.finetune_config_id = data['finetune_config_id']
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'message': '用户配置更新成功',
|
||||
'config': user_config.to_dict()
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return jsonify({'error': f'更新用户配置失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/algorithms', methods=['GET'])
|
||||
@jwt_required()
|
||||
def get_available_algorithms():
|
||||
"""获取可用的算法列表"""
|
||||
try:
|
||||
perturbation_configs = Perturbation.query.all()
|
||||
finetune_configs = Finetune.query.all()
|
||||
|
||||
return jsonify({
|
||||
'perturbation_algorithms': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description,
|
||||
} for config in perturbation_configs
|
||||
],
|
||||
'finetune_methods': [
|
||||
{
|
||||
'id': config.id,
|
||||
'method_code': config.method_code,
|
||||
'method_name': config.method_name,
|
||||
'description': config.description
|
||||
} for config in finetune_configs
|
||||
]
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取算法列表失败: {str(e)}'}), 500
|
||||
|
||||
@user_bp.route('/stats', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_stats(current_user_id):
|
||||
"""获取用户统计信息"""
|
||||
try:
|
||||
from app.database import Task, Image
|
||||
|
||||
# 统计用户的任务和图片数量
|
||||
total_tasks = Task.query.filter_by(user_id=current_user_id).count()
|
||||
completed_tasks = Task.query.filter_by(user_id=current_user_id, status='completed').count()
|
||||
processing_tasks = Task.query.filter_by(user_id=current_user_id, status='processing').count()
|
||||
failed_tasks = Task.query.filter_by(user_id=current_user_id, status='failed').count()
|
||||
|
||||
total_images = Image.query.join(Task, Image.task_id == Task.id).filter(Task.user_id == current_user_id).count()
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'total_tasks': total_tasks,
|
||||
'completed_tasks': completed_tasks,
|
||||
'processing_tasks': processing_tasks,
|
||||
'failed_tasks': failed_tasks,
|
||||
'total_images': total_images
|
||||
}
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'获取用户统计失败: {str(e)}'}), 500
|
||||
|
||||
"""
|
||||
用户管理控制器
|
||||
负责用户配置、任务汇总等接口
|
||||
"""
|
||||
|
||||
from flask import Blueprint, request, jsonify
|
||||
from app import db
|
||||
from app.controllers.auth_controller import int_jwt_required
|
||||
from app.database import UserConfig, Task, TaskType, TaskStatus
|
||||
|
||||
|
||||
user_bp = Blueprint('user', __name__)
|
||||
|
||||
|
||||
def _json_error(message, status_code=400):
|
||||
return jsonify({'error': message}), status_code
|
||||
|
||||
|
||||
def _get_or_create_user_config(user_id):
|
||||
config = UserConfig.query.filter_by(user_id=user_id).first()
|
||||
if not config:
|
||||
config = UserConfig(user_id=user_id)
|
||||
db.session.add(config)
|
||||
db.session.commit()
|
||||
return config
|
||||
|
||||
|
||||
def _serialize_config(config):
|
||||
return {
|
||||
'user_configs_id': config.user_configs_id,
|
||||
'user_id': config.user_id,
|
||||
'data_type_id': config.data_type_id,
|
||||
'perturbation_configs_id': config.perturbation_configs_id,
|
||||
'perturbation_intensity': config.perturbation_intensity,
|
||||
'finetune_configs_id': config.finetune_configs_id,
|
||||
'created_at': config.created_at.isoformat() if config.created_at else None,
|
||||
'updated_at': config.updated_at.isoformat() if config.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_task(task):
|
||||
status_code = task.task_status.task_status_code if task.task_status else None
|
||||
task_type_code = task.task_type.task_type_code if task.task_type else None
|
||||
return {
|
||||
'task_id': task.tasks_id,
|
||||
'flow_id': task.flow_id,
|
||||
'task_type': task_type_code,
|
||||
'status': status_code,
|
||||
'created_at': task.created_at.isoformat() if task.created_at else None,
|
||||
'started_at': task.started_at.isoformat() if task.started_at else None,
|
||||
'finished_at': task.finished_at.isoformat() if task.finished_at else None,
|
||||
'description': task.description,
|
||||
'error_message': task.error_message
|
||||
}
|
||||
|
||||
|
||||
@user_bp.route('/config', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_config(current_user_id):
|
||||
config = _get_or_create_user_config(current_user_id)
|
||||
return jsonify({'config': _serialize_config(config)}), 200
|
||||
|
||||
|
||||
@user_bp.route('/config', methods=['PUT'])
|
||||
@int_jwt_required
|
||||
def update_user_config(current_user_id):
|
||||
config = _get_or_create_user_config(current_user_id)
|
||||
data = request.get_json() or {}
|
||||
|
||||
allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'}
|
||||
for key, value in data.items():
|
||||
if key in allowed_fields:
|
||||
if key == 'perturbation_intensity' and value is not None:
|
||||
try:
|
||||
value = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return _json_error('perturbation_intensity 参数格式不正确')
|
||||
setattr(config, key, value)
|
||||
|
||||
try:
|
||||
db.session.commit()
|
||||
return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200
|
||||
except Exception as exc:
|
||||
db.session.rollback()
|
||||
return _json_error(f'更新配置失败: {exc}', 500)
|
||||
|
||||
|
||||
@user_bp.route('/tasks', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def list_user_tasks(current_user_id):
|
||||
task_type_code = request.args.get('type')
|
||||
status_code = request.args.get('status')
|
||||
|
||||
query = Task.query.filter_by(user_id=current_user_id)
|
||||
|
||||
if task_type_code:
|
||||
task_type = TaskType.query.filter_by(task_type_code=task_type_code).first()
|
||||
if not task_type:
|
||||
return _json_error('任务类型不存在', 404)
|
||||
query = query.filter(Task.tasks_type_id == task_type.task_type_id)
|
||||
|
||||
if status_code:
|
||||
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
|
||||
if not status:
|
||||
return _json_error('任务状态不存在', 404)
|
||||
query = query.filter(Task.tasks_status_id == status.task_status_id)
|
||||
|
||||
tasks = query.order_by(Task.created_at.desc()).all()
|
||||
return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200
|
||||
|
||||
|
||||
@user_bp.route('/tasks/<int:task_id>', methods=['GET'])
|
||||
@int_jwt_required
|
||||
def get_user_task(task_id, current_user_id):
|
||||
task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first()
|
||||
if not task:
|
||||
return _json_error('任务不存在或无权限', 404)
|
||||
return jsonify({'task': _serialize_task(task)}), 200
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
"""
|
||||
认证服务
|
||||
处理用户认证相关逻辑
|
||||
"""
|
||||
|
||||
from app.database import User
|
||||
|
||||
class AuthService:
|
||||
"""认证服务类"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(username, password):
|
||||
"""验证用户凭据"""
|
||||
user = User.query.filter_by(username=username).first()
|
||||
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
return User.query.get(user_id)
|
||||
|
||||
@staticmethod
|
||||
def is_email_available(email):
|
||||
"""检查邮箱是否可用"""
|
||||
return User.query.filter_by(email=email).first() is None
|
||||
|
||||
@staticmethod
|
||||
def is_username_available(username):
|
||||
"""检查用户名是否可用"""
|
||||
return User.query.filter_by(username=username).first() is None
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue