|
|
# P2P Network Communication - Security Module Tests
|
|
|
"""
|
|
|
安全模块测试
|
|
|
测试传输加密、本地数据加密和密钥管理功能
|
|
|
|
|
|
需求: 10.1, 10.2, 10.3
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import tempfile
|
|
|
import pytest
|
|
|
from pathlib import Path
|
|
|
|
|
|
from shared.security import (
|
|
|
AESCipher,
|
|
|
EncryptedData,
|
|
|
MessageEncryptor,
|
|
|
FileEncryptor,
|
|
|
KeyManager,
|
|
|
LocalDataEncryptor,
|
|
|
TLSManager,
|
|
|
EncryptionError,
|
|
|
DecryptionError,
|
|
|
create_message_encryptor,
|
|
|
create_file_encryptor,
|
|
|
create_local_data_encryptor,
|
|
|
encrypt_message,
|
|
|
decrypt_message,
|
|
|
)
|
|
|
|
|
|
|
|
|
class TestAESCipher:
|
|
|
"""AES-256-GCM 加密器测试"""
|
|
|
|
|
|
def test_generate_key(self):
|
|
|
"""测试密钥生成"""
|
|
|
key = AESCipher.generate_key()
|
|
|
assert len(key) == 32 # 256 bits
|
|
|
|
|
|
# 确保每次生成的密钥不同
|
|
|
key2 = AESCipher.generate_key()
|
|
|
assert key != key2
|
|
|
|
|
|
def test_generate_iv(self):
|
|
|
"""测试IV生成"""
|
|
|
iv = AESCipher.generate_iv()
|
|
|
assert len(iv) == 12 # 96 bits
|
|
|
|
|
|
def test_encrypt_decrypt(self):
|
|
|
"""测试加密解密往返"""
|
|
|
cipher = AESCipher()
|
|
|
plaintext = b"Hello, World! This is a test message."
|
|
|
|
|
|
encrypted = cipher.encrypt(plaintext)
|
|
|
decrypted = cipher.decrypt(encrypted)
|
|
|
|
|
|
assert decrypted == plaintext
|
|
|
|
|
|
def test_encrypt_decrypt_empty(self):
|
|
|
"""测试空数据加密解密"""
|
|
|
cipher = AESCipher()
|
|
|
plaintext = b""
|
|
|
|
|
|
encrypted = cipher.encrypt(plaintext)
|
|
|
decrypted = cipher.decrypt(encrypted)
|
|
|
|
|
|
assert decrypted == plaintext
|
|
|
|
|
|
def test_encrypt_decrypt_large_data(self):
|
|
|
"""测试大数据加密解密"""
|
|
|
cipher = AESCipher()
|
|
|
plaintext = os.urandom(1024 * 1024) # 1MB
|
|
|
|
|
|
encrypted = cipher.encrypt(plaintext)
|
|
|
decrypted = cipher.decrypt(encrypted)
|
|
|
|
|
|
assert decrypted == plaintext
|
|
|
|
|
|
def test_encrypt_decrypt_unicode(self):
|
|
|
"""测试Unicode数据加密解密"""
|
|
|
cipher = AESCipher()
|
|
|
plaintext = "你好,世界!这是一条测试消息。🎉".encode('utf-8')
|
|
|
|
|
|
encrypted = cipher.encrypt(plaintext)
|
|
|
decrypted = cipher.decrypt(encrypted)
|
|
|
|
|
|
assert decrypted == plaintext
|
|
|
|
|
|
def test_different_keys_produce_different_ciphertext(self):
|
|
|
"""测试不同密钥产生不同密文"""
|
|
|
cipher1 = AESCipher()
|
|
|
cipher2 = AESCipher()
|
|
|
plaintext = b"Same message"
|
|
|
|
|
|
encrypted1 = cipher1.encrypt(plaintext)
|
|
|
encrypted2 = cipher2.encrypt(plaintext)
|
|
|
|
|
|
assert encrypted1.ciphertext != encrypted2.ciphertext
|
|
|
|
|
|
def test_wrong_key_fails_decryption(self):
|
|
|
"""测试错误密钥解密失败"""
|
|
|
cipher1 = AESCipher()
|
|
|
cipher2 = AESCipher()
|
|
|
plaintext = b"Secret message"
|
|
|
|
|
|
encrypted = cipher1.encrypt(plaintext)
|
|
|
|
|
|
with pytest.raises(DecryptionError):
|
|
|
cipher2.decrypt(encrypted)
|
|
|
|
|
|
def test_derive_key_from_password(self):
|
|
|
"""测试从密码派生密钥"""
|
|
|
password = "my_secure_password"
|
|
|
|
|
|
key1, salt1 = AESCipher.derive_key_from_password(password)
|
|
|
assert len(key1) == 32
|
|
|
assert len(salt1) == 16
|
|
|
|
|
|
# 相同密码和盐应产生相同密钥
|
|
|
key2, _ = AESCipher.derive_key_from_password(password, salt1)
|
|
|
assert key1 == key2
|
|
|
|
|
|
# 不同盐应产生不同密钥
|
|
|
key3, salt3 = AESCipher.derive_key_from_password(password)
|
|
|
assert key1 != key3
|
|
|
|
|
|
def test_encrypt_with_password(self):
|
|
|
"""测试使用密码加密"""
|
|
|
cipher = AESCipher()
|
|
|
password = "test_password"
|
|
|
plaintext = b"Secret data"
|
|
|
|
|
|
encrypted = cipher.encrypt_with_password(plaintext, password)
|
|
|
assert encrypted.salt # 应包含盐值
|
|
|
|
|
|
decrypted = cipher.decrypt_with_password(encrypted, password)
|
|
|
assert decrypted == plaintext
|
|
|
|
|
|
def test_wrong_password_fails(self):
|
|
|
"""测试错误密码解密失败"""
|
|
|
cipher = AESCipher()
|
|
|
plaintext = b"Secret data"
|
|
|
|
|
|
encrypted = cipher.encrypt_with_password(plaintext, "correct_password")
|
|
|
|
|
|
with pytest.raises(DecryptionError):
|
|
|
cipher.decrypt_with_password(encrypted, "wrong_password")
|
|
|
|
|
|
|
|
|
class TestEncryptedData:
|
|
|
"""加密数据结构测试"""
|
|
|
|
|
|
def test_to_bytes_from_bytes(self):
|
|
|
"""测试序列化和反序列化"""
|
|
|
original = EncryptedData(
|
|
|
ciphertext=b"encrypted_content",
|
|
|
iv=b"123456789012",
|
|
|
tag=b"1234567890123456",
|
|
|
salt=b"salt_value_here!"
|
|
|
)
|
|
|
|
|
|
serialized = original.to_bytes()
|
|
|
restored = EncryptedData.from_bytes(serialized)
|
|
|
|
|
|
assert restored.ciphertext == original.ciphertext
|
|
|
assert restored.iv == original.iv
|
|
|
assert restored.tag == original.tag
|
|
|
assert restored.salt == original.salt
|
|
|
|
|
|
def test_to_base64_from_base64(self):
|
|
|
"""测试Base64编码和解码"""
|
|
|
original = EncryptedData(
|
|
|
ciphertext=b"encrypted_content",
|
|
|
iv=b"123456789012",
|
|
|
tag=b"1234567890123456"
|
|
|
)
|
|
|
|
|
|
base64_str = original.to_base64()
|
|
|
restored = EncryptedData.from_base64(base64_str)
|
|
|
|
|
|
assert restored.ciphertext == original.ciphertext
|
|
|
assert restored.iv == original.iv
|
|
|
assert restored.tag == original.tag
|
|
|
|
|
|
|
|
|
class TestMessageEncryptor:
|
|
|
"""消息加密器测试"""
|
|
|
|
|
|
def test_encrypt_decrypt_message(self):
|
|
|
"""测试消息加密解密"""
|
|
|
encryptor = MessageEncryptor()
|
|
|
message_data = b'{"type": "text", "content": "Hello!"}'
|
|
|
|
|
|
encrypted = encryptor.encrypt_message(message_data)
|
|
|
decrypted = encryptor.decrypt_message(encrypted)
|
|
|
|
|
|
assert decrypted == message_data
|
|
|
|
|
|
def test_set_key(self):
|
|
|
"""测试设置密钥"""
|
|
|
key = AESCipher.generate_key()
|
|
|
encryptor = MessageEncryptor()
|
|
|
encryptor.set_key(key)
|
|
|
|
|
|
assert encryptor.key == key
|
|
|
|
|
|
def test_shared_key_encryption(self):
|
|
|
"""测试共享密钥加密"""
|
|
|
key = AESCipher.generate_key()
|
|
|
|
|
|
encryptor1 = MessageEncryptor(AESCipher(key))
|
|
|
encryptor2 = MessageEncryptor(AESCipher(key))
|
|
|
|
|
|
message = b"Shared secret message"
|
|
|
|
|
|
encrypted = encryptor1.encrypt_message(message)
|
|
|
decrypted = encryptor2.decrypt_message(encrypted)
|
|
|
|
|
|
assert decrypted == message
|
|
|
|
|
|
|
|
|
class TestFileEncryptor:
|
|
|
"""文件加密器测试"""
|
|
|
|
|
|
def test_encrypt_decrypt_file(self):
|
|
|
"""测试文件加密解密"""
|
|
|
encryptor = FileEncryptor()
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
# 创建测试文件
|
|
|
input_file = Path(tmpdir) / "test.txt"
|
|
|
encrypted_file = Path(tmpdir) / "test.enc"
|
|
|
output_file = Path(tmpdir) / "test_decrypted.txt"
|
|
|
|
|
|
original_content = b"This is test file content.\n" * 100
|
|
|
input_file.write_bytes(original_content)
|
|
|
|
|
|
# 加密
|
|
|
encryptor.encrypt_file(str(input_file), str(encrypted_file))
|
|
|
assert encrypted_file.exists()
|
|
|
|
|
|
# 解密
|
|
|
encryptor.decrypt_file(str(encrypted_file), str(output_file))
|
|
|
assert output_file.exists()
|
|
|
|
|
|
# 验证内容
|
|
|
decrypted_content = output_file.read_bytes()
|
|
|
assert decrypted_content == original_content
|
|
|
|
|
|
def test_encrypt_decrypt_chunk(self):
|
|
|
"""测试文件块加密解密"""
|
|
|
encryptor = FileEncryptor()
|
|
|
chunk_data = os.urandom(64 * 1024) # 64KB
|
|
|
|
|
|
encrypted = encryptor.encrypt_chunk(chunk_data)
|
|
|
decrypted = encryptor.decrypt_chunk(encrypted)
|
|
|
|
|
|
assert decrypted == chunk_data
|
|
|
|
|
|
|
|
|
class TestKeyManager:
|
|
|
"""密钥管理器测试"""
|
|
|
|
|
|
def test_generate_key_pair(self):
|
|
|
"""测试生成密钥对"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
manager = KeyManager(tmpdir)
|
|
|
private_key, public_key = manager.generate_key_pair()
|
|
|
|
|
|
assert b"PRIVATE KEY" in private_key
|
|
|
assert b"PUBLIC KEY" in public_key
|
|
|
|
|
|
def test_save_load_key(self):
|
|
|
"""测试保存和加载密钥"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
manager = KeyManager(tmpdir)
|
|
|
key_data = AESCipher.generate_key()
|
|
|
|
|
|
# 保存
|
|
|
assert manager.save_key("test_key", key_data)
|
|
|
|
|
|
# 加载
|
|
|
loaded = manager.load_key("test_key")
|
|
|
assert loaded == key_data
|
|
|
|
|
|
def test_save_load_key_with_password(self):
|
|
|
"""测试使用密码保存和加载密钥"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
manager = KeyManager(tmpdir)
|
|
|
key_data = AESCipher.generate_key()
|
|
|
password = "secure_password"
|
|
|
|
|
|
# 保存
|
|
|
assert manager.save_key("protected_key", key_data, password)
|
|
|
|
|
|
# 加载
|
|
|
loaded = manager.load_key("protected_key", password)
|
|
|
assert loaded == key_data
|
|
|
|
|
|
def test_delete_key(self):
|
|
|
"""测试删除密钥"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
manager = KeyManager(tmpdir)
|
|
|
key_data = AESCipher.generate_key()
|
|
|
|
|
|
manager.save_key("to_delete", key_data)
|
|
|
assert manager.delete_key("to_delete")
|
|
|
|
|
|
# 验证已删除
|
|
|
assert manager.load_key("to_delete") is None
|
|
|
|
|
|
def test_session_key_management(self):
|
|
|
"""测试会话密钥管理"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
manager = KeyManager(tmpdir)
|
|
|
|
|
|
# 获取或创建会话密钥
|
|
|
key1 = manager.get_or_create_session_key("session1")
|
|
|
assert len(key1) == 32
|
|
|
|
|
|
# 再次获取应返回相同密钥
|
|
|
key2 = manager.get_or_create_session_key("session1")
|
|
|
assert key1 == key2
|
|
|
|
|
|
# 不同会话应有不同密钥
|
|
|
key3 = manager.get_or_create_session_key("session2")
|
|
|
assert key1 != key3
|
|
|
|
|
|
# 清除会话密钥
|
|
|
manager.clear_session_key("session1")
|
|
|
key4 = manager.get_or_create_session_key("session1")
|
|
|
assert key1 != key4 # 应该是新密钥
|
|
|
|
|
|
|
|
|
class TestLocalDataEncryptor:
|
|
|
"""本地数据加密器测试"""
|
|
|
|
|
|
def test_encrypt_decrypt_string(self):
|
|
|
"""测试字符串加密解密"""
|
|
|
encryptor = LocalDataEncryptor(password="test_password")
|
|
|
original = "Hello, World!"
|
|
|
|
|
|
encrypted = encryptor.encrypt_data(original)
|
|
|
decrypted = encryptor.decrypt_data(encrypted)
|
|
|
|
|
|
assert decrypted.decode('utf-8') == original
|
|
|
|
|
|
def test_encrypt_decrypt_dict(self):
|
|
|
"""测试字典加密解密"""
|
|
|
encryptor = LocalDataEncryptor(password="test_password")
|
|
|
original = {"name": "Test", "value": 123, "nested": {"key": "value"}}
|
|
|
|
|
|
encrypted = encryptor.encrypt_data(original)
|
|
|
decrypted = encryptor.decrypt_data(encrypted, as_json=True)
|
|
|
|
|
|
assert decrypted == original
|
|
|
|
|
|
def test_encrypt_decrypt_chat_history(self):
|
|
|
"""测试聊天记录加密解密"""
|
|
|
encryptor = LocalDataEncryptor(password="chat_password")
|
|
|
messages = [
|
|
|
{"sender": "user1", "content": "Hello", "timestamp": 1234567890},
|
|
|
{"sender": "user2", "content": "Hi there!", "timestamp": 1234567891},
|
|
|
]
|
|
|
|
|
|
encrypted = encryptor.encrypt_chat_history(messages)
|
|
|
decrypted = encryptor.decrypt_chat_history(encrypted)
|
|
|
|
|
|
assert decrypted == messages
|
|
|
|
|
|
def test_save_load_encrypted_file(self):
|
|
|
"""测试保存和加载加密文件"""
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
encryptor = LocalDataEncryptor(password="file_password")
|
|
|
file_path = Path(tmpdir) / "encrypted_data.enc"
|
|
|
|
|
|
original_data = {"key": "value", "list": [1, 2, 3]}
|
|
|
|
|
|
# 保存
|
|
|
assert encryptor.save_encrypted_file(original_data, str(file_path))
|
|
|
assert file_path.exists()
|
|
|
|
|
|
# 加载
|
|
|
loaded = encryptor.load_encrypted_file(str(file_path), as_json=True)
|
|
|
assert loaded == original_data
|
|
|
|
|
|
|
|
|
class TestConvenienceFunctions:
|
|
|
"""便捷函数测试"""
|
|
|
|
|
|
def test_create_message_encryptor(self):
|
|
|
"""测试创建消息加密器"""
|
|
|
encryptor = create_message_encryptor()
|
|
|
assert encryptor is not None
|
|
|
assert len(encryptor.key) == 32
|
|
|
|
|
|
def test_create_file_encryptor(self):
|
|
|
"""测试创建文件加密器"""
|
|
|
encryptor = create_file_encryptor()
|
|
|
assert encryptor is not None
|
|
|
assert len(encryptor.key) == 32
|
|
|
|
|
|
def test_create_local_data_encryptor(self):
|
|
|
"""测试创建本地数据加密器"""
|
|
|
encryptor = create_local_data_encryptor("password")
|
|
|
assert encryptor is not None
|
|
|
|
|
|
def test_encrypt_decrypt_message_functions(self):
|
|
|
"""测试快速加密解密函数"""
|
|
|
key = AESCipher.generate_key()
|
|
|
message = b"Quick encryption test"
|
|
|
|
|
|
encrypted = encrypt_message(message, key)
|
|
|
decrypted = decrypt_message(encrypted, key)
|
|
|
|
|
|
assert decrypted == message
|
|
|
|
|
|
|
|
|
class TestTLSManager:
|
|
|
"""TLS管理器测试"""
|
|
|
|
|
|
def test_create_client_ssl_context_no_verify(self):
|
|
|
"""测试创建客户端SSL上下文(不验证)"""
|
|
|
manager = TLSManager()
|
|
|
context = manager.create_client_ssl_context(verify=False)
|
|
|
|
|
|
assert context is not None
|
|
|
assert context.verify_mode.name == "CERT_NONE"
|
|
|
|
|
|
def test_create_client_ssl_context_with_verify(self):
|
|
|
"""测试创建客户端SSL上下文(验证)"""
|
|
|
manager = TLSManager()
|
|
|
context = manager.create_client_ssl_context(verify=True)
|
|
|
|
|
|
assert context is not None
|
|
|
assert context.verify_mode.name == "CERT_REQUIRED"
|