You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

437 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Security Module Tests
"""
安全模块测试
测试传输加密、本地数据加密和密钥管理功能
需求: 10.1, 10.2, 10.3
"""
import os
import tempfile
import pytest
from pathlib import Path
from shared.security import (
AESCipher,
EncryptedData,
MessageEncryptor,
FileEncryptor,
KeyManager,
LocalDataEncryptor,
TLSManager,
EncryptionError,
DecryptionError,
create_message_encryptor,
create_file_encryptor,
create_local_data_encryptor,
encrypt_message,
decrypt_message,
)
class TestAESCipher:
"""AES-256-GCM 加密器测试"""
def test_generate_key(self):
"""测试密钥生成"""
key = AESCipher.generate_key()
assert len(key) == 32 # 256 bits
# 确保每次生成的密钥不同
key2 = AESCipher.generate_key()
assert key != key2
def test_generate_iv(self):
"""测试IV生成"""
iv = AESCipher.generate_iv()
assert len(iv) == 12 # 96 bits
def test_encrypt_decrypt(self):
"""测试加密解密往返"""
cipher = AESCipher()
plaintext = b"Hello, World! This is a test message."
encrypted = cipher.encrypt(plaintext)
decrypted = cipher.decrypt(encrypted)
assert decrypted == plaintext
def test_encrypt_decrypt_empty(self):
"""测试空数据加密解密"""
cipher = AESCipher()
plaintext = b""
encrypted = cipher.encrypt(plaintext)
decrypted = cipher.decrypt(encrypted)
assert decrypted == plaintext
def test_encrypt_decrypt_large_data(self):
"""测试大数据加密解密"""
cipher = AESCipher()
plaintext = os.urandom(1024 * 1024) # 1MB
encrypted = cipher.encrypt(plaintext)
decrypted = cipher.decrypt(encrypted)
assert decrypted == plaintext
def test_encrypt_decrypt_unicode(self):
"""测试Unicode数据加密解密"""
cipher = AESCipher()
plaintext = "你好,世界!这是一条测试消息。🎉".encode('utf-8')
encrypted = cipher.encrypt(plaintext)
decrypted = cipher.decrypt(encrypted)
assert decrypted == plaintext
def test_different_keys_produce_different_ciphertext(self):
"""测试不同密钥产生不同密文"""
cipher1 = AESCipher()
cipher2 = AESCipher()
plaintext = b"Same message"
encrypted1 = cipher1.encrypt(plaintext)
encrypted2 = cipher2.encrypt(plaintext)
assert encrypted1.ciphertext != encrypted2.ciphertext
def test_wrong_key_fails_decryption(self):
"""测试错误密钥解密失败"""
cipher1 = AESCipher()
cipher2 = AESCipher()
plaintext = b"Secret message"
encrypted = cipher1.encrypt(plaintext)
with pytest.raises(DecryptionError):
cipher2.decrypt(encrypted)
def test_derive_key_from_password(self):
"""测试从密码派生密钥"""
password = "my_secure_password"
key1, salt1 = AESCipher.derive_key_from_password(password)
assert len(key1) == 32
assert len(salt1) == 16
# 相同密码和盐应产生相同密钥
key2, _ = AESCipher.derive_key_from_password(password, salt1)
assert key1 == key2
# 不同盐应产生不同密钥
key3, salt3 = AESCipher.derive_key_from_password(password)
assert key1 != key3
def test_encrypt_with_password(self):
"""测试使用密码加密"""
cipher = AESCipher()
password = "test_password"
plaintext = b"Secret data"
encrypted = cipher.encrypt_with_password(plaintext, password)
assert encrypted.salt # 应包含盐值
decrypted = cipher.decrypt_with_password(encrypted, password)
assert decrypted == plaintext
def test_wrong_password_fails(self):
"""测试错误密码解密失败"""
cipher = AESCipher()
plaintext = b"Secret data"
encrypted = cipher.encrypt_with_password(plaintext, "correct_password")
with pytest.raises(DecryptionError):
cipher.decrypt_with_password(encrypted, "wrong_password")
class TestEncryptedData:
"""加密数据结构测试"""
def test_to_bytes_from_bytes(self):
"""测试序列化和反序列化"""
original = EncryptedData(
ciphertext=b"encrypted_content",
iv=b"123456789012",
tag=b"1234567890123456",
salt=b"salt_value_here!"
)
serialized = original.to_bytes()
restored = EncryptedData.from_bytes(serialized)
assert restored.ciphertext == original.ciphertext
assert restored.iv == original.iv
assert restored.tag == original.tag
assert restored.salt == original.salt
def test_to_base64_from_base64(self):
"""测试Base64编码和解码"""
original = EncryptedData(
ciphertext=b"encrypted_content",
iv=b"123456789012",
tag=b"1234567890123456"
)
base64_str = original.to_base64()
restored = EncryptedData.from_base64(base64_str)
assert restored.ciphertext == original.ciphertext
assert restored.iv == original.iv
assert restored.tag == original.tag
class TestMessageEncryptor:
"""消息加密器测试"""
def test_encrypt_decrypt_message(self):
"""测试消息加密解密"""
encryptor = MessageEncryptor()
message_data = b'{"type": "text", "content": "Hello!"}'
encrypted = encryptor.encrypt_message(message_data)
decrypted = encryptor.decrypt_message(encrypted)
assert decrypted == message_data
def test_set_key(self):
"""测试设置密钥"""
key = AESCipher.generate_key()
encryptor = MessageEncryptor()
encryptor.set_key(key)
assert encryptor.key == key
def test_shared_key_encryption(self):
"""测试共享密钥加密"""
key = AESCipher.generate_key()
encryptor1 = MessageEncryptor(AESCipher(key))
encryptor2 = MessageEncryptor(AESCipher(key))
message = b"Shared secret message"
encrypted = encryptor1.encrypt_message(message)
decrypted = encryptor2.decrypt_message(encrypted)
assert decrypted == message
class TestFileEncryptor:
"""文件加密器测试"""
def test_encrypt_decrypt_file(self):
"""测试文件加密解密"""
encryptor = FileEncryptor()
with tempfile.TemporaryDirectory() as tmpdir:
# 创建测试文件
input_file = Path(tmpdir) / "test.txt"
encrypted_file = Path(tmpdir) / "test.enc"
output_file = Path(tmpdir) / "test_decrypted.txt"
original_content = b"This is test file content.\n" * 100
input_file.write_bytes(original_content)
# 加密
encryptor.encrypt_file(str(input_file), str(encrypted_file))
assert encrypted_file.exists()
# 解密
encryptor.decrypt_file(str(encrypted_file), str(output_file))
assert output_file.exists()
# 验证内容
decrypted_content = output_file.read_bytes()
assert decrypted_content == original_content
def test_encrypt_decrypt_chunk(self):
"""测试文件块加密解密"""
encryptor = FileEncryptor()
chunk_data = os.urandom(64 * 1024) # 64KB
encrypted = encryptor.encrypt_chunk(chunk_data)
decrypted = encryptor.decrypt_chunk(encrypted)
assert decrypted == chunk_data
class TestKeyManager:
"""密钥管理器测试"""
def test_generate_key_pair(self):
"""测试生成密钥对"""
with tempfile.TemporaryDirectory() as tmpdir:
manager = KeyManager(tmpdir)
private_key, public_key = manager.generate_key_pair()
assert b"PRIVATE KEY" in private_key
assert b"PUBLIC KEY" in public_key
def test_save_load_key(self):
"""测试保存和加载密钥"""
with tempfile.TemporaryDirectory() as tmpdir:
manager = KeyManager(tmpdir)
key_data = AESCipher.generate_key()
# 保存
assert manager.save_key("test_key", key_data)
# 加载
loaded = manager.load_key("test_key")
assert loaded == key_data
def test_save_load_key_with_password(self):
"""测试使用密码保存和加载密钥"""
with tempfile.TemporaryDirectory() as tmpdir:
manager = KeyManager(tmpdir)
key_data = AESCipher.generate_key()
password = "secure_password"
# 保存
assert manager.save_key("protected_key", key_data, password)
# 加载
loaded = manager.load_key("protected_key", password)
assert loaded == key_data
def test_delete_key(self):
"""测试删除密钥"""
with tempfile.TemporaryDirectory() as tmpdir:
manager = KeyManager(tmpdir)
key_data = AESCipher.generate_key()
manager.save_key("to_delete", key_data)
assert manager.delete_key("to_delete")
# 验证已删除
assert manager.load_key("to_delete") is None
def test_session_key_management(self):
"""测试会话密钥管理"""
with tempfile.TemporaryDirectory() as tmpdir:
manager = KeyManager(tmpdir)
# 获取或创建会话密钥
key1 = manager.get_or_create_session_key("session1")
assert len(key1) == 32
# 再次获取应返回相同密钥
key2 = manager.get_or_create_session_key("session1")
assert key1 == key2
# 不同会话应有不同密钥
key3 = manager.get_or_create_session_key("session2")
assert key1 != key3
# 清除会话密钥
manager.clear_session_key("session1")
key4 = manager.get_or_create_session_key("session1")
assert key1 != key4 # 应该是新密钥
class TestLocalDataEncryptor:
"""本地数据加密器测试"""
def test_encrypt_decrypt_string(self):
"""测试字符串加密解密"""
encryptor = LocalDataEncryptor(password="test_password")
original = "Hello, World!"
encrypted = encryptor.encrypt_data(original)
decrypted = encryptor.decrypt_data(encrypted)
assert decrypted.decode('utf-8') == original
def test_encrypt_decrypt_dict(self):
"""测试字典加密解密"""
encryptor = LocalDataEncryptor(password="test_password")
original = {"name": "Test", "value": 123, "nested": {"key": "value"}}
encrypted = encryptor.encrypt_data(original)
decrypted = encryptor.decrypt_data(encrypted, as_json=True)
assert decrypted == original
def test_encrypt_decrypt_chat_history(self):
"""测试聊天记录加密解密"""
encryptor = LocalDataEncryptor(password="chat_password")
messages = [
{"sender": "user1", "content": "Hello", "timestamp": 1234567890},
{"sender": "user2", "content": "Hi there!", "timestamp": 1234567891},
]
encrypted = encryptor.encrypt_chat_history(messages)
decrypted = encryptor.decrypt_chat_history(encrypted)
assert decrypted == messages
def test_save_load_encrypted_file(self):
"""测试保存和加载加密文件"""
with tempfile.TemporaryDirectory() as tmpdir:
encryptor = LocalDataEncryptor(password="file_password")
file_path = Path(tmpdir) / "encrypted_data.enc"
original_data = {"key": "value", "list": [1, 2, 3]}
# 保存
assert encryptor.save_encrypted_file(original_data, str(file_path))
assert file_path.exists()
# 加载
loaded = encryptor.load_encrypted_file(str(file_path), as_json=True)
assert loaded == original_data
class TestConvenienceFunctions:
"""便捷函数测试"""
def test_create_message_encryptor(self):
"""测试创建消息加密器"""
encryptor = create_message_encryptor()
assert encryptor is not None
assert len(encryptor.key) == 32
def test_create_file_encryptor(self):
"""测试创建文件加密器"""
encryptor = create_file_encryptor()
assert encryptor is not None
assert len(encryptor.key) == 32
def test_create_local_data_encryptor(self):
"""测试创建本地数据加密器"""
encryptor = create_local_data_encryptor("password")
assert encryptor is not None
def test_encrypt_decrypt_message_functions(self):
"""测试快速加密解密函数"""
key = AESCipher.generate_key()
message = b"Quick encryption test"
encrypted = encrypt_message(message, key)
decrypted = decrypt_message(encrypted, key)
assert decrypted == message
class TestTLSManager:
"""TLS管理器测试"""
def test_create_client_ssl_context_no_verify(self):
"""测试创建客户端SSL上下文不验证"""
manager = TLSManager()
context = manager.create_client_ssl_context(verify=False)
assert context is not None
assert context.verify_mode.name == "CERT_NONE"
def test_create_client_ssl_context_with_verify(self):
"""测试创建客户端SSL上下文验证"""
manager = TLSManager()
context = manager.create_client_ssl_context(verify=True)
assert context is not None
assert context.verify_mode.name == "CERT_REQUIRED"