|
|
# P2P Network Communication - Security Module
|
|
|
"""
|
|
|
安全模块
|
|
|
负责传输加密、本地数据加密和密钥管理
|
|
|
|
|
|
需求: 10.1, 10.2, 10.3
|
|
|
"""
|
|
|
|
|
|
import base64
|
|
|
import hashlib
|
|
|
import hmac
|
|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
import secrets
|
|
|
import ssl
|
|
|
import struct
|
|
|
from dataclasses import dataclass, field
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
|
|
from cryptography.hazmat.primitives import hashes, padding, serialization
|
|
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding
|
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
from cryptography.x509 import load_pem_x509_certificate
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class SecurityError(Exception):
|
|
|
"""安全错误基类"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class EncryptionError(SecurityError):
|
|
|
"""加密错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class DecryptionError(SecurityError):
|
|
|
"""解密错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class KeyManagementError(SecurityError):
|
|
|
"""密钥管理错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
class CertificateError(SecurityError):
|
|
|
"""证书错误"""
|
|
|
pass
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class EncryptedData:
|
|
|
"""加密数据结构"""
|
|
|
ciphertext: bytes
|
|
|
iv: bytes # 初始化向量
|
|
|
tag: bytes # 认证标签 (GCM模式)
|
|
|
salt: bytes = field(default_factory=bytes) # 用于密钥派生
|
|
|
|
|
|
def to_bytes(self) -> bytes:
|
|
|
"""序列化为字节流"""
|
|
|
# 格式: [iv_len(2)][iv][tag_len(2)][tag][salt_len(2)][salt][ciphertext]
|
|
|
result = struct.pack('!H', len(self.iv)) + self.iv
|
|
|
result += struct.pack('!H', len(self.tag)) + self.tag
|
|
|
result += struct.pack('!H', len(self.salt)) + self.salt
|
|
|
result += self.ciphertext
|
|
|
return result
|
|
|
|
|
|
@classmethod
|
|
|
def from_bytes(cls, data: bytes) -> "EncryptedData":
|
|
|
"""从字节流反序列化"""
|
|
|
offset = 0
|
|
|
|
|
|
# 读取IV
|
|
|
iv_len = struct.unpack('!H', data[offset:offset+2])[0]
|
|
|
offset += 2
|
|
|
iv = data[offset:offset+iv_len]
|
|
|
offset += iv_len
|
|
|
|
|
|
# 读取tag
|
|
|
tag_len = struct.unpack('!H', data[offset:offset+2])[0]
|
|
|
offset += 2
|
|
|
tag = data[offset:offset+tag_len]
|
|
|
offset += tag_len
|
|
|
|
|
|
# 读取salt
|
|
|
salt_len = struct.unpack('!H', data[offset:offset+2])[0]
|
|
|
offset += 2
|
|
|
salt = data[offset:offset+salt_len]
|
|
|
offset += salt_len
|
|
|
|
|
|
# 剩余为密文
|
|
|
ciphertext = data[offset:]
|
|
|
|
|
|
return cls(ciphertext=ciphertext, iv=iv, tag=tag, salt=salt)
|
|
|
|
|
|
def to_base64(self) -> str:
|
|
|
"""转换为Base64字符串"""
|
|
|
return base64.b64encode(self.to_bytes()).decode('utf-8')
|
|
|
|
|
|
@classmethod
|
|
|
def from_base64(cls, data: str) -> "EncryptedData":
|
|
|
"""从Base64字符串创建"""
|
|
|
return cls.from_bytes(base64.b64decode(data))
|
|
|
|
|
|
|
|
|
|
|
|
class AESCipher:
|
|
|
"""
|
|
|
AES-256-GCM 加密器
|
|
|
|
|
|
用于消息和文件的加密传输 (需求 10.1, 10.3)
|
|
|
以及本地数据的加密存储 (需求 10.2)
|
|
|
"""
|
|
|
|
|
|
# AES-256 密钥长度
|
|
|
KEY_SIZE = 32 # 256 bits
|
|
|
|
|
|
# GCM IV 长度
|
|
|
IV_SIZE = 12 # 96 bits (推荐)
|
|
|
|
|
|
# GCM 认证标签长度
|
|
|
TAG_SIZE = 16 # 128 bits
|
|
|
|
|
|
# PBKDF2 盐长度
|
|
|
SALT_SIZE = 16 # 128 bits
|
|
|
|
|
|
# PBKDF2 迭代次数
|
|
|
PBKDF2_ITERATIONS = 100000
|
|
|
|
|
|
def __init__(self, key: Optional[bytes] = None):
|
|
|
"""
|
|
|
初始化AES加密器
|
|
|
|
|
|
Args:
|
|
|
key: 256位密钥,如果为None则生成新密钥
|
|
|
"""
|
|
|
if key is None:
|
|
|
self._key = self.generate_key()
|
|
|
else:
|
|
|
if len(key) != self.KEY_SIZE:
|
|
|
raise EncryptionError(f"Key must be {self.KEY_SIZE} bytes")
|
|
|
self._key = key
|
|
|
|
|
|
@property
|
|
|
def key(self) -> bytes:
|
|
|
"""获取密钥"""
|
|
|
return self._key
|
|
|
|
|
|
@staticmethod
|
|
|
def generate_key() -> bytes:
|
|
|
"""
|
|
|
生成随机256位密钥
|
|
|
|
|
|
Returns:
|
|
|
32字节随机密钥
|
|
|
"""
|
|
|
return secrets.token_bytes(AESCipher.KEY_SIZE)
|
|
|
|
|
|
@staticmethod
|
|
|
def generate_iv() -> bytes:
|
|
|
"""
|
|
|
生成随机IV
|
|
|
|
|
|
Returns:
|
|
|
12字节随机IV
|
|
|
"""
|
|
|
return secrets.token_bytes(AESCipher.IV_SIZE)
|
|
|
|
|
|
@staticmethod
|
|
|
def derive_key_from_password(password: str, salt: Optional[bytes] = None) -> Tuple[bytes, bytes]:
|
|
|
"""
|
|
|
从密码派生密钥 (PBKDF2)
|
|
|
|
|
|
实现密钥管理 (需求 10.2)
|
|
|
|
|
|
Args:
|
|
|
password: 用户密码
|
|
|
salt: 盐值,如果为None则生成新盐
|
|
|
|
|
|
Returns:
|
|
|
(密钥, 盐值) 元组
|
|
|
"""
|
|
|
if salt is None:
|
|
|
salt = secrets.token_bytes(AESCipher.SALT_SIZE)
|
|
|
|
|
|
kdf = PBKDF2HMAC(
|
|
|
algorithm=hashes.SHA256(),
|
|
|
length=AESCipher.KEY_SIZE,
|
|
|
salt=salt,
|
|
|
iterations=AESCipher.PBKDF2_ITERATIONS,
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
|
|
|
key = kdf.derive(password.encode('utf-8'))
|
|
|
return key, salt
|
|
|
|
|
|
def encrypt(self, plaintext: bytes) -> EncryptedData:
|
|
|
"""
|
|
|
加密数据 (AES-256-GCM)
|
|
|
|
|
|
实现消息加密传输 (需求 10.1)
|
|
|
实现文件加密传输 (需求 10.3)
|
|
|
|
|
|
Args:
|
|
|
plaintext: 明文数据
|
|
|
|
|
|
Returns:
|
|
|
加密数据对象
|
|
|
|
|
|
Raises:
|
|
|
EncryptionError: 加密失败
|
|
|
"""
|
|
|
try:
|
|
|
iv = self.generate_iv()
|
|
|
|
|
|
cipher = Cipher(
|
|
|
algorithms.AES(self._key),
|
|
|
modes.GCM(iv),
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
encryptor = cipher.encryptor()
|
|
|
|
|
|
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
|
|
|
|
|
|
return EncryptedData(
|
|
|
ciphertext=ciphertext,
|
|
|
iv=iv,
|
|
|
tag=encryptor.tag
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
raise EncryptionError(f"Encryption failed: {e}")
|
|
|
|
|
|
def decrypt(self, encrypted_data: EncryptedData) -> bytes:
|
|
|
"""
|
|
|
解密数据 (AES-256-GCM)
|
|
|
|
|
|
Args:
|
|
|
encrypted_data: 加密数据对象
|
|
|
|
|
|
Returns:
|
|
|
解密后的明文
|
|
|
|
|
|
Raises:
|
|
|
DecryptionError: 解密失败
|
|
|
"""
|
|
|
try:
|
|
|
cipher = Cipher(
|
|
|
algorithms.AES(self._key),
|
|
|
modes.GCM(encrypted_data.iv, encrypted_data.tag),
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
decryptor = cipher.decryptor()
|
|
|
|
|
|
plaintext = decryptor.update(encrypted_data.ciphertext) + decryptor.finalize()
|
|
|
|
|
|
return plaintext
|
|
|
|
|
|
except Exception as e:
|
|
|
raise DecryptionError(f"Decryption failed: {e}")
|
|
|
|
|
|
def encrypt_with_password(self, plaintext: bytes, password: str) -> EncryptedData:
|
|
|
"""
|
|
|
使用密码加密数据
|
|
|
|
|
|
实现 AES-256 加密存储 (需求 10.2)
|
|
|
|
|
|
Args:
|
|
|
plaintext: 明文数据
|
|
|
password: 用户密码
|
|
|
|
|
|
Returns:
|
|
|
加密数据对象(包含盐值)
|
|
|
"""
|
|
|
key, salt = self.derive_key_from_password(password)
|
|
|
|
|
|
# 临时使用派生密钥
|
|
|
original_key = self._key
|
|
|
self._key = key
|
|
|
|
|
|
try:
|
|
|
encrypted = self.encrypt(plaintext)
|
|
|
encrypted.salt = salt
|
|
|
return encrypted
|
|
|
finally:
|
|
|
self._key = original_key
|
|
|
|
|
|
def decrypt_with_password(self, encrypted_data: EncryptedData, password: str) -> bytes:
|
|
|
"""
|
|
|
使用密码解密数据
|
|
|
|
|
|
Args:
|
|
|
encrypted_data: 加密数据对象
|
|
|
password: 用户密码
|
|
|
|
|
|
Returns:
|
|
|
解密后的明文
|
|
|
"""
|
|
|
key, _ = self.derive_key_from_password(password, encrypted_data.salt)
|
|
|
|
|
|
# 临时使用派生密钥
|
|
|
original_key = self._key
|
|
|
self._key = key
|
|
|
|
|
|
try:
|
|
|
return self.decrypt(encrypted_data)
|
|
|
finally:
|
|
|
self._key = original_key
|
|
|
|
|
|
|
|
|
|
|
|
class TLSManager:
|
|
|
"""
|
|
|
TLS/SSL 连接管理器
|
|
|
|
|
|
实现 TLS/SSL 连接 (需求 10.1)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, cert_file: Optional[str] = None,
|
|
|
key_file: Optional[str] = None,
|
|
|
ca_file: Optional[str] = None):
|
|
|
"""
|
|
|
初始化TLS管理器
|
|
|
|
|
|
Args:
|
|
|
cert_file: 证书文件路径
|
|
|
key_file: 私钥文件路径
|
|
|
ca_file: CA证书文件路径
|
|
|
"""
|
|
|
self.cert_file = cert_file
|
|
|
self.key_file = key_file
|
|
|
self.ca_file = ca_file
|
|
|
|
|
|
def create_server_ssl_context(self) -> ssl.SSLContext:
|
|
|
"""
|
|
|
创建服务器SSL上下文
|
|
|
|
|
|
实现 TLS/SSL 连接 (需求 10.1)
|
|
|
|
|
|
Returns:
|
|
|
SSL上下文对象
|
|
|
|
|
|
Raises:
|
|
|
CertificateError: 证书配置错误
|
|
|
"""
|
|
|
try:
|
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
|
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
|
|
|
|
# 设置安全选项
|
|
|
context.options |= ssl.OP_NO_SSLv2
|
|
|
context.options |= ssl.OP_NO_SSLv3
|
|
|
context.options |= ssl.OP_NO_TLSv1
|
|
|
context.options |= ssl.OP_NO_TLSv1_1
|
|
|
|
|
|
# 加载证书和私钥
|
|
|
if self.cert_file and self.key_file:
|
|
|
context.load_cert_chain(self.cert_file, self.key_file)
|
|
|
else:
|
|
|
# 生成自签名证书(仅用于开发/测试)
|
|
|
logger.warning("No certificate provided, using self-signed certificate")
|
|
|
self._generate_self_signed_cert()
|
|
|
if self.cert_file and self.key_file:
|
|
|
context.load_cert_chain(self.cert_file, self.key_file)
|
|
|
|
|
|
# 设置密码套件
|
|
|
context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20')
|
|
|
|
|
|
return context
|
|
|
|
|
|
except ssl.SSLError as e:
|
|
|
raise CertificateError(f"Failed to create server SSL context: {e}")
|
|
|
except FileNotFoundError as e:
|
|
|
raise CertificateError(f"Certificate file not found: {e}")
|
|
|
|
|
|
def create_client_ssl_context(self, verify: bool = True) -> ssl.SSLContext:
|
|
|
"""
|
|
|
创建客户端SSL上下文
|
|
|
|
|
|
实现 TLS/SSL 连接 (需求 10.1)
|
|
|
|
|
|
Args:
|
|
|
verify: 是否验证服务器证书
|
|
|
|
|
|
Returns:
|
|
|
SSL上下文对象
|
|
|
"""
|
|
|
try:
|
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
|
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
|
|
|
|
# 设置安全选项
|
|
|
context.options |= ssl.OP_NO_SSLv2
|
|
|
context.options |= ssl.OP_NO_SSLv3
|
|
|
context.options |= ssl.OP_NO_TLSv1
|
|
|
context.options |= ssl.OP_NO_TLSv1_1
|
|
|
|
|
|
if verify:
|
|
|
context.verify_mode = ssl.CERT_REQUIRED
|
|
|
context.check_hostname = True
|
|
|
|
|
|
# 加载CA证书
|
|
|
if self.ca_file:
|
|
|
context.load_verify_locations(self.ca_file)
|
|
|
else:
|
|
|
# 使用系统默认CA
|
|
|
context.load_default_certs()
|
|
|
else:
|
|
|
# 不验证证书(仅用于开发/测试)
|
|
|
context.check_hostname = False
|
|
|
context.verify_mode = ssl.CERT_NONE
|
|
|
logger.warning("SSL certificate verification disabled")
|
|
|
|
|
|
# 设置密码套件
|
|
|
context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20')
|
|
|
|
|
|
return context
|
|
|
|
|
|
except ssl.SSLError as e:
|
|
|
raise CertificateError(f"Failed to create client SSL context: {e}")
|
|
|
|
|
|
def _generate_self_signed_cert(self) -> None:
|
|
|
"""
|
|
|
生成自签名证书(仅用于开发/测试)
|
|
|
"""
|
|
|
from cryptography import x509
|
|
|
from cryptography.x509.oid import NameOID
|
|
|
from datetime import timedelta
|
|
|
|
|
|
# 生成私钥
|
|
|
private_key = rsa.generate_private_key(
|
|
|
public_exponent=65537,
|
|
|
key_size=2048,
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
|
|
|
# 创建证书
|
|
|
subject = issuer = x509.Name([
|
|
|
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
|
|
|
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Beijing"),
|
|
|
x509.NameAttribute(NameOID.LOCALITY_NAME, "Beijing"),
|
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "P2P Chat"),
|
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
|
|
|
])
|
|
|
|
|
|
cert = (
|
|
|
x509.CertificateBuilder()
|
|
|
.subject_name(subject)
|
|
|
.issuer_name(issuer)
|
|
|
.public_key(private_key.public_key())
|
|
|
.serial_number(x509.random_serial_number())
|
|
|
.not_valid_before(datetime.utcnow())
|
|
|
.not_valid_after(datetime.utcnow() + timedelta(days=365))
|
|
|
.add_extension(
|
|
|
x509.SubjectAlternativeName([
|
|
|
x509.DNSName("localhost"),
|
|
|
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
|
|
]),
|
|
|
critical=False,
|
|
|
)
|
|
|
.sign(private_key, hashes.SHA256(), default_backend())
|
|
|
)
|
|
|
|
|
|
# 保存证书和私钥
|
|
|
cert_dir = Path("certs")
|
|
|
cert_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
self.cert_file = str(cert_dir / "server.crt")
|
|
|
self.key_file = str(cert_dir / "server.key")
|
|
|
|
|
|
with open(self.cert_file, "wb") as f:
|
|
|
f.write(cert.public_bytes(serialization.Encoding.PEM))
|
|
|
|
|
|
with open(self.key_file, "wb") as f:
|
|
|
f.write(private_key.private_bytes(
|
|
|
encoding=serialization.Encoding.PEM,
|
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
|
))
|
|
|
|
|
|
logger.info(f"Generated self-signed certificate: {self.cert_file}")
|
|
|
|
|
|
|
|
|
|
|
|
class MessageEncryptor:
|
|
|
"""
|
|
|
消息加密器
|
|
|
|
|
|
实现消息加密传输 (需求 10.1)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, cipher: Optional[AESCipher] = None):
|
|
|
"""
|
|
|
初始化消息加密器
|
|
|
|
|
|
Args:
|
|
|
cipher: AES加密器,如果为None则创建新的
|
|
|
"""
|
|
|
self._cipher = cipher or AESCipher()
|
|
|
|
|
|
@property
|
|
|
def key(self) -> bytes:
|
|
|
"""获取加密密钥"""
|
|
|
return self._cipher.key
|
|
|
|
|
|
def set_key(self, key: bytes) -> None:
|
|
|
"""
|
|
|
设置加密密钥
|
|
|
|
|
|
Args:
|
|
|
key: 256位密钥
|
|
|
"""
|
|
|
self._cipher = AESCipher(key)
|
|
|
|
|
|
def encrypt_message(self, message_data: bytes) -> bytes:
|
|
|
"""
|
|
|
加密消息数据
|
|
|
|
|
|
实现消息加密传输 (需求 10.1)
|
|
|
|
|
|
Args:
|
|
|
message_data: 消息数据(序列化后的字节流)
|
|
|
|
|
|
Returns:
|
|
|
加密后的字节流
|
|
|
"""
|
|
|
encrypted = self._cipher.encrypt(message_data)
|
|
|
return encrypted.to_bytes()
|
|
|
|
|
|
def decrypt_message(self, encrypted_data: bytes) -> bytes:
|
|
|
"""
|
|
|
解密消息数据
|
|
|
|
|
|
Args:
|
|
|
encrypted_data: 加密的字节流
|
|
|
|
|
|
Returns:
|
|
|
解密后的消息数据
|
|
|
"""
|
|
|
encrypted = EncryptedData.from_bytes(encrypted_data)
|
|
|
return self._cipher.decrypt(encrypted)
|
|
|
|
|
|
|
|
|
class FileEncryptor:
|
|
|
"""
|
|
|
文件加密器
|
|
|
|
|
|
实现文件加密传输 (需求 10.3)
|
|
|
"""
|
|
|
|
|
|
# 文件加密块大小
|
|
|
BLOCK_SIZE = 64 * 1024 # 64KB
|
|
|
|
|
|
def __init__(self, cipher: Optional[AESCipher] = None):
|
|
|
"""
|
|
|
初始化文件加密器
|
|
|
|
|
|
Args:
|
|
|
cipher: AES加密器
|
|
|
"""
|
|
|
self._cipher = cipher or AESCipher()
|
|
|
|
|
|
@property
|
|
|
def key(self) -> bytes:
|
|
|
"""获取加密密钥"""
|
|
|
return self._cipher.key
|
|
|
|
|
|
def set_key(self, key: bytes) -> None:
|
|
|
"""
|
|
|
设置加密密钥
|
|
|
|
|
|
Args:
|
|
|
key: 256位密钥
|
|
|
"""
|
|
|
self._cipher = AESCipher(key)
|
|
|
|
|
|
def encrypt_file(self, input_path: str, output_path: str) -> bool:
|
|
|
"""
|
|
|
加密文件
|
|
|
|
|
|
实现文件加密传输 (需求 10.3)
|
|
|
|
|
|
Args:
|
|
|
input_path: 输入文件路径
|
|
|
output_path: 输出文件路径
|
|
|
|
|
|
Returns:
|
|
|
加密成功返回True
|
|
|
|
|
|
Raises:
|
|
|
EncryptionError: 加密失败
|
|
|
"""
|
|
|
try:
|
|
|
# 读取整个文件
|
|
|
with open(input_path, 'rb') as f:
|
|
|
plaintext = f.read()
|
|
|
|
|
|
# 加密
|
|
|
encrypted = self._cipher.encrypt(plaintext)
|
|
|
|
|
|
# 写入加密文件
|
|
|
with open(output_path, 'wb') as f:
|
|
|
f.write(encrypted.to_bytes())
|
|
|
|
|
|
logger.info(f"File encrypted: {input_path} -> {output_path}")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
raise EncryptionError(f"Failed to encrypt file: {e}")
|
|
|
|
|
|
def decrypt_file(self, input_path: str, output_path: str) -> bool:
|
|
|
"""
|
|
|
解密文件
|
|
|
|
|
|
Args:
|
|
|
input_path: 加密文件路径
|
|
|
output_path: 输出文件路径
|
|
|
|
|
|
Returns:
|
|
|
解密成功返回True
|
|
|
|
|
|
Raises:
|
|
|
DecryptionError: 解密失败
|
|
|
"""
|
|
|
try:
|
|
|
# 读取加密文件
|
|
|
with open(input_path, 'rb') as f:
|
|
|
encrypted_data = f.read()
|
|
|
|
|
|
# 解密
|
|
|
encrypted = EncryptedData.from_bytes(encrypted_data)
|
|
|
plaintext = self._cipher.decrypt(encrypted)
|
|
|
|
|
|
# 写入解密文件
|
|
|
with open(output_path, 'wb') as f:
|
|
|
f.write(plaintext)
|
|
|
|
|
|
logger.info(f"File decrypted: {input_path} -> {output_path}")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
raise DecryptionError(f"Failed to decrypt file: {e}")
|
|
|
|
|
|
def encrypt_chunk(self, chunk_data: bytes) -> bytes:
|
|
|
"""
|
|
|
加密文件块
|
|
|
|
|
|
用于分块传输时的加密
|
|
|
|
|
|
Args:
|
|
|
chunk_data: 文件块数据
|
|
|
|
|
|
Returns:
|
|
|
加密后的数据
|
|
|
"""
|
|
|
encrypted = self._cipher.encrypt(chunk_data)
|
|
|
return encrypted.to_bytes()
|
|
|
|
|
|
def decrypt_chunk(self, encrypted_chunk: bytes) -> bytes:
|
|
|
"""
|
|
|
解密文件块
|
|
|
|
|
|
Args:
|
|
|
encrypted_chunk: 加密的文件块
|
|
|
|
|
|
Returns:
|
|
|
解密后的数据
|
|
|
"""
|
|
|
encrypted = EncryptedData.from_bytes(encrypted_chunk)
|
|
|
return self._cipher.decrypt(encrypted)
|
|
|
|
|
|
|
|
|
|
|
|
class KeyManager:
|
|
|
"""
|
|
|
密钥管理器
|
|
|
|
|
|
实现密钥管理 (需求 10.2)
|
|
|
"""
|
|
|
|
|
|
# 密钥存储目录
|
|
|
KEY_DIR = "keys"
|
|
|
|
|
|
def __init__(self, key_dir: Optional[str] = None):
|
|
|
"""
|
|
|
初始化密钥管理器
|
|
|
|
|
|
Args:
|
|
|
key_dir: 密钥存储目录
|
|
|
"""
|
|
|
self._key_dir = Path(key_dir or self.KEY_DIR)
|
|
|
self._key_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# 内存中的密钥缓存
|
|
|
self._key_cache: dict = {}
|
|
|
|
|
|
def generate_key_pair(self) -> Tuple[bytes, bytes]:
|
|
|
"""
|
|
|
生成RSA密钥对
|
|
|
|
|
|
Returns:
|
|
|
(私钥, 公钥) 元组(PEM格式)
|
|
|
"""
|
|
|
private_key = rsa.generate_private_key(
|
|
|
public_exponent=65537,
|
|
|
key_size=2048,
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
|
|
|
private_pem = private_key.private_bytes(
|
|
|
encoding=serialization.Encoding.PEM,
|
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
|
)
|
|
|
|
|
|
public_pem = private_key.public_key().public_bytes(
|
|
|
encoding=serialization.Encoding.PEM,
|
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
|
)
|
|
|
|
|
|
return private_pem, public_pem
|
|
|
|
|
|
def save_key(self, key_id: str, key_data: bytes, password: Optional[str] = None) -> bool:
|
|
|
"""
|
|
|
保存密钥到文件
|
|
|
|
|
|
实现密钥管理 (需求 10.2)
|
|
|
|
|
|
Args:
|
|
|
key_id: 密钥标识
|
|
|
key_data: 密钥数据
|
|
|
password: 可选的加密密码
|
|
|
|
|
|
Returns:
|
|
|
保存成功返回True
|
|
|
"""
|
|
|
try:
|
|
|
key_file = self._key_dir / f"{key_id}.key"
|
|
|
|
|
|
if password:
|
|
|
# 使用密码加密密钥
|
|
|
cipher = AESCipher()
|
|
|
encrypted = cipher.encrypt_with_password(key_data, password)
|
|
|
with open(key_file, 'wb') as f:
|
|
|
f.write(encrypted.to_bytes())
|
|
|
else:
|
|
|
# 直接保存(不推荐用于生产环境)
|
|
|
with open(key_file, 'wb') as f:
|
|
|
f.write(key_data)
|
|
|
|
|
|
logger.info(f"Key saved: {key_id}")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to save key {key_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
def load_key(self, key_id: str, password: Optional[str] = None) -> Optional[bytes]:
|
|
|
"""
|
|
|
从文件加载密钥
|
|
|
|
|
|
Args:
|
|
|
key_id: 密钥标识
|
|
|
password: 解密密码
|
|
|
|
|
|
Returns:
|
|
|
密钥数据,如果加载失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
key_file = self._key_dir / f"{key_id}.key"
|
|
|
|
|
|
if not key_file.exists():
|
|
|
logger.error(f"Key file not found: {key_id}")
|
|
|
return None
|
|
|
|
|
|
with open(key_file, 'rb') as f:
|
|
|
data = f.read()
|
|
|
|
|
|
if password:
|
|
|
# 解密密钥
|
|
|
cipher = AESCipher()
|
|
|
encrypted = EncryptedData.from_bytes(data)
|
|
|
return cipher.decrypt_with_password(encrypted, password)
|
|
|
else:
|
|
|
return data
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load key {key_id}: {e}")
|
|
|
return None
|
|
|
|
|
|
def delete_key(self, key_id: str) -> bool:
|
|
|
"""
|
|
|
删除密钥
|
|
|
|
|
|
Args:
|
|
|
key_id: 密钥标识
|
|
|
|
|
|
Returns:
|
|
|
删除成功返回True
|
|
|
"""
|
|
|
try:
|
|
|
key_file = self._key_dir / f"{key_id}.key"
|
|
|
|
|
|
if key_file.exists():
|
|
|
key_file.unlink()
|
|
|
logger.info(f"Key deleted: {key_id}")
|
|
|
|
|
|
# 从缓存中移除
|
|
|
self._key_cache.pop(key_id, None)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to delete key {key_id}: {e}")
|
|
|
return False
|
|
|
|
|
|
def get_or_create_session_key(self, session_id: str) -> bytes:
|
|
|
"""
|
|
|
获取或创建会话密钥
|
|
|
|
|
|
Args:
|
|
|
session_id: 会话标识
|
|
|
|
|
|
Returns:
|
|
|
会话密钥
|
|
|
"""
|
|
|
if session_id in self._key_cache:
|
|
|
return self._key_cache[session_id]
|
|
|
|
|
|
# 生成新的会话密钥
|
|
|
key = AESCipher.generate_key()
|
|
|
self._key_cache[session_id] = key
|
|
|
|
|
|
return key
|
|
|
|
|
|
def clear_session_key(self, session_id: str) -> None:
|
|
|
"""
|
|
|
清除会话密钥
|
|
|
|
|
|
Args:
|
|
|
session_id: 会话标识
|
|
|
"""
|
|
|
self._key_cache.pop(session_id, None)
|
|
|
|
|
|
|
|
|
|
|
|
class LocalDataEncryptor:
|
|
|
"""
|
|
|
本地数据加密器
|
|
|
|
|
|
实现 AES-256 加密存储 (需求 10.2)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, password: Optional[str] = None, key: Optional[bytes] = None):
|
|
|
"""
|
|
|
初始化本地数据加密器
|
|
|
|
|
|
Args:
|
|
|
password: 用户密码(用于派生密钥)
|
|
|
key: 直接提供的密钥
|
|
|
"""
|
|
|
self._cipher = AESCipher()
|
|
|
self._password = password
|
|
|
self._salt: Optional[bytes] = None
|
|
|
|
|
|
if key:
|
|
|
self._cipher = AESCipher(key)
|
|
|
elif password:
|
|
|
# 从密码派生密钥
|
|
|
key, self._salt = AESCipher.derive_key_from_password(password)
|
|
|
self._cipher = AESCipher(key)
|
|
|
|
|
|
def encrypt_data(self, data: Union[str, bytes, dict, list]) -> str:
|
|
|
"""
|
|
|
加密数据
|
|
|
|
|
|
实现对本地数据进行加密存储 (需求 10.2)
|
|
|
|
|
|
Args:
|
|
|
data: 要加密的数据(字符串、字节、字典或列表)
|
|
|
|
|
|
Returns:
|
|
|
Base64编码的加密数据
|
|
|
"""
|
|
|
# 转换为字节
|
|
|
if isinstance(data, str):
|
|
|
plaintext = data.encode('utf-8')
|
|
|
elif isinstance(data, (dict, list)):
|
|
|
plaintext = json.dumps(data, ensure_ascii=False).encode('utf-8')
|
|
|
else:
|
|
|
plaintext = data
|
|
|
|
|
|
encrypted = self._cipher.encrypt(plaintext)
|
|
|
|
|
|
# 如果有盐值,添加到加密数据中
|
|
|
if self._salt:
|
|
|
encrypted.salt = self._salt
|
|
|
|
|
|
return encrypted.to_base64()
|
|
|
|
|
|
def decrypt_data(self, encrypted_data: str, as_json: bool = False) -> Union[bytes, dict]:
|
|
|
"""
|
|
|
解密数据
|
|
|
|
|
|
Args:
|
|
|
encrypted_data: Base64编码的加密数据
|
|
|
as_json: 是否解析为JSON
|
|
|
|
|
|
Returns:
|
|
|
解密后的数据
|
|
|
"""
|
|
|
encrypted = EncryptedData.from_base64(encrypted_data)
|
|
|
plaintext = self._cipher.decrypt(encrypted)
|
|
|
|
|
|
if as_json:
|
|
|
return json.loads(plaintext.decode('utf-8'))
|
|
|
|
|
|
return plaintext
|
|
|
|
|
|
def encrypt_chat_history(self, messages: list) -> str:
|
|
|
"""
|
|
|
加密聊天记录
|
|
|
|
|
|
实现对本地数据进行加密存储 (需求 10.2)
|
|
|
|
|
|
Args:
|
|
|
messages: 聊天消息列表
|
|
|
|
|
|
Returns:
|
|
|
加密后的数据
|
|
|
"""
|
|
|
return self.encrypt_data(messages)
|
|
|
|
|
|
def decrypt_chat_history(self, encrypted_data: str) -> list:
|
|
|
"""
|
|
|
解密聊天记录
|
|
|
|
|
|
Args:
|
|
|
encrypted_data: 加密的聊天记录
|
|
|
|
|
|
Returns:
|
|
|
聊天消息列表
|
|
|
"""
|
|
|
return self.decrypt_data(encrypted_data, as_json=True)
|
|
|
|
|
|
def save_encrypted_file(self, data: Union[str, bytes, dict], file_path: str) -> bool:
|
|
|
"""
|
|
|
保存加密数据到文件
|
|
|
|
|
|
Args:
|
|
|
data: 要保存的数据
|
|
|
file_path: 文件路径
|
|
|
|
|
|
Returns:
|
|
|
保存成功返回True
|
|
|
"""
|
|
|
try:
|
|
|
encrypted = self.encrypt_data(data)
|
|
|
|
|
|
# 确保目录存在
|
|
|
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
|
f.write(encrypted)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to save encrypted file: {e}")
|
|
|
return False
|
|
|
|
|
|
def load_encrypted_file(self, file_path: str, as_json: bool = False) -> Optional[Union[bytes, dict]]:
|
|
|
"""
|
|
|
从文件加载加密数据
|
|
|
|
|
|
Args:
|
|
|
file_path: 文件路径
|
|
|
as_json: 是否解析为JSON
|
|
|
|
|
|
Returns:
|
|
|
解密后的数据,如果失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
encrypted = f.read()
|
|
|
|
|
|
return self.decrypt_data(encrypted, as_json=as_json)
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
logger.error(f"Encrypted file not found: {file_path}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load encrypted file: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
# 需要导入ipaddress模块用于自签名证书
|
|
|
import ipaddress
|
|
|
|
|
|
|
|
|
# 便捷函数
|
|
|
def create_message_encryptor(key: Optional[bytes] = None) -> MessageEncryptor:
|
|
|
"""创建消息加密器"""
|
|
|
cipher = AESCipher(key) if key else AESCipher()
|
|
|
return MessageEncryptor(cipher)
|
|
|
|
|
|
|
|
|
def create_file_encryptor(key: Optional[bytes] = None) -> FileEncryptor:
|
|
|
"""创建文件加密器"""
|
|
|
cipher = AESCipher(key) if key else AESCipher()
|
|
|
return FileEncryptor(cipher)
|
|
|
|
|
|
|
|
|
def create_local_data_encryptor(password: str) -> LocalDataEncryptor:
|
|
|
"""创建本地数据加密器"""
|
|
|
return LocalDataEncryptor(password=password)
|
|
|
|
|
|
|
|
|
def encrypt_message(message_data: bytes, key: bytes) -> bytes:
|
|
|
"""快速加密消息"""
|
|
|
encryptor = MessageEncryptor(AESCipher(key))
|
|
|
return encryptor.encrypt_message(message_data)
|
|
|
|
|
|
|
|
|
def decrypt_message(encrypted_data: bytes, key: bytes) -> bytes:
|
|
|
"""快速解密消息"""
|
|
|
encryptor = MessageEncryptor(AESCipher(key))
|
|
|
return encryptor.decrypt_message(encrypted_data)
|