@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MuseGuard 测试包
|
||||
包含单元测试和集成测试
|
||||
"""
|
||||
@ -0,0 +1,258 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试数据工厂
|
||||
使用 factory_boy 生成测试数据
|
||||
"""
|
||||
|
||||
import factory
|
||||
from factory.alchemy import SQLAlchemyModelFactory
|
||||
from faker import Faker
|
||||
from datetime import datetime
|
||||
|
||||
from app import db
|
||||
from app.database import (
|
||||
User, Role, Task, TaskType, TaskStatus, Image, ImageType,
|
||||
UserConfig, PerturbationConfig, FinetuneConfig, DataType,
|
||||
Perturbation, Finetune, Heatmap, Evaluate, EvaluationResult
|
||||
)
|
||||
|
||||
fake = Faker('zh_CN')
|
||||
|
||||
|
||||
class BaseFactory(SQLAlchemyModelFactory):
|
||||
"""基础工厂类"""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
sqlalchemy_session = db.session
|
||||
sqlalchemy_session_persistence = 'commit'
|
||||
|
||||
|
||||
class RoleFactory(BaseFactory):
|
||||
"""角色工厂"""
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
|
||||
role_code = factory.Sequence(lambda n: f'role_{n}')
|
||||
name = factory.LazyAttribute(lambda obj: f'角色_{obj.role_code}')
|
||||
max_concurrent_tasks = factory.Faker('random_int', min=1, max=10)
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class UserFactory(BaseFactory):
|
||||
"""用户工厂"""
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
|
||||
username = factory.Sequence(lambda n: f'user_{n}')
|
||||
email = factory.LazyAttribute(lambda obj: f'{obj.username}@example.com')
|
||||
role_id = 3 # 默认普通用户
|
||||
is_active = True
|
||||
|
||||
@factory.lazy_attribute
|
||||
def password_hash(self):
|
||||
from werkzeug.security import generate_password_hash
|
||||
return generate_password_hash('defaultpassword123')
|
||||
|
||||
@classmethod
|
||||
def _create(cls, model_class, *args, **kwargs):
|
||||
"""创建用户时设置密码"""
|
||||
password = kwargs.pop('password', 'defaultpassword123')
|
||||
obj = super()._create(model_class, *args, **kwargs)
|
||||
obj.set_password(password)
|
||||
return obj
|
||||
|
||||
|
||||
class AdminUserFactory(UserFactory):
|
||||
"""管理员用户工厂"""
|
||||
|
||||
username = factory.Sequence(lambda n: f'admin_{n}')
|
||||
role_id = 1
|
||||
|
||||
|
||||
class VipUserFactory(UserFactory):
|
||||
"""VIP用户工厂"""
|
||||
|
||||
username = factory.Sequence(lambda n: f'vip_{n}')
|
||||
role_id = 2
|
||||
|
||||
|
||||
class TaskTypeFactory(BaseFactory):
|
||||
"""任务类型工厂"""
|
||||
|
||||
class Meta:
|
||||
model = TaskType
|
||||
|
||||
task_type_code = factory.Sequence(lambda n: f'type_{n}')
|
||||
task_type_name = factory.LazyAttribute(lambda obj: f'任务类型_{obj.task_type_code}')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class TaskStatusFactory(BaseFactory):
|
||||
"""任务状态工厂"""
|
||||
|
||||
class Meta:
|
||||
model = TaskStatus
|
||||
|
||||
task_status_code = factory.Sequence(lambda n: f'status_{n}')
|
||||
task_status_name = factory.LazyAttribute(lambda obj: f'状态_{obj.task_status_code}')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class ImageTypeFactory(BaseFactory):
|
||||
"""图片类型工厂"""
|
||||
|
||||
class Meta:
|
||||
model = ImageType
|
||||
|
||||
image_code = factory.Sequence(lambda n: f'img_type_{n}')
|
||||
image_name = factory.LazyAttribute(lambda obj: f'图片类型_{obj.image_code}')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class DataTypeFactory(BaseFactory):
|
||||
"""数据类型工厂"""
|
||||
|
||||
class Meta:
|
||||
model = DataType
|
||||
|
||||
data_type_code = factory.Sequence(lambda n: f'data_{n}')
|
||||
instance_prompt = factory.Faker('sentence', locale='zh_CN')
|
||||
class_prompt = factory.Faker('sentence', locale='zh_CN')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class PerturbationConfigFactory(BaseFactory):
|
||||
"""加噪配置工厂"""
|
||||
|
||||
class Meta:
|
||||
model = PerturbationConfig
|
||||
|
||||
perturbation_code = factory.Sequence(lambda n: f'pert_{n}')
|
||||
perturbation_name = factory.LazyAttribute(lambda obj: f'加噪算法_{obj.perturbation_code}')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class FinetuneConfigFactory(BaseFactory):
|
||||
"""微调配置工厂"""
|
||||
|
||||
class Meta:
|
||||
model = FinetuneConfig
|
||||
|
||||
finetune_code = factory.Sequence(lambda n: f'ft_{n}')
|
||||
finetune_name = factory.LazyAttribute(lambda obj: f'微调方式_{obj.finetune_code}')
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class TaskFactory(BaseFactory):
|
||||
"""任务工厂"""
|
||||
|
||||
class Meta:
|
||||
model = Task
|
||||
|
||||
flow_id = factory.Sequence(lambda n: 1000000 + n)
|
||||
tasks_type_id = 1 # perturbation
|
||||
user_id = factory.LazyAttribute(lambda obj: UserFactory().user_id)
|
||||
tasks_status_id = 1 # waiting
|
||||
description = factory.Faker('sentence', locale='zh_CN')
|
||||
created_at = factory.LazyFunction(datetime.utcnow)
|
||||
|
||||
|
||||
class PerturbationTaskFactory(TaskFactory):
|
||||
"""加噪任务工厂"""
|
||||
|
||||
tasks_type_id = 1
|
||||
|
||||
@factory.post_generation
|
||||
def perturbation(obj, create, extracted, **kwargs):
|
||||
if not create:
|
||||
return
|
||||
|
||||
if extracted:
|
||||
return extracted
|
||||
|
||||
return PerturbationFactory(tasks_id=obj.tasks_id)
|
||||
|
||||
|
||||
class FinetuneTaskFactory(TaskFactory):
|
||||
"""微调任务工厂"""
|
||||
|
||||
tasks_type_id = 2
|
||||
|
||||
@factory.post_generation
|
||||
def finetune(obj, create, extracted, **kwargs):
|
||||
if not create:
|
||||
return
|
||||
|
||||
if extracted:
|
||||
return extracted
|
||||
|
||||
return FinetuneFactory(tasks_id=obj.tasks_id)
|
||||
|
||||
|
||||
class PerturbationFactory(BaseFactory):
|
||||
"""加噪详情工厂"""
|
||||
|
||||
class Meta:
|
||||
model = Perturbation
|
||||
|
||||
tasks_id = factory.LazyAttribute(lambda obj: TaskFactory(tasks_type_id=1).tasks_id)
|
||||
data_type_id = 1
|
||||
perturbation_configs_id = 1
|
||||
perturbation_intensity = factory.Faker('pyfloat', min_value=0.1, max_value=1.0)
|
||||
perturbation_name = factory.Faker('word', locale='zh_CN')
|
||||
|
||||
|
||||
class FinetuneFactory(BaseFactory):
|
||||
"""微调详情工厂"""
|
||||
|
||||
class Meta:
|
||||
model = Finetune
|
||||
|
||||
tasks_id = factory.LazyAttribute(lambda obj: TaskFactory(tasks_type_id=2).tasks_id)
|
||||
finetune_configs_id = 1
|
||||
data_type_id = 1
|
||||
finetune_name = factory.Faker('word', locale='zh_CN')
|
||||
custom_prompt = factory.Faker('sentence', locale='zh_CN')
|
||||
|
||||
|
||||
class ImageFactory(BaseFactory):
|
||||
"""图片工厂"""
|
||||
|
||||
class Meta:
|
||||
model = Image
|
||||
|
||||
task_id = factory.LazyAttribute(lambda obj: TaskFactory().tasks_id)
|
||||
image_types_id = 1 # original
|
||||
stored_filename = factory.Sequence(lambda n: f'image_{n}.png')
|
||||
file_path = factory.LazyAttribute(lambda obj: f'/tmp/{obj.stored_filename}')
|
||||
file_size = factory.Faker('random_int', min=1024, max=1024*1024)
|
||||
width = factory.Faker('random_int', min=256, max=2048)
|
||||
height = factory.Faker('random_int', min=256, max=2048)
|
||||
|
||||
|
||||
class UserConfigFactory(BaseFactory):
|
||||
"""用户配置工厂"""
|
||||
|
||||
class Meta:
|
||||
model = UserConfig
|
||||
|
||||
user_id = factory.LazyAttribute(lambda obj: UserFactory().user_id)
|
||||
data_type_id = 1
|
||||
perturbation_configs_id = 1
|
||||
perturbation_intensity = factory.Faker('pyfloat', min_value=0.1, max_value=1.0)
|
||||
finetune_configs_id = 1
|
||||
|
||||
|
||||
class EvaluationResultFactory(BaseFactory):
|
||||
"""评估结果工厂"""
|
||||
|
||||
class Meta:
|
||||
model = EvaluationResult
|
||||
|
||||
fid_score = factory.Faker('pyfloat', min_value=0.0, max_value=100.0)
|
||||
lpips_score = factory.Faker('pyfloat', min_value=0.0, max_value=1.0)
|
||||
ssim_score = factory.Faker('pyfloat', min_value=0.0, max_value=1.0)
|
||||
psnr_score = factory.Faker('pyfloat', min_value=10.0, max_value=50.0)
|
||||
@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
集成测试包
|
||||
测试API端点和组件间的交互
|
||||
"""
|
||||
@ -0,0 +1,251 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
认证API集成测试
|
||||
测试用户注册、登录、密码修改等接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestAuthRegister:
|
||||
"""用户注册接口测试"""
|
||||
|
||||
@patch('app.controllers.auth_controller.VerificationService')
|
||||
def test_register_成功注册(self, mock_verification, client, init_database):
|
||||
"""测试成功注册新用户"""
|
||||
# Mock 验证码服务
|
||||
mock_service = MagicMock()
|
||||
mock_service.verify_code.return_value = True
|
||||
mock_verification.return_value = mock_service
|
||||
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'newuser',
|
||||
'password': 'newpassword123',
|
||||
'email': 'newuser@example.com',
|
||||
'code': '123456'
|
||||
})
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.get_json()
|
||||
assert 'message' in data
|
||||
assert data['message'] == '注册成功'
|
||||
assert 'user' in data
|
||||
assert data['user']['username'] == 'newuser'
|
||||
|
||||
def test_register_缺少必填字段(self, client, init_database):
|
||||
"""测试注册时缺少必填字段"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'newuser'
|
||||
# 缺少 password 和 email
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
@patch('app.controllers.auth_controller.VerificationService')
|
||||
def test_register_用户名已存在(self, mock_verification, client, init_database, test_user):
|
||||
"""测试注册已存在的用户名"""
|
||||
mock_service = MagicMock()
|
||||
mock_service.verify_code.return_value = True
|
||||
mock_verification.return_value = mock_service
|
||||
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser', # 已存在
|
||||
'password': 'password123',
|
||||
'email': 'another@example.com',
|
||||
'code': '123456'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert '用户名已存在' in data['error']
|
||||
|
||||
@patch('app.controllers.auth_controller.VerificationService')
|
||||
def test_register_邮箱已存在(self, mock_verification, client, init_database, test_user):
|
||||
"""测试注册已存在的邮箱"""
|
||||
mock_service = MagicMock()
|
||||
mock_service.verify_code.return_value = True
|
||||
mock_verification.return_value = mock_service
|
||||
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'anotheruser',
|
||||
'password': 'password123',
|
||||
'email': 'test@example.com', # 已存在
|
||||
'code': '123456'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert '邮箱' in data['error']
|
||||
|
||||
def test_register_邮箱格式错误(self, client, init_database):
|
||||
"""测试注册时邮箱格式错误"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'newuser',
|
||||
'password': 'password123',
|
||||
'email': 'invalid-email',
|
||||
'code': '123456'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert '邮箱格式' in data['error']
|
||||
|
||||
|
||||
class TestAuthLogin:
|
||||
"""用户登录接口测试"""
|
||||
|
||||
def test_login_成功登录(self, client, init_database, test_user):
|
||||
"""测试成功登录"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'testpassword123'
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'access_token' in data
|
||||
assert 'user' in data
|
||||
assert data['user']['username'] == 'testuser'
|
||||
|
||||
def test_login_错误密码(self, client, init_database, test_user):
|
||||
"""测试使用错误密码登录"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'wrongpassword'
|
||||
})
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.get_json()
|
||||
assert '用户名或密码错误' in data['error']
|
||||
|
||||
def test_login_不存在的用户(self, client, init_database):
|
||||
"""测试登录不存在的用户"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'nonexistent',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.get_json()
|
||||
assert '用户名或密码错误' in data['error']
|
||||
|
||||
def test_login_缺少字段(self, client, init_database):
|
||||
"""测试登录时缺少字段"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser'
|
||||
# 缺少 password
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_login_禁用账户(self, init_database, client):
|
||||
"""测试登录被禁用的账户"""
|
||||
from app.database import User
|
||||
from app import db
|
||||
|
||||
# 创建禁用用户
|
||||
user = User(
|
||||
username='disableduser',
|
||||
email='disabled@example.com',
|
||||
role_id=3,
|
||||
is_active=False
|
||||
)
|
||||
user.set_password('password123')
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'disableduser',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.get_json()
|
||||
assert '禁用' in data['error']
|
||||
|
||||
|
||||
class TestAuthProfile:
|
||||
"""用户信息接口测试"""
|
||||
|
||||
def test_get_profile_获取用户信息(self, client, auth_headers):
|
||||
"""测试获取当前用户信息"""
|
||||
response = client.get('/api/auth/profile', headers=auth_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'user' in data
|
||||
assert data['user']['username'] == 'testuser'
|
||||
|
||||
def test_get_profile_未认证(self, client, init_database):
|
||||
"""测试未认证时获取用户信息"""
|
||||
response = client.get('/api/auth/profile')
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestAuthChangePassword:
|
||||
"""修改密码接口测试"""
|
||||
|
||||
def test_change_password_成功修改(self, client, auth_headers, test_user):
|
||||
"""测试成功修改密码"""
|
||||
response = client.post('/api/auth/change-password',
|
||||
headers=auth_headers,
|
||||
json={
|
||||
'old_password': 'testpassword123',
|
||||
'new_password': 'newpassword456'
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert '成功' in data['message']
|
||||
|
||||
# 验证新密码可以登录
|
||||
login_response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'newpassword456'
|
||||
})
|
||||
assert login_response.status_code == 200
|
||||
|
||||
def test_change_password_旧密码错误(self, client, auth_headers):
|
||||
"""测试使用错误的旧密码修改"""
|
||||
response = client.post('/api/auth/change-password',
|
||||
headers=auth_headers,
|
||||
json={
|
||||
'old_password': 'wrongoldpassword',
|
||||
'new_password': 'newpassword456'
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.get_json()
|
||||
assert '旧密码错误' in data['error']
|
||||
|
||||
def test_change_password_缺少字段(self, client, auth_headers):
|
||||
"""测试修改密码时缺少字段"""
|
||||
response = client.post('/api/auth/change-password',
|
||||
headers=auth_headers,
|
||||
json={
|
||||
'old_password': 'testpassword123'
|
||||
# 缺少 new_password
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestAuthLogout:
|
||||
"""用户登出接口测试"""
|
||||
|
||||
def test_logout_成功登出(self, client, auth_headers):
|
||||
"""测试成功登出"""
|
||||
response = client.post('/api/auth/logout', headers=auth_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert '登出成功' in data['message']
|
||||
@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试包
|
||||
测试独立的函数和类方法
|
||||
"""
|
||||
@ -0,0 +1,232 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Repository 层单元测试
|
||||
测试数据访问层的 CRUD 操作
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.database import User, Task, Image, UserConfig
|
||||
from app.repositories.user_repository import UserRepository, UserConfigRepository
|
||||
from app.repositories.task_repository import TaskRepository, PerturbationRepository
|
||||
from app.repositories.image_repository import ImageRepository
|
||||
|
||||
|
||||
class TestUserRepository:
|
||||
"""用户 Repository 单元测试"""
|
||||
|
||||
def test_get_by_username_存在的用户(self, init_database, test_user):
|
||||
"""测试根据用户名获取存在的用户"""
|
||||
repo = UserRepository()
|
||||
user = repo.get_by_username('testuser')
|
||||
|
||||
assert user is not None
|
||||
assert user.username == 'testuser'
|
||||
assert user.email == 'test@example.com'
|
||||
|
||||
def test_get_by_username_不存在的用户(self, init_database):
|
||||
"""测试根据用户名获取不存在的用户"""
|
||||
repo = UserRepository()
|
||||
user = repo.get_by_username('nonexistent')
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_get_by_email_存在的邮箱(self, init_database, test_user):
|
||||
"""测试根据邮箱获取用户"""
|
||||
repo = UserRepository()
|
||||
user = repo.get_by_email('test@example.com')
|
||||
|
||||
assert user is not None
|
||||
assert user.username == 'testuser'
|
||||
|
||||
def test_username_exists_检查用户名是否存在(self, init_database, test_user):
|
||||
"""测试检查用户名是否存在"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.username_exists('testuser') is True
|
||||
assert repo.username_exists('nonexistent') is False
|
||||
|
||||
def test_email_exists_检查邮箱是否存在(self, init_database, test_user):
|
||||
"""测试检查邮箱是否存在"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.email_exists('test@example.com') is True
|
||||
assert repo.email_exists('nonexistent@example.com') is False
|
||||
|
||||
def test_authenticate_正确凭据(self, init_database, test_user):
|
||||
"""测试使用正确凭据认证"""
|
||||
repo = UserRepository()
|
||||
user = repo.authenticate('testuser', 'testpassword123')
|
||||
|
||||
assert user is not None
|
||||
assert user.username == 'testuser'
|
||||
|
||||
def test_authenticate_错误密码(self, init_database, test_user):
|
||||
"""测试使用错误密码认证"""
|
||||
repo = UserRepository()
|
||||
user = repo.authenticate('testuser', 'wrongpassword')
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_不存在的用户(self, init_database):
|
||||
"""测试认证不存在的用户"""
|
||||
repo = UserRepository()
|
||||
user = repo.authenticate('nonexistent', 'password')
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_is_admin_管理员用户(self, init_database, admin_user):
|
||||
"""测试判断管理员用户"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.is_admin(admin_user) is True
|
||||
|
||||
def test_is_admin_普通用户(self, init_database, test_user):
|
||||
"""测试判断普通用户不是管理员"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.is_admin(test_user) is False
|
||||
|
||||
def test_is_vip_VIP用户(self, init_database, vip_user):
|
||||
"""测试判断VIP用户"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.is_vip(vip_user) is True
|
||||
|
||||
def test_get_max_concurrent_tasks_不同角色(self, init_database, test_user, admin_user, vip_user):
|
||||
"""测试获取不同角色的最大并发任务数"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.get_max_concurrent_tasks(admin_user) == 10
|
||||
assert repo.get_max_concurrent_tasks(vip_user) == 5
|
||||
assert repo.get_max_concurrent_tasks(test_user) == 2
|
||||
|
||||
|
||||
class TestTaskRepository:
|
||||
"""任务 Repository 单元测试"""
|
||||
|
||||
def test_get_by_user_获取用户任务(self, init_database, sample_task, test_user):
|
||||
"""测试获取用户的所有任务"""
|
||||
repo = TaskRepository()
|
||||
tasks = repo.get_by_user(test_user.user_id)
|
||||
|
||||
assert len(tasks) == 1
|
||||
assert tasks[0].tasks_id == sample_task.tasks_id
|
||||
|
||||
def test_get_by_user_无任务用户(self, init_database, admin_user):
|
||||
"""测试获取无任务用户的任务列表"""
|
||||
repo = TaskRepository()
|
||||
tasks = repo.get_by_user(admin_user.user_id)
|
||||
|
||||
assert len(tasks) == 0
|
||||
|
||||
def test_is_owner_验证任务归属(self, init_database, sample_task, test_user, admin_user):
|
||||
"""测试验证任务归属"""
|
||||
repo = TaskRepository()
|
||||
|
||||
assert repo.is_owner(sample_task, test_user.user_id) is True
|
||||
assert repo.is_owner(sample_task, admin_user.user_id) is False
|
||||
|
||||
def test_get_for_user_获取用户任务带权限验证(self, init_database, sample_task, test_user, admin_user):
|
||||
"""测试获取用户任务(带权限验证)"""
|
||||
repo = TaskRepository()
|
||||
|
||||
# 任务所有者可以获取
|
||||
task = repo.get_for_user(sample_task.tasks_id, test_user.user_id)
|
||||
assert task is not None
|
||||
|
||||
# 非所有者无法获取
|
||||
task = repo.get_for_user(sample_task.tasks_id, admin_user.user_id)
|
||||
assert task is None
|
||||
|
||||
def test_get_type_code_获取任务类型代码(self, init_database, sample_task):
|
||||
"""测试获取任务类型代码"""
|
||||
repo = TaskRepository()
|
||||
type_code = repo.get_type_code(sample_task)
|
||||
|
||||
assert type_code == 'perturbation'
|
||||
|
||||
def test_is_type_判断任务类型(self, init_database, sample_task):
|
||||
"""测试判断任务类型"""
|
||||
repo = TaskRepository()
|
||||
|
||||
assert repo.is_type(sample_task, 'perturbation') is True
|
||||
assert repo.is_type(sample_task, 'finetune') is False
|
||||
|
||||
|
||||
class TestImageRepository:
|
||||
"""图片 Repository 单元测试"""
|
||||
|
||||
def test_get_by_task_获取任务图片(self, init_database, sample_image, sample_task):
|
||||
"""测试获取任务的所有图片"""
|
||||
repo = ImageRepository()
|
||||
images = repo.get_by_task(sample_task.tasks_id)
|
||||
|
||||
assert len(images) == 1
|
||||
assert images[0].images_id == sample_image.images_id
|
||||
|
||||
def test_count_by_task_统计任务图片数量(self, init_database, sample_image, sample_task):
|
||||
"""测试统计任务的图片数量"""
|
||||
repo = ImageRepository()
|
||||
count = repo.count_by_task(sample_task.tasks_id)
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_is_owner_验证图片归属(self, init_database, sample_image, test_user, admin_user):
|
||||
"""测试验证图片归属(通过任务)"""
|
||||
repo = ImageRepository()
|
||||
|
||||
assert repo.is_owner(sample_image, test_user.user_id) is True
|
||||
assert repo.is_owner(sample_image, admin_user.user_id) is False
|
||||
|
||||
def test_get_by_task_and_type_获取指定类型图片(self, init_database, sample_image, sample_task):
|
||||
"""测试获取任务指定类型的图片"""
|
||||
repo = ImageRepository()
|
||||
|
||||
# 获取原图类型
|
||||
images = repo.get_by_task_and_type(sample_task.tasks_id, 'original')
|
||||
assert len(images) == 1
|
||||
|
||||
# 获取不存在的类型
|
||||
images = repo.get_by_task_and_type(sample_task.tasks_id, 'perturbed')
|
||||
assert len(images) == 0
|
||||
|
||||
|
||||
class TestBaseRepository:
|
||||
"""基础 Repository 单元测试"""
|
||||
|
||||
def test_get_by_id_获取实体(self, init_database, test_user):
|
||||
"""测试根据ID获取实体"""
|
||||
repo = UserRepository()
|
||||
user = repo.get_by_id(test_user.user_id)
|
||||
|
||||
assert user is not None
|
||||
assert user.user_id == test_user.user_id
|
||||
|
||||
def test_get_by_id_不存在的实体(self, init_database):
|
||||
"""测试获取不存在的实体"""
|
||||
repo = UserRepository()
|
||||
user = repo.get_by_id(99999)
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_exists_检查实体存在(self, init_database, test_user):
|
||||
"""测试检查实体是否存在"""
|
||||
repo = UserRepository()
|
||||
|
||||
assert repo.exists(test_user.user_id) is True
|
||||
assert repo.exists(99999) is False
|
||||
|
||||
def test_count_统计数量(self, init_database, test_user, admin_user):
|
||||
"""测试统计实体数量"""
|
||||
repo = UserRepository()
|
||||
count = repo.count()
|
||||
|
||||
assert count == 2 # test_user 和 admin_user
|
||||
|
||||
def test_find_by_条件查询(self, init_database, test_user):
|
||||
"""测试根据条件查询"""
|
||||
repo = UserRepository()
|
||||
users = repo.find_by(is_active=True)
|
||||
|
||||
assert len(users) >= 1
|
||||
assert all(u.is_active for u in users)
|
||||
Loading…
Reference in new issue