diff --git a/.gitignore b/.gitignore index ef6693c..86aa95f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ logs/ *.log # 上传文件临时目录 -uploads/ \ No newline at end of file +uploads/ + +.github/ \ No newline at end of file diff --git a/src/backend/app/controllers/task_controller.py b/src/backend/app/controllers/task_controller.py index f534fec..f49e668 100644 --- a/src/backend/app/controllers/task_controller.py +++ b/src/backend/app/controllers/task_controller.py @@ -7,7 +7,7 @@ from flask import Blueprint, request, jsonify, current_app from flask_jwt_extended import jwt_required, get_jwt_identity from werkzeug.utils import secure_filename from app import db -from app.database import User, Batch, Image, ImageType, UserConfig, FinetuneBatch, FinetuneConfig +from app.database import User, Role, PerturbationConfig, FinetuneConfig, UserConfig, Image, ImageType, DataType, TaskType, TaskStatus, Task, Perturbation, Finetune, EvaluationResult, Evaluate, Heatmap from app.services.task_service import TaskService from app.services.image_service import ImageService from app.utils.file_utils import allowed_file, save_uploaded_file diff --git a/src/backend/app/database/__init__.py b/src/backend/app/database/__init__.py index eabbe24..b17c250 100644 --- a/src/backend/app/database/__init__.py +++ b/src/backend/app/database/__init__.py @@ -7,25 +7,43 @@ from datetime import datetime from app import db from werkzeug.security import generate_password_hash, check_password_hash from enum import Enum as PyEnum +from sqlalchemy import Integer, String, Text, DateTime, Boolean, ForeignKey, Float, BigInteger +# ---------------------------- +# 1. 角色表 (role) +# ---------------------------- +class Role(db.Model): + """权限/角色表""" + __tablename__ = 'role' + + role_id = db.Column(Integer, primary_key=True, comment='角色ID') + name = db.Column(String(50), nullable=False, comment='角色名称 (e.g., Administrator, VIP, User)') + max_concurrent_tasks = db.Column(Integer, default=1, comment='最大并发任务数') + description = db.Column(Text, comment='角色描述') + + def __repr__(self): + return f'' + +# ---------------------------- +# 2. 用户表 (users) +# ---------------------------- class User(db.Model): """用户表""" __tablename__ = 'users' - id = db.Column(db.BigInteger, primary_key=True) - username = db.Column(db.String(50), unique=True, nullable=False) - password_hash = db.Column(db.String(255), nullable=False) - email = db.Column(db.String(100)) - role = db.Column(db.Enum('user', 'admin'), default='user') - max_concurrent_tasks = db.Column(db.Integer, nullable=False, default=0) - created_at = db.Column(db.DateTime, default=datetime.utcnow) - updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - is_active = db.Column(db.Boolean, default=True) - + user_id = db.Column(Integer, primary_key=True, autoincrement=True, comment='用户ID') + username = db.Column(String(50), unique=True, nullable=False, comment='用户名') + password_hash = db.Column(String(255), nullable=False, comment='密码哈希') + email = db.Column(String(100), unique=True, nullable=False, index=True, comment='邮箱') + role_id = db.Column(Integer, ForeignKey('role.role_id'), nullable=False, comment='外键关联role表') + is_active = db.Column(Boolean, default=True, comment='是否激活') + created_at = db.Column(DateTime, default=datetime.utcnow, comment='创建时间') + updated_at = db.Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment='更新时间') + # 关系 - batches = db.relationship('Batch', backref='user', lazy='dynamic', cascade='all, delete-orphan') - images = db.relationship('Image', backref='user', lazy='dynamic', cascade='all, delete-orphan') - user_config = db.relationship('UserConfig', backref='user', uselist=False, cascade='all, delete-orphan') + role = db.relationship('Role', backref=db.backref('users', lazy='dynamic')) + user_config = db.relationship('UserConfig', uselist=False, backref='user', cascade='all, delete-orphan') + tasks = db.relationship('Task', backref='user', lazy='dynamic', cascade='all, delete-orphan') def set_password(self, password): """设置密码""" @@ -34,245 +52,253 @@ class User(db.Model): def check_password(self, password): """验证密码""" return check_password_hash(self.password_hash, password) - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'username': self.username, - 'email': self.email, - 'role': self.role, - 'max_concurrent_tasks': self.max_concurrent_tasks, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'is_active': self.is_active - } -class ImageType(db.Model): - """图片类型表""" - __tablename__ = 'image_types' - - id = db.Column(db.BigInteger, primary_key=True) - type_code = db.Column(db.Enum('original', 'perturbed', 'original_generate', 'perturbed_generate'), - unique=True, nullable=False) - type_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text) - - # 关系 - images = db.relationship('Image', backref='image_type', lazy='dynamic') + def __repr__(self): + return f'' + +# ---------------------------- +# 3. 配置字典表 (Configs & Types) +# ---------------------------- class PerturbationConfig(db.Model): - """加噪算法表""" + """加噪算法配置表""" __tablename__ = 'perturbation_configs' + perturbation_configs_id = db.Column(Integer, primary_key=True, autoincrement=True) + perturbation_code = db.Column(String(50), nullable=False, comment='算法代号') + perturbation_name = db.Column(String(100), nullable=False, comment='算法名称') + description = db.Column(Text) + + def __repr__(self): + return f'' + +class DataType(db.Model): + """数据集类型表""" + __tablename__ = 'data_type' + data_type_id = db.Column(Integer, primary_key=True) + data_type_code = db.Column(String(50), nullable=False) + data_type_prompt = db.Column(Text, comment='数据集相关的Prompt') + description = db.Column(Text) - id = db.Column(db.BigInteger, primary_key=True) - method_code = db.Column(db.String(50), unique=True, nullable=False) - method_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text, nullable=False) - default_epsilon = db.Column(db.Numeric(5, 2), nullable=False) - - # 关系 - batches = db.relationship('Batch', backref='perturbation_config', lazy='dynamic') - user_configs = db.relationship('UserConfig', backref='preferred_perturbation_config', lazy='dynamic') + def __repr__(self): + return f'' class FinetuneConfig(db.Model): - """微调方式表""" + """微调方式配置表""" __tablename__ = 'finetune_configs' - - id = db.Column(db.BigInteger, primary_key=True) - method_code = db.Column(db.String(50), unique=True, nullable=False) - method_name = db.Column(db.String(100), nullable=False) - description = db.Column(db.Text, nullable=False) - + finetune_configs_id = db.Column(Integer, primary_key=True, autoincrement=True) + finetune_code = db.Column(String(50), nullable=False) + finetune_name = db.Column(String(100), nullable=False) + description = db.Column(Text) + + def __repr__(self): + return f'' + +class TaskType(db.Model): + """任务类型表""" + __tablename__ = 'task_type' + task_type_id = db.Column(Integer, primary_key=True) + task_type_code = db.Column(String(50), nullable=False, comment='任务类型代码') + task_type_name = db.Column(String(100), nullable=False) + description = db.Column(Text) + + def __repr__(self): + return f'' + +class TaskStatus(db.Model): + """任务状态表""" + __tablename__ = 'task_status' + task_status_id = db.Column(Integer, primary_key=True, autoincrement=True) + task_status_code = db.Column(String(50), nullable=False, comment='状态代码 (Pending, Processing, Done, Failed)') + task_status_name = db.Column(String(100), nullable=False) + description = db.Column(Text) + + def __repr__(self): + return f'' + +class ImageType(db.Model): + """图片类型表 (X, X'', Y, Y''等)""" + __tablename__ = 'image_types' + image_types_id = db.Column(Integer, primary_key=True, autoincrement=True) + image_code = db.Column(String(50), nullable=False) + image_name = db.Column(String(100), nullable=False) + description = db.Column(Text) + + def __repr__(self): + return f'' + +# ---------------------------- +# 4. 用户配置表 (user_configs) +# ---------------------------- +class UserConfig(db.Model): + """用户偏好配置表""" + __tablename__ = 'user_configs' + user_configs_id = db.Column(Integer, primary_key=True, autoincrement=True) + user_id = db.Column(Integer, ForeignKey('users.user_id'), unique=True, nullable=False, index=True) + data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='默认数据集') + perturbation_configs_id = db.Column(Integer, ForeignKey('perturbation_configs.perturbation_configs_id'), default=None, comment='默认加噪算法') + perturbation_intensity = db.Column(Float, default=None, comment='默认扰动强度') + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), default=None, comment='默认微调方式') + created_at = db.Column(DateTime, default=datetime.utcnow) + updated_at = db.Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + # 关系 - finetune_tasks = db.relationship('FinetuneBatch', backref='finetune_config', lazy='dynamic') - user_configs = db.relationship('UserConfig', backref='preferred_finetune_config', lazy='dynamic') - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'method_code': self.method_code, - 'method_name': self.method_name, - 'description': self.description - } - -class Batch(db.Model): - """加噪批次表""" - __tablename__ = 'batch' - - id = db.Column(db.BigInteger, primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) - batch_name = db.Column(db.String(128)) - - # 加噪配置 - perturbation_config_id = db.Column(db.BigInteger, db.ForeignKey('perturbation_configs.id'), - nullable=False, default=1) - preferred_epsilon = db.Column(db.Numeric(5, 2), nullable=False, default=8.0) - - # 净化配置 - use_strong_protection = db.Column(db.Boolean, nullable=False, default=False) + data_type = db.relationship('DataType') + perturbation_config = db.relationship('PerturbationConfig') + finetune_config = db.relationship('FinetuneConfig') + + def __repr__(self): + return f'' + +# ---------------------------- +# 5. 核心任务表 (tasks) +# ---------------------------- +class Task(db.Model): + """任务总表""" + __tablename__ = 'tasks' + tasks_id = db.Column(BigInteger, primary_key=True, autoincrement=True, comment='任务ID') + tasks_type_id = db.Column(Integer, ForeignKey('task_type.task_type_id'), nullable=False, comment='任务类型') + user_id = db.Column(Integer, ForeignKey('users.user_id'), nullable=False, index=True, comment='归属用户') + tasks_status_id = db.Column(Integer, ForeignKey('task_status.task_status_id'), nullable=False, comment='任务状态ID') + created_at = db.Column(DateTime, default=datetime.utcnow) + started_at = db.Column(DateTime, default=None) + finished_at = db.Column(DateTime, default=None) + error_message = db.Column(Text, comment='错误信息') + description = db.Column(Text, comment='任务描述') + + # 关系 + task_type = db.relationship('TaskType', backref='tasks') + task_status = db.relationship('TaskStatus', backref='tasks') + images = db.relationship('Image', backref='task', lazy='dynamic', cascade='all, delete-orphan') - # 任务状态 - status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending') - created_at = db.Column(db.DateTime, default=datetime.utcnow) - started_at = db.Column(db.DateTime) - completed_at = db.Column(db.DateTime) - error_message = db.Column(db.Text) - result_path = db.Column(db.String(500)) + # --- 变更部分 --- + # 与子表的一对一关系 (perturbation, heatmap) + perturbation = db.relationship('Perturbation', uselist=False, back_populates='task', cascade='all, delete-orphan') + heatmap = db.relationship('Heatmap', uselist=False, back_populates='task', cascade='all, delete-orphan') + # 与子表的一对多关系 (finetune, evaluate) + finetunes = db.relationship('Finetune', back_populates='task', cascade='all, delete-orphan') + evaluations = db.relationship('Evaluate', back_populates='task', cascade='all, delete-orphan') + # --- 变更结束 --- + + def __repr__(self): + return f'' + +# ---------------------------- +# 6. 任务子表:加噪任务 (perturbation) +# ---------------------------- +class Perturbation(db.Model): + """加噪任务详情表""" + __tablename__ = 'perturbation' + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表1:1关联') + data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), nullable=False, comment='所选数据集') + perturbation_name = db.Column(String(100), comment='加噪任务自定义名称') + perturbation_configs_id = db.Column(Integer, ForeignKey('perturbation_configs.perturbation_configs_id'), nullable=False, comment='使用的算法') + perturbation_intensity = db.Column(Float, nullable=False, comment='扰动强度') + # 关系 - images = db.relationship('Image', backref='batch', lazy='dynamic') - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'batch_name': self.batch_name, - 'status': self.status, - 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, - 'use_strong_protection': self.use_strong_protection, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'error_message': self.error_message, - 'perturbation_config': self.perturbation_config.method_name if self.perturbation_config else None - } + task = db.relationship('Task', back_populates='perturbation') + data_type = db.relationship('DataType') + perturbation_config = db.relationship('PerturbationConfig') -class Image(db.Model): - """图片表""" - __tablename__ = 'images' - - id = db.Column(db.BigInteger, primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) - batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id')) - father_id = db.Column(db.BigInteger, db.ForeignKey('images.id')) - original_filename = db.Column(db.String(255)) - stored_filename = db.Column(db.String(255), unique=True, nullable=False) - file_path = db.Column(db.String(500), nullable=False) - file_size = db.Column(db.BigInteger) - image_type_id = db.Column(db.BigInteger, db.ForeignKey('image_types.id'), nullable=False) - width = db.Column(db.Integer) - height = db.Column(db.Integer) - upload_time = db.Column(db.DateTime, default=datetime.utcnow) - - # 自引用关系 - children = db.relationship('Image', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') - - # 评估结果关系 - reference_evaluations = db.relationship('EvaluationResult', - foreign_keys='EvaluationResult.reference_image_id', - backref='reference_image', lazy='dynamic') - target_evaluations = db.relationship('EvaluationResult', - foreign_keys='EvaluationResult.target_image_id', - backref='target_image', lazy='dynamic') + def __repr__(self): + return f'' + +# ---------------------------- +# 7. 任务子表:微调任务 (finetune) - [已更新为复合主键] +# ---------------------------- +class Finetune(db.Model): + """微调任务详情表""" + __tablename__ = 'finetune' + # --- 变更部分:复合主键 --- + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True, comment='与tasks表关联') + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='微调配置ID') + # --- 变更结束 --- + data_type_id = db.Column(Integer, ForeignKey('data_type.data_type_id'), default=None, comment='微调所用数据集') + finetune_name = db.Column(String(100), comment='微调任务名称') - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'original_filename': self.original_filename, - 'stored_filename': self.stored_filename, - 'file_path': self.file_path, - 'file_size': self.file_size, - 'width': self.width, - 'height': self.height, - 'upload_time': self.upload_time.isoformat() if self.upload_time else None, - 'image_type': self.image_type.type_name if self.image_type else None, - 'batch_id': self.batch_id - } + # --- 变更部分:更新 back_populates --- + task = db.relationship('Task', back_populates='finetunes') + # --- 变更结束 --- + finetune_config = db.relationship('FinetuneConfig') + data_type = db.relationship('DataType') + def __repr__(self): + return f'' + +# ---------------------------- +# 8. 评估结果表 (evaluation_results) +# ---------------------------- class EvaluationResult(db.Model): - """评估结果表""" + """评估结果数据表""" __tablename__ = 'evaluation_results' - - id = db.Column(db.BigInteger, primary_key=True) - reference_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) - target_image_id = db.Column(db.BigInteger, db.ForeignKey('images.id'), nullable=False) - evaluation_type = db.Column(db.Enum('image_quality', 'model_generation'), nullable=False) - purification_applied = db.Column(db.Boolean, default=False) - fid_score = db.Column(db.Numeric(8, 4)) - lpips_score = db.Column(db.Numeric(8, 4)) - ssim_score = db.Column(db.Numeric(8, 4)) - psnr_score = db.Column(db.Numeric(8, 4)) - heatmap_path = db.Column(db.String(500)) - evaluated_at = db.Column(db.DateTime, default=datetime.utcnow) - - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'evaluation_type': self.evaluation_type, - 'purification_applied': self.purification_applied, - 'fid_score': float(self.fid_score) if self.fid_score else None, - 'lpips_score': float(self.lpips_score) if self.lpips_score else None, - 'ssim_score': float(self.ssim_score) if self.ssim_score else None, - 'psnr_score': float(self.psnr_score) if self.psnr_score else None, - 'heatmap_path': self.heatmap_path, - 'evaluated_at': self.evaluated_at.isoformat() if self.evaluated_at else None - } - -class FinetuneBatch(db.Model): - """微调任务表""" - __tablename__ = 'finetune_batch' - - id = db.Column(db.BigInteger, primary_key=True) - batch_id = db.Column(db.BigInteger, db.ForeignKey('batch.id'), nullable=False) - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), nullable=False) - finetune_config_id = db.Column(db.BigInteger, db.ForeignKey('finetune_configs.id')) - - # 任务状态 - status = db.Column(db.Enum('pending', 'queued', 'processing', 'completed', 'failed'), default='pending') - created_at = db.Column(db.DateTime, default=datetime.utcnow) - started_at = db.Column(db.DateTime) - completed_at = db.Column(db.DateTime) - error_message = db.Column(db.Text) - - # 任务ID(用于RQ任务追踪) - original_job_id = db.Column(db.String(255)) # 原始图片微调任务ID - perturbed_job_id = db.Column(db.String(255)) # 扰动图片微调任务ID - + evaluation_results_id = db.Column(BigInteger, primary_key=True, autoincrement=True) + fid_score = db.Column(Float, default=None, comment='FID指标') + lpips_score = db.Column(Float, default=None, comment='LPIPS指标') + ssim_score = db.Column(Float, default=None, comment='SSIM指标') + psnr_score = db.Column(Float, default=None, comment='PSNR指标') + + def __repr__(self): + return f'' + +# ---------------------------- +# 9. 任务子表:评估任务 (evaluate) - [已更新为复合主键] +# ---------------------------- +class Evaluate(db.Model): + """指标计算任务表""" + __tablename__ = 'evaluate' + # --- 变更部分:复合主键 --- + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True) + finetune_configs_id = db.Column(Integer, ForeignKey('finetune_configs.finetune_configs_id'), primary_key=True, comment='关联的微调配置(如果是针对微调的评估)') + # --- 变更结束 --- + evaluate_name = db.Column(String(100)) + evaluation_results_id = db.Column(BigInteger, ForeignKey('evaluation_results.evaluation_results_id'), unique=True, default=None, comment='关联的结果ID') + + # --- 变更部分:更新 back_populates --- + task = db.relationship('Task', back_populates='evaluations') + # --- 变更结束 --- + finetune_config = db.relationship('FinetuneConfig') + evaluation_result = db.relationship('EvaluationResult', backref='evaluate_task', uselist=False) + + def __repr__(self): + return f'' + +# ---------------------------- +# 10. 任务子表:热力图计算任务 (heatmap) +# ---------------------------- +class Heatmap(db.Model): + """热力图计算任务表""" + __tablename__ = 'heatmap' + tasks_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), primary_key=True) + heatmap_name = db.Column(String(100)) + # 关系 - batch = db.relationship('Batch', backref='finetune_tasks', lazy=True) - user = db.relationship('User', backref='finetune_tasks', lazy=True) + task = db.relationship('Task', back_populates='heatmap') - def to_dict(self): - """转换为字典""" - return { - 'id': self.id, - 'batch_id': self.batch_id, - 'user_id': self.user_id, - 'finetune_config_id': self.finetune_config_id, - 'finetune_config': self.finetune_config.method_name if self.finetune_config else None, - 'status': self.status, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'error_message': self.error_message, - 'original_job_id': self.original_job_id, - 'perturbed_job_id': self.perturbed_job_id - } + def __repr__(self): + return f'' -class UserConfig(db.Model): - """用户配置表""" - __tablename__ = 'user_configs' +# ---------------------------- +# 11. 图片表 (images) +# ---------------------------- +class Image(db.Model): + """图片资源表""" + __tablename__ = 'images' + images_id = db.Column(BigInteger, primary_key=True, autoincrement=True) + task_id = db.Column(BigInteger, ForeignKey('tasks.tasks_id', ondelete='CASCADE'), nullable=False, index=True, comment='关联的任务ID') + image_types_id = db.Column(Integer, ForeignKey('image_types.image_types_id'), nullable=False, comment='图片类型(原图/加噪图等)') + father_id = db.Column(BigInteger, ForeignKey('images.images_id'), default=None, index=True, comment='父级图片ID (用于溯源)') + stored_filename = db.Column(String(255), nullable=False, comment='存储文件名') + file_path = db.Column(String(255), nullable=False, comment='完整路径') + file_size = db.Column(BigInteger, default=None, comment='文件大小(Bytes)') + width = db.Column(Integer, default=None) + height = db.Column(Integer, default=None) + frequency = db.Column(Text, default=None, comment='频域信息/分析数据路径') - user_id = db.Column(db.BigInteger, db.ForeignKey('users.id'), primary_key=True) - preferred_perturbation_config_id = db.Column(db.BigInteger, - db.ForeignKey('perturbation_configs.id'), default=1) - preferred_epsilon = db.Column(db.Numeric(5, 2), default=8.0) - preferred_finetune_config_id = db.Column(db.BigInteger, - db.ForeignKey('finetune_configs.id'), default=1) - preferred_purification = db.Column(db.Boolean, default=False) - created_at = db.Column(db.DateTime, default=datetime.utcnow) - updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + # 关系 + image_type = db.relationship('ImageType') - def to_dict(self): - """转换为字典""" - return { - 'user_id': self.user_id, - 'preferred_epsilon': float(self.preferred_epsilon) if self.preferred_epsilon else None, - 'preferred_purification': self.preferred_purification, - 'preferred_perturbation_config': self.preferred_perturbation_config.method_name if self.preferred_perturbation_config else None, - 'preferred_finetune_config': self.preferred_finetune_config.method_name if self.preferred_finetune_config else None, - 'updated_at': self.updated_at.isoformat() if self.updated_at else None - } + # 自我引用关系 + father_image = db.relationship('Image', remote_side=[images_id], backref=db.backref('child_images', lazy='dynamic')) + + def __repr__(self): + return f'' diff --git a/src/backend/init_db.py b/src/backend/init_db.py index 258a634..1507886 100644 --- a/src/backend/init_db.py +++ b/src/backend/init_db.py @@ -13,12 +13,38 @@ def init_database(): # 创建所有表 db.create_all() + # 初始化角色数据 + roles = [ + {'role_id': 0, 'name': '管理员', 'max_concurrent_tasks': 15, 'description': '系统管理员,拥有最高权限'}, + {'role_id': 1, 'name': 'VIP用户', 'max_concurrent_tasks': 10, 'description': '付费用户,享有较高的资源使用权限'}, + {'role_id': 2, 'name': '普通用户', 'max_concurrent_tasks': 5, 'description': '免费用户,享有基本的资源使用权限'} + ] + for role_data in roles: + existing = Role.query.filter_by(role_id=role_data['role_id']).first() + if not existing: + new_role = Role(**role_data) + db.session.add(new_role) + + # 初始化任务状态数据 + task_status = [ + {'status_code': 'pending', 'status_name': '待处理', 'description': '任务已创建,等待处理'}, + {'status_code': 'in_progress', 'status_name': '进行中', 'description': '任务正在处理中'}, + {'status_code': 'completed', 'status_name': '已完成', 'description':'任务已成功完成'}, + {'status_code': 'failed', 'status_name': '失败', 'description': '任务处理失败'} + ] + for status in task_status: + existing = TaskStatus.query.filter_by(status_code=status['status_code']).first() + if not existing: + new_status = TaskStatus(**status) + db.session.add(new_status) + # 初始化图片类型数据 image_types = [ - {'type_code': 'original', 'type_name': '原始图片', 'description': '用户上传的原始图像文件'}, - {'type_code': 'perturbed', 'type_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, - {'type_code': 'original_generate', 'type_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, - {'type_code': 'perturbed_generate', 'type_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'} + {'image_code': 'original', 'image_name': '原始图片', 'description': '用户上传的原始图像文件'}, + {'image_code': 'perturbed', 'image_name': '加噪后图片', 'description': '经过扰动算法处理后的防护图像'}, + {'image_code': 'original_generate', 'image_name': '原始图像生成图片', 'description': '利用原始图像训练模型后模型生成图片'}, + {'image_code': 'perturbed_generate', 'image_name': '加噪后图像生成图片', 'description': '利用加噪后图像训练模型后模型生成图片'}, + {'image_code': 'heatmap', 'image_name': '生成的热力图', 'description': '热力图'} ] for img_type in image_types: @@ -29,10 +55,10 @@ def init_database(): # 初始化加噪算法数据 perturbation_configs = [ - {'method_code': 'aspl', 'method_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense', 'default_epsilon': 6.0}, - {'method_code': 'simac', 'method_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy', 'default_epsilon': 8.0}, - {'method_code': 'caat', 'method_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck', 'default_epsilon': 16.0}, - {'method_code': 'pid', 'method_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models', 'default_epsilon': 4.0} + {'perturbation_code': 'aspl', 'perturbation_name': 'ASPL算法', 'description': 'Advanced Semantic Protection Layer for Enhanced Privacy Defense'}, + {'perturbation_code': 'simac', 'perturbation_name': 'SimAC算法', 'description': 'Simple Anti-Customization Method for Protecting Face Privacy'}, + {'perturbation_code': 'caat', 'perturbation_name': 'CAAT算法', 'description': 'Perturbing Attention Gives You More Bang for the Buck'}, + {'perturbation_code': 'pid', 'perturbation_name': 'PID算法', 'description': 'Prompt-Independent Data Protection Against Latent Diffusion Models'} ] for config in perturbation_configs: @@ -43,9 +69,9 @@ def init_database(): # 初始化微调方式数据 finetune_configs = [ - {'method_code': 'dreambooth', 'method_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'}, - {'method_code': 'lora', 'method_name': 'LoRA', 'description': '低秩适应(Low-Rank Adaptation)微调方法'}, - {'method_code': 'textual_inversion', 'method_name': 'Textual Inversion', 'description': '文本反转个性化方法'} + {'finetune_code': 'dreambooth', 'finetune_name': 'DreamBooth', 'description': 'DreamBooth个性化文本到图像生成'}, + {'finetune_code': 'lora', 'finetune_name': 'LoRA', 'description': '低秩适应(Low-Rank Adaptation)微调方法'}, + {'finetune_code': 'textual_inversion', 'finetune_name': 'Textual Inversion', 'description': '文本反转个性化方法'} ] for config in finetune_configs: @@ -53,12 +79,31 @@ def init_database(): if not existing: new_config = FinetuneConfig(**config) db.session.add(new_config) - + + # 初始化数据集类型数据 + dataset_types = [ + {'data_type_id': 0, 'dataset_code': 'facial', 'dataset_name': '人脸数据集', 'description': '人脸类型的数据集'}, + {'data_type_id': 1, 'dataset_code': 'art', 'dataset_name': '艺术品数据集', 'description': '艺术品类型的数据集'} + ] + for dataset in dataset_types: + existing = DataType.query.filter_by(data_type_id=dataset['data_type_id']).first() + if not existing: + new_dataset = DataType(**dataset) + db.session.add(new_dataset) + + # 初始化任务类型数据 + task_types = [ + {'task_type_id': 0, 'task_code': 'perturbation', 'task_name': '加噪任务', 'description': '对图像进行加噪处理的任务'}, + {'task_type_id': 1, 'task_code': 'finetune', 'task_name': '微调任务', 'description': '对模型进行微调训练的任务'}, + {'task_type_id': 2, 'task_code': 'generation', 'task_name': '生成任务', 'description': '利用微调后模型进行图像生成的任务'} + {'task_type_id': 3, 'task_code': 'heatmap', 'task_name': '热力图任务', 'description': '计算X和X’的热力图的任务'} + ] + # 创建默认管理员用户 admin_users = [ - {'username': 'admin1', 'email': 'admin1@museguard.com', 'role': 'admin'}, - {'username': 'admin2', 'email': 'admin2@museguard.com', 'role': 'admin'}, - {'username': 'admin3', 'email': 'admin3@museguard.com', 'role': 'admin'} + {'username': 'admin1', 'email': 'admin1@museguard.com', 'role_id': 0}, + {'username': 'admin2', 'email': 'admin2@museguard.com', 'role_id': 0}, + {'username': 'admin3', 'email': 'admin3@museguard.com', 'role_id': 0} ] for admin_data in admin_users: