将lianghao_branch合并到develop #8

Merged
ppy4sjqvf merged 7 commits from lianghao_branch into develop 1 month ago

@ -42,14 +42,12 @@ def create_app(config_name=None):
from app.controllers.task_controller import task_bp
from app.controllers.image_controller import image_bp
from app.controllers.admin_controller import admin_bp
from app.controllers.demo_controller import demo_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(task_bp, url_prefix='/api/task')
app.register_blueprint(image_bp, url_prefix='/api/image')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
app.register_blueprint(demo_bp, url_prefix='/api/demo')
# 注册错误处理器
@app.errorhandler(404)

@ -1,176 +0,0 @@
"""
演示图片控制器
处理预设图像对比图的展示功能
"""
from flask import Blueprint, send_file, jsonify, current_app
from flask_jwt_extended import jwt_required
from app.database import Perturbation, Finetune
import os
import glob
demo_bp = Blueprint('demo', __name__)
@demo_bp.route('/images', methods=['GET'])
def list_demo_images():
"""获取所有演示图片列表"""
try:
demo_images = []
# 获取演示原始图片 - 修正路径构建
# 获取项目根目录backend目录
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
if os.path.exists(original_folder):
original_files = glob.glob(os.path.join(original_folder, '*'))
for file_path in original_files:
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
# 查找对应的加噪图片
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
perturbed_files = glob.glob(os.path.join(perturbed_folder, f"{name_without_ext}*"))
# 查找对应的对比图
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
comparison_files = glob.glob(os.path.join(comparison_folder, f"{name_without_ext}*"))
demo_image = {
'id': name_without_ext,
'name': name_without_ext,
'original': f"/api/demo/image/original/{filename}",
'perturbed': [f"/api/demo/image/perturbed/{os.path.basename(f)}" for f in perturbed_files],
'comparisons': [f"/api/demo/image/comparison/{os.path.basename(f)}" for f in comparison_files]
}
demo_images.append(demo_image)
return jsonify({
'demo_images': demo_images,
'total': len(demo_images)
}), 200
except Exception as e:
return jsonify({'error': f'获取演示图片列表失败: {str(e)}'}), 500
@demo_bp.route('/image/original/<filename>', methods=['GET'])
def get_demo_original_image(filename):
"""获取演示原始图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取原始图片失败: {str(e)}'}), 500
@demo_bp.route('/image/perturbed/<filename>', methods=['GET'])
def get_demo_perturbed_image(filename):
"""获取演示加噪图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取加噪图片失败: {str(e)}'}), 500
@demo_bp.route('/image/comparison/<filename>', methods=['GET'])
def get_demo_comparison_image(filename):
"""获取演示对比图片"""
try:
project_root = os.path.dirname(current_app.root_path)
file_path = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'], filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, as_attachment=False)
except Exception as e:
return jsonify({'error': f'获取对比图片失败: {str(e)}'}), 500
@demo_bp.route('/algorithms', methods=['GET'])
def get_demo_algorithms():
"""获取演示算法信息"""
try:
# 从数据库获取扰动算法
perturbation_algorithms = []
perturbation_configs = Perturbation.query.all()
for config in perturbation_configs:
perturbation_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'perturbation',
'description': config.description,
})
# 从数据库获取微调算法
finetune_algorithms = []
finetune_configs = Finetune.query.all()
for config in finetune_configs:
finetune_algorithms.append({
'id': config.id,
'code': config.method_code,
'name': config.method_name,
'type': 'finetune',
'description': config.description
})
return jsonify({
'perturbation_algorithms': perturbation_algorithms,
'finetune_algorithms': finetune_algorithms,
'evaluation_metrics': [
{'name': 'FID', 'description': 'Fréchet Inception Distance - 衡量图像质量的指标'},
{'name': 'LPIPS', 'description': 'Learned Perceptual Image Patch Similarity - 感知相似度'},
{'name': 'SSIM', 'description': 'Structural Similarity Index - 结构相似性指标'},
{'name': 'PSNR', 'description': 'Peak Signal-to-Noise Ratio - 峰值信噪比'}
]
}), 200
except Exception as e:
return jsonify({'error': f'获取算法信息失败: {str(e)}'}), 500
@demo_bp.route('/stats', methods=['GET'])
def get_demo_stats():
"""获取演示统计信息"""
try:
# 统计演示图片数量
project_root = os.path.dirname(current_app.root_path)
original_folder = os.path.join(project_root, current_app.config['DEMO_ORIGINAL_FOLDER'])
perturbed_folder = os.path.join(project_root, current_app.config['DEMO_PERTURBED_FOLDER'])
comparison_folder = os.path.join(project_root, current_app.config['DEMO_COMPARISONS_FOLDER'])
original_count = len(glob.glob(os.path.join(original_folder, '*'))) if os.path.exists(original_folder) else 0
perturbed_count = len(glob.glob(os.path.join(perturbed_folder, '*'))) if os.path.exists(perturbed_folder) else 0
comparison_count = len(glob.glob(os.path.join(comparison_folder, '*'))) if os.path.exists(comparison_folder) else 0
# 统计数据库中的算法数量
perturbation_count = Perturbation.query.count()
finetune_count = Finetune.query.count()
total_algorithms = perturbation_count + finetune_count
return jsonify({
'demo_stats': {
'original_images': original_count,
'perturbed_images': perturbed_count,
'comparison_images': comparison_count,
'supported_algorithms': total_algorithms,
'perturbation_algorithms': perturbation_count,
'finetune_algorithms': finetune_count,
'evaluation_metrics': 4
}
}), 200
except Exception as e:
return jsonify({'error': f'获取统计信息失败: {str(e)}'}), 500

@ -8,7 +8,7 @@ from flask import Blueprint, request, jsonify
from app import db
from app.controllers.auth_controller import int_jwt_required
from app.database import (
Task,
Task, TaskType, TaskStatus,
Perturbation, Finetune, Heatmap, Evaluate,
PerturbationConfig, FinetuneConfig, DataType,
Image
@ -21,6 +21,40 @@ task_bp = Blueprint('task', __name__)
# ==================== 通用任务接口 ====================
@task_bp.route('', methods=['GET'])
@int_jwt_required
def list_tasks(current_user_id):
"""根据任务类型与状态筛选通用任务列表"""
task_type_code = request.args.get('task_type', 'all') or 'all'
status_code = request.args.get('task_status', 'all') or 'all'
query = Task.query.filter(Task.user_id == current_user_id)
if task_type_code != 'all':
task_type = TaskType.query.filter_by(task_type_code=task_type_code).first()
if not task_type:
return TaskService.json_error('任务类型不存在', 400)
query = query.filter(Task.tasks_type_id == task_type.task_type_id)
if status_code != 'all':
task_status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not task_status:
return TaskService.json_error('任务状态不存在', 400)
query = query.filter(Task.tasks_status_id == task_status.task_status_id)
tasks = query.order_by(Task.created_at.desc()).all()
return jsonify({'tasks': [TaskService.serialize_task(task) for task in tasks]}), 200
@task_bp.route('/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_task(task_id, current_user_id):
"""获取当前用户的单个任务详情"""
task = TaskService.load_task_for_user(task_id, current_user_id)
if not task:
return TaskService.json_error('任务不存在或无权限', 404)
return jsonify({'task': TaskService.serialize_task(task)}), 200
@task_bp.route('/<int:task_id>/status', methods=['GET'])
@int_jwt_required
def get_task_status(task_id, current_user_id):
@ -42,6 +76,25 @@ def cancel_task(task_id, current_user_id):
return TaskService.json_error('取消任务失败', 500)
@task_bp.route('/quota', methods=['GET'])
@int_jwt_required
def get_task_quota(current_user_id):
user = TaskService.get_user(current_user_id)
if not user:
return TaskService.json_error('用户不存在', 404)
role = user.role
max_tasks = role.max_concurrent_tasks if role and role.max_concurrent_tasks is not None else 0
current_count = Task.query.filter_by(user_id=current_user_id).count()
remaining = max(max_tasks - current_count, 0)
return jsonify({
'max_tasks': max_tasks,
'current_tasks': current_count,
'remaining_tasks': remaining
}), 200
# ==================== 加噪任务 ====================
@task_bp.route('/perturbation/configs', methods=['GET'])
@ -70,8 +123,16 @@ def create_perturbation_task(current_user_id):
if not all([data_type_id, perturbation_configs_id, intensity]):
return TaskService.json_error('缺少必要的任务参数')
if not DataType.query.get(data_type_id):
user = TaskService.get_user(current_user_id)
if not user:
return TaskService.json_error('用户不存在', 404)
data_type = DataType.query.get(data_type_id)
if not data_type:
return TaskService.json_error('数据集类型不存在')
role_code = user.role.role_code if user.role else 'user'
if role_code in ('user', 'normal') and data_type.data_type_code != 'facial':
return TaskService.json_error('普通用户仅可使用人脸数据集', 403)
if not PerturbationConfig.query.get(perturbation_configs_id):
return TaskService.json_error('加噪配置不存在')
@ -358,9 +419,17 @@ def create_finetune_from_upload(current_user_id):
if not finetune_configs_id:
return TaskService.json_error('缺少必要参数: finetune_configs_id')
if not FinetuneConfig.query.get(finetune_configs_id):
finetune_config = FinetuneConfig.query.get(finetune_configs_id)
if not finetune_config:
return TaskService.json_error('微调配置不存在')
data_type_id = data.get('data_type_id')
if not data_type_id:
return TaskService.json_error('缺少必要参数: data_type_id')
data_type = DataType.query.get(data_type_id)
if not data_type:
return TaskService.json_error('数据集类型不存在')
try:
flow_id = data.get('flow_id')
if flow_id is not None:
@ -393,7 +462,7 @@ def create_finetune_from_upload(current_user_id):
finetune = Finetune(
tasks_id=task.tasks_id,
finetune_configs_id=finetune_configs_id,
data_type_id=data.get('data_type_id'),
data_type_id=data_type_id,
finetune_name=data.get('finetune_name')
)
db.session.add(finetune)
@ -464,6 +533,9 @@ def create_evaluate_task(current_user_id):
if not finetune_task:
return TaskService.json_error('微调任务不存在或无权限', 404)
if finetune_task.finetune and finetune_task.finetune.evaluation:
return TaskService.json_error('该微调任务已存在评估,请勿重复创建', 400)
# 仅允许基于加噪微调创建评估
if TaskService.determine_finetune_source(finetune_task) != 'perturbation':
return TaskService.json_error('数值评估仅支持基于加噪任务的微调结果')
@ -490,7 +562,7 @@ def create_evaluate_task(current_user_id):
evaluate = Evaluate(
tasks_id=task.tasks_id,
finetune_configs_id=finetune_task.finetune.finetune_configs_id,
finetune_task_id=finetune_task_id,
evaluate_name=data.get('evaluate_name')
)
db.session.add(evaluate)

@ -7,7 +7,7 @@
from flask import Blueprint, request, jsonify
from app import db
from app.controllers.auth_controller import int_jwt_required
from app.database import UserConfig, Task, TaskType, TaskStatus
from app.services.user_service import UserService
user_bp = Blueprint('user', __name__)
@ -17,55 +17,17 @@ def _json_error(message, status_code=400):
return jsonify({'error': message}), status_code
def _get_or_create_user_config(user_id):
config = UserConfig.query.filter_by(user_id=user_id).first()
if not config:
config = UserConfig(user_id=user_id)
db.session.add(config)
db.session.commit()
return config
def _serialize_config(config):
return {
'user_configs_id': config.user_configs_id,
'user_id': config.user_id,
'data_type_id': config.data_type_id,
'perturbation_configs_id': config.perturbation_configs_id,
'perturbation_intensity': config.perturbation_intensity,
'finetune_configs_id': config.finetune_configs_id,
'created_at': config.created_at.isoformat() if config.created_at else None,
'updated_at': config.updated_at.isoformat() if config.updated_at else None,
}
def _serialize_task(task):
status_code = task.task_status.task_status_code if task.task_status else None
task_type_code = task.task_type.task_type_code if task.task_type else None
return {
'task_id': task.tasks_id,
'flow_id': task.flow_id,
'task_type': task_type_code,
'status': status_code,
'created_at': task.created_at.isoformat() if task.created_at else None,
'started_at': task.started_at.isoformat() if task.started_at else None,
'finished_at': task.finished_at.isoformat() if task.finished_at else None,
'description': task.description,
'error_message': task.error_message
}
@user_bp.route('/config', methods=['GET'])
@int_jwt_required
def get_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
return jsonify({'config': _serialize_config(config)}), 200
config = UserService.get_or_create_user_config(current_user_id)
return jsonify({'config': UserService.serialize_config(config)}), 200
@user_bp.route('/config', methods=['PUT'])
@int_jwt_required
def update_user_config(current_user_id):
config = _get_or_create_user_config(current_user_id)
config = UserService.get_or_create_user_config(current_user_id)
data = request.get_json() or {}
allowed_fields = {'data_type_id', 'perturbation_configs_id', 'perturbation_intensity', 'finetune_configs_id'}
@ -80,40 +42,8 @@ def update_user_config(current_user_id):
try:
db.session.commit()
return jsonify({'message': '配置已更新', 'config': _serialize_config(config)}), 200
return jsonify({'message': '配置已更新', 'config': UserService.serialize_config(config)}), 200
except Exception as exc:
db.session.rollback()
return _json_error(f'更新配置失败: {exc}', 500)
@user_bp.route('/tasks', methods=['GET'])
@int_jwt_required
def list_user_tasks(current_user_id):
task_type_code = request.args.get('type')
status_code = request.args.get('status')
query = Task.query.filter_by(user_id=current_user_id)
if task_type_code:
task_type = TaskType.query.filter_by(task_type_code=task_type_code).first()
if not task_type:
return _json_error('任务类型不存在', 404)
query = query.filter(Task.tasks_type_id == task_type.task_type_id)
if status_code:
status = TaskStatus.query.filter_by(task_status_code=status_code).first()
if not status:
return _json_error('任务状态不存在', 404)
query = query.filter(Task.tasks_status_id == status.task_status_id)
tasks = query.order_by(Task.created_at.desc()).all()
return jsonify({'tasks': [_serialize_task(task) for task in tasks]}), 200
@user_bp.route('/tasks/<int:task_id>', methods=['GET'])
@int_jwt_required
def get_user_task(task_id, current_user_id):
task = Task.query.filter_by(tasks_id=task_id, user_id=current_user_id).first()
if not task:
return _json_error('任务不存在或无权限', 404)
return jsonify({'task': _serialize_task(task)}), 200

@ -249,6 +249,7 @@ class Finetune(db.Model):
task = db.relationship('Task', back_populates='finetune')
finetune_config = db.relationship('FinetuneConfig')
data_type = db.relationship('DataType')
evaluation = db.relationship('Evaluate', uselist=False, back_populates='finetune_task', cascade='all, delete-orphan')
def __repr__(self):
return f'<Finetune TaskID={self.tasks_id}>'
@ -275,12 +276,12 @@ class Evaluate(db.Model):
"""指标计算任务表"""
__tablename__ = 'evaluate'
tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联')
finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), nullable=False, comment='关联的微调配置')
finetune_task_id = db.Column(BigInteger, ForeignKey('finetune.tasks_id', ondelete='CASCADE'), unique=True, nullable=False, comment='关联的微调任务ID')
evaluate_name = db.Column(String(100))
evaluation_results_id = db.Column(BigInteger, ForeignKey('evaluation_results.evaluation_results_id'), unique=True, default=None, comment='关联的结果ID')
task = db.relationship('Task', back_populates='evaluation')
finetune_config = db.relationship('FinetuneConfig')
finetune_task = db.relationship('Finetune', back_populates='evaluation')
evaluation_result = db.relationship('EvaluationResult', backref='evaluate_task', uselist=False)
def __repr__(self):

@ -233,7 +233,7 @@ class TaskService:
}
elif task_type == 'evaluate' and task.evaluation:
base['evaluate'] = {
'finetune_configs_id': task.evaluation.finetune_configs_id,
'finetune_task_id': task.evaluation.finetune_task_id,
'evaluate_name': task.evaluation.evaluate_name,
'evaluation_results_id': task.evaluation.evaluation_results_id
}
@ -709,48 +709,23 @@ class TaskService:
logger.error(f"Evaluate task {task_id} not found")
return None
# 获取用户ID
sample_image = Image.query.filter_by(tasks_id=task_id).first()
if not sample_image:
logger.error(f"No images found for task {task_id}")
return None
user_id = sample_image.user_id
# 查找相同flow_id的Finetune任务
finetune_tasks = Task.query.filter(
Task.flow_id == task.flow_id,
Task.tasks_type_id == 2, # finetune类型
Task.tasks_id != task_id
).all()
if not finetune_tasks:
logger.error(f"No finetune task found for flow {task.flow_id}")
return None
# 从Evaluate任务获取需要的微调配置ID
if not evaluate.finetune_configs_id:
logger.error(f"Evaluate task {task_id} has no finetune_configs_id")
finetune = Finetune.query.get(evaluate.finetune_task_id)
if not finetune:
logger.error(f"Finetune task {evaluate.finetune_task_id} not found for evaluation {task_id}")
return None
# 查找使用相同微调算法的任务
target_finetune_task = None
for ft_task in finetune_tasks:
ft = Finetune.query.get(ft_task.tasks_id)
if ft and ft.finetune_configs_id == evaluate.finetune_configs_id:
target_finetune_task = ft_task
break
if not target_finetune_task:
logger.error(f"No finetune task with config {evaluate.finetune_configs_id} found for flow {task.flow_id}")
finetune_task = finetune.task
if not finetune_task:
logger.error(f"Finetune task {evaluate.finetune_task_id} missing Task relation")
return None
finetune_task = target_finetune_task
user_id = finetune_task.user_id
# 路径配置
clean_ref_dir = TaskService.get_original_images_path(user_id, task.flow_id)
clean_output_dir = TaskService.get_original_generated_path(user_id, task.flow_id, finetune_task.tasks_id)
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, task.flow_id, finetune_task.tasks_id)
output_dir = TaskService.get_evaluate_path(user_id, task.flow_id, task_id)
clean_ref_dir = TaskService.get_original_images_path(user_id, finetune_task.flow_id)
clean_output_dir = TaskService.get_original_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
perturbed_output_dir = TaskService.get_perturbed_generated_path(user_id, finetune_task.flow_id, finetune_task.tasks_id)
output_dir = TaskService.get_evaluate_path(user_id, finetune_task.flow_id, task_id)
# 加入RQ队列
from app.workers.evaluate_worker import run_evaluate_task

@ -0,0 +1,32 @@
"""用户服务层,封装用户配置相关逻辑。"""
from app import db
from app.database import UserConfig
class UserService:
"""用户相关业务逻辑"""
@staticmethod
def get_or_create_user_config(user_id):
"""获取或创建用户配置"""
config = UserConfig.query.filter_by(user_id=user_id).first()
if not config:
config = UserConfig(user_id=user_id)
db.session.add(config)
db.session.commit()
return config
@staticmethod
def serialize_config(config):
"""用户配置序列化"""
return {
'user_configs_id': config.user_configs_id,
'user_id': config.user_id,
'data_type_id': config.data_type_id,
'perturbation_configs_id': config.perturbation_configs_id,
'perturbation_intensity': config.perturbation_intensity,
'finetune_configs_id': config.finetune_configs_id,
'created_at': config.created_at.isoformat() if config.created_at else None,
'updated_at': config.updated_at.isoformat() if config.updated_at else None,
}
Loading…
Cancel
Save