You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1038 lines
29 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Security Module
"""
安全模块
负责传输加密、本地数据加密和密钥管理
需求: 10.1, 10.2, 10.3
"""
import base64
import hashlib
import hmac
import json
import logging
import os
import secrets
import ssl
import struct
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple, Union
from cryptography.hazmat.primitives import hashes, padding, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
# 设置日志
logger = logging.getLogger(__name__)
class SecurityError(Exception):
"""安全错误基类"""
pass
class EncryptionError(SecurityError):
"""加密错误"""
pass
class DecryptionError(SecurityError):
"""解密错误"""
pass
class KeyManagementError(SecurityError):
"""密钥管理错误"""
pass
class CertificateError(SecurityError):
"""证书错误"""
pass
@dataclass
class EncryptedData:
"""加密数据结构"""
ciphertext: bytes
iv: bytes # 初始化向量
tag: bytes # 认证标签 (GCM模式)
salt: bytes = field(default_factory=bytes) # 用于密钥派生
def to_bytes(self) -> bytes:
"""序列化为字节流"""
# 格式: [iv_len(2)][iv][tag_len(2)][tag][salt_len(2)][salt][ciphertext]
result = struct.pack('!H', len(self.iv)) + self.iv
result += struct.pack('!H', len(self.tag)) + self.tag
result += struct.pack('!H', len(self.salt)) + self.salt
result += self.ciphertext
return result
@classmethod
def from_bytes(cls, data: bytes) -> "EncryptedData":
"""从字节流反序列化"""
offset = 0
# 读取IV
iv_len = struct.unpack('!H', data[offset:offset+2])[0]
offset += 2
iv = data[offset:offset+iv_len]
offset += iv_len
# 读取tag
tag_len = struct.unpack('!H', data[offset:offset+2])[0]
offset += 2
tag = data[offset:offset+tag_len]
offset += tag_len
# 读取salt
salt_len = struct.unpack('!H', data[offset:offset+2])[0]
offset += 2
salt = data[offset:offset+salt_len]
offset += salt_len
# 剩余为密文
ciphertext = data[offset:]
return cls(ciphertext=ciphertext, iv=iv, tag=tag, salt=salt)
def to_base64(self) -> str:
"""转换为Base64字符串"""
return base64.b64encode(self.to_bytes()).decode('utf-8')
@classmethod
def from_base64(cls, data: str) -> "EncryptedData":
"""从Base64字符串创建"""
return cls.from_bytes(base64.b64decode(data))
class AESCipher:
"""
AES-256-GCM 加密器
用于消息和文件的加密传输 (需求 10.1, 10.3)
以及本地数据的加密存储 (需求 10.2)
"""
# AES-256 密钥长度
KEY_SIZE = 32 # 256 bits
# GCM IV 长度
IV_SIZE = 12 # 96 bits (推荐)
# GCM 认证标签长度
TAG_SIZE = 16 # 128 bits
# PBKDF2 盐长度
SALT_SIZE = 16 # 128 bits
# PBKDF2 迭代次数
PBKDF2_ITERATIONS = 100000
def __init__(self, key: Optional[bytes] = None):
"""
初始化AES加密器
Args:
key: 256位密钥如果为None则生成新密钥
"""
if key is None:
self._key = self.generate_key()
else:
if len(key) != self.KEY_SIZE:
raise EncryptionError(f"Key must be {self.KEY_SIZE} bytes")
self._key = key
@property
def key(self) -> bytes:
"""获取密钥"""
return self._key
@staticmethod
def generate_key() -> bytes:
"""
生成随机256位密钥
Returns:
32字节随机密钥
"""
return secrets.token_bytes(AESCipher.KEY_SIZE)
@staticmethod
def generate_iv() -> bytes:
"""
生成随机IV
Returns:
12字节随机IV
"""
return secrets.token_bytes(AESCipher.IV_SIZE)
@staticmethod
def derive_key_from_password(password: str, salt: Optional[bytes] = None) -> Tuple[bytes, bytes]:
"""
从密码派生密钥 (PBKDF2)
实现密钥管理 (需求 10.2)
Args:
password: 用户密码
salt: 盐值如果为None则生成新盐
Returns:
(密钥, 盐值) 元组
"""
if salt is None:
salt = secrets.token_bytes(AESCipher.SALT_SIZE)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=AESCipher.KEY_SIZE,
salt=salt,
iterations=AESCipher.PBKDF2_ITERATIONS,
backend=default_backend()
)
key = kdf.derive(password.encode('utf-8'))
return key, salt
def encrypt(self, plaintext: bytes) -> EncryptedData:
"""
加密数据 (AES-256-GCM)
实现消息加密传输 (需求 10.1)
实现文件加密传输 (需求 10.3)
Args:
plaintext: 明文数据
Returns:
加密数据对象
Raises:
EncryptionError: 加密失败
"""
try:
iv = self.generate_iv()
cipher = Cipher(
algorithms.AES(self._key),
modes.GCM(iv),
backend=default_backend()
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return EncryptedData(
ciphertext=ciphertext,
iv=iv,
tag=encryptor.tag
)
except Exception as e:
raise EncryptionError(f"Encryption failed: {e}")
def decrypt(self, encrypted_data: EncryptedData) -> bytes:
"""
解密数据 (AES-256-GCM)
Args:
encrypted_data: 加密数据对象
Returns:
解密后的明文
Raises:
DecryptionError: 解密失败
"""
try:
cipher = Cipher(
algorithms.AES(self._key),
modes.GCM(encrypted_data.iv, encrypted_data.tag),
backend=default_backend()
)
decryptor = cipher.decryptor()
plaintext = decryptor.update(encrypted_data.ciphertext) + decryptor.finalize()
return plaintext
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
def encrypt_with_password(self, plaintext: bytes, password: str) -> EncryptedData:
"""
使用密码加密数据
实现 AES-256 加密存储 (需求 10.2)
Args:
plaintext: 明文数据
password: 用户密码
Returns:
加密数据对象(包含盐值)
"""
key, salt = self.derive_key_from_password(password)
# 临时使用派生密钥
original_key = self._key
self._key = key
try:
encrypted = self.encrypt(plaintext)
encrypted.salt = salt
return encrypted
finally:
self._key = original_key
def decrypt_with_password(self, encrypted_data: EncryptedData, password: str) -> bytes:
"""
使用密码解密数据
Args:
encrypted_data: 加密数据对象
password: 用户密码
Returns:
解密后的明文
"""
key, _ = self.derive_key_from_password(password, encrypted_data.salt)
# 临时使用派生密钥
original_key = self._key
self._key = key
try:
return self.decrypt(encrypted_data)
finally:
self._key = original_key
class TLSManager:
"""
TLS/SSL 连接管理器
实现 TLS/SSL 连接 (需求 10.1)
"""
def __init__(self, cert_file: Optional[str] = None,
key_file: Optional[str] = None,
ca_file: Optional[str] = None):
"""
初始化TLS管理器
Args:
cert_file: 证书文件路径
key_file: 私钥文件路径
ca_file: CA证书文件路径
"""
self.cert_file = cert_file
self.key_file = key_file
self.ca_file = ca_file
def create_server_ssl_context(self) -> ssl.SSLContext:
"""
创建服务器SSL上下文
实现 TLS/SSL 连接 (需求 10.1)
Returns:
SSL上下文对象
Raises:
CertificateError: 证书配置错误
"""
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.minimum_version = ssl.TLSVersion.TLSv1_2
# 设置安全选项
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
# 加载证书和私钥
if self.cert_file and self.key_file:
context.load_cert_chain(self.cert_file, self.key_file)
else:
# 生成自签名证书(仅用于开发/测试)
logger.warning("No certificate provided, using self-signed certificate")
self._generate_self_signed_cert()
if self.cert_file and self.key_file:
context.load_cert_chain(self.cert_file, self.key_file)
# 设置密码套件
context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20')
return context
except ssl.SSLError as e:
raise CertificateError(f"Failed to create server SSL context: {e}")
except FileNotFoundError as e:
raise CertificateError(f"Certificate file not found: {e}")
def create_client_ssl_context(self, verify: bool = True) -> ssl.SSLContext:
"""
创建客户端SSL上下文
实现 TLS/SSL 连接 (需求 10.1)
Args:
verify: 是否验证服务器证书
Returns:
SSL上下文对象
"""
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.minimum_version = ssl.TLSVersion.TLSv1_2
# 设置安全选项
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
if verify:
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
# 加载CA证书
if self.ca_file:
context.load_verify_locations(self.ca_file)
else:
# 使用系统默认CA
context.load_default_certs()
else:
# 不验证证书(仅用于开发/测试)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
logger.warning("SSL certificate verification disabled")
# 设置密码套件
context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20')
return context
except ssl.SSLError as e:
raise CertificateError(f"Failed to create client SSL context: {e}")
def _generate_self_signed_cert(self) -> None:
"""
生成自签名证书(仅用于开发/测试)
"""
from cryptography import x509
from cryptography.x509.oid import NameOID
from datetime import timedelta
# 生成私钥
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# 创建证书
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Beijing"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Beijing"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "P2P Chat"),
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.utcnow())
.not_valid_after(datetime.utcnow() + timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
]),
critical=False,
)
.sign(private_key, hashes.SHA256(), default_backend())
)
# 保存证书和私钥
cert_dir = Path("certs")
cert_dir.mkdir(exist_ok=True)
self.cert_file = str(cert_dir / "server.crt")
self.key_file = str(cert_dir / "server.key")
with open(self.cert_file, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
with open(self.key_file, "wb") as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
))
logger.info(f"Generated self-signed certificate: {self.cert_file}")
class MessageEncryptor:
"""
消息加密器
实现消息加密传输 (需求 10.1)
"""
def __init__(self, cipher: Optional[AESCipher] = None):
"""
初始化消息加密器
Args:
cipher: AES加密器如果为None则创建新的
"""
self._cipher = cipher or AESCipher()
@property
def key(self) -> bytes:
"""获取加密密钥"""
return self._cipher.key
def set_key(self, key: bytes) -> None:
"""
设置加密密钥
Args:
key: 256位密钥
"""
self._cipher = AESCipher(key)
def encrypt_message(self, message_data: bytes) -> bytes:
"""
加密消息数据
实现消息加密传输 (需求 10.1)
Args:
message_data: 消息数据(序列化后的字节流)
Returns:
加密后的字节流
"""
encrypted = self._cipher.encrypt(message_data)
return encrypted.to_bytes()
def decrypt_message(self, encrypted_data: bytes) -> bytes:
"""
解密消息数据
Args:
encrypted_data: 加密的字节流
Returns:
解密后的消息数据
"""
encrypted = EncryptedData.from_bytes(encrypted_data)
return self._cipher.decrypt(encrypted)
class FileEncryptor:
"""
文件加密器
实现文件加密传输 (需求 10.3)
"""
# 文件加密块大小
BLOCK_SIZE = 64 * 1024 # 64KB
def __init__(self, cipher: Optional[AESCipher] = None):
"""
初始化文件加密器
Args:
cipher: AES加密器
"""
self._cipher = cipher or AESCipher()
@property
def key(self) -> bytes:
"""获取加密密钥"""
return self._cipher.key
def set_key(self, key: bytes) -> None:
"""
设置加密密钥
Args:
key: 256位密钥
"""
self._cipher = AESCipher(key)
def encrypt_file(self, input_path: str, output_path: str) -> bool:
"""
加密文件
实现文件加密传输 (需求 10.3)
Args:
input_path: 输入文件路径
output_path: 输出文件路径
Returns:
加密成功返回True
Raises:
EncryptionError: 加密失败
"""
try:
# 读取整个文件
with open(input_path, 'rb') as f:
plaintext = f.read()
# 加密
encrypted = self._cipher.encrypt(plaintext)
# 写入加密文件
with open(output_path, 'wb') as f:
f.write(encrypted.to_bytes())
logger.info(f"File encrypted: {input_path} -> {output_path}")
return True
except Exception as e:
raise EncryptionError(f"Failed to encrypt file: {e}")
def decrypt_file(self, input_path: str, output_path: str) -> bool:
"""
解密文件
Args:
input_path: 加密文件路径
output_path: 输出文件路径
Returns:
解密成功返回True
Raises:
DecryptionError: 解密失败
"""
try:
# 读取加密文件
with open(input_path, 'rb') as f:
encrypted_data = f.read()
# 解密
encrypted = EncryptedData.from_bytes(encrypted_data)
plaintext = self._cipher.decrypt(encrypted)
# 写入解密文件
with open(output_path, 'wb') as f:
f.write(plaintext)
logger.info(f"File decrypted: {input_path} -> {output_path}")
return True
except Exception as e:
raise DecryptionError(f"Failed to decrypt file: {e}")
def encrypt_chunk(self, chunk_data: bytes) -> bytes:
"""
加密文件块
用于分块传输时的加密
Args:
chunk_data: 文件块数据
Returns:
加密后的数据
"""
encrypted = self._cipher.encrypt(chunk_data)
return encrypted.to_bytes()
def decrypt_chunk(self, encrypted_chunk: bytes) -> bytes:
"""
解密文件块
Args:
encrypted_chunk: 加密的文件块
Returns:
解密后的数据
"""
encrypted = EncryptedData.from_bytes(encrypted_chunk)
return self._cipher.decrypt(encrypted)
class KeyManager:
"""
密钥管理器
实现密钥管理 (需求 10.2)
"""
# 密钥存储目录
KEY_DIR = "keys"
def __init__(self, key_dir: Optional[str] = None):
"""
初始化密钥管理器
Args:
key_dir: 密钥存储目录
"""
self._key_dir = Path(key_dir or self.KEY_DIR)
self._key_dir.mkdir(parents=True, exist_ok=True)
# 内存中的密钥缓存
self._key_cache: dict = {}
def generate_key_pair(self) -> Tuple[bytes, bytes]:
"""
生成RSA密钥对
Returns:
(私钥, 公钥) 元组PEM格式
"""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return private_pem, public_pem
def save_key(self, key_id: str, key_data: bytes, password: Optional[str] = None) -> bool:
"""
保存密钥到文件
实现密钥管理 (需求 10.2)
Args:
key_id: 密钥标识
key_data: 密钥数据
password: 可选的加密密码
Returns:
保存成功返回True
"""
try:
key_file = self._key_dir / f"{key_id}.key"
if password:
# 使用密码加密密钥
cipher = AESCipher()
encrypted = cipher.encrypt_with_password(key_data, password)
with open(key_file, 'wb') as f:
f.write(encrypted.to_bytes())
else:
# 直接保存(不推荐用于生产环境)
with open(key_file, 'wb') as f:
f.write(key_data)
logger.info(f"Key saved: {key_id}")
return True
except Exception as e:
logger.error(f"Failed to save key {key_id}: {e}")
return False
def load_key(self, key_id: str, password: Optional[str] = None) -> Optional[bytes]:
"""
从文件加载密钥
Args:
key_id: 密钥标识
password: 解密密码
Returns:
密钥数据如果加载失败返回None
"""
try:
key_file = self._key_dir / f"{key_id}.key"
if not key_file.exists():
logger.error(f"Key file not found: {key_id}")
return None
with open(key_file, 'rb') as f:
data = f.read()
if password:
# 解密密钥
cipher = AESCipher()
encrypted = EncryptedData.from_bytes(data)
return cipher.decrypt_with_password(encrypted, password)
else:
return data
except Exception as e:
logger.error(f"Failed to load key {key_id}: {e}")
return None
def delete_key(self, key_id: str) -> bool:
"""
删除密钥
Args:
key_id: 密钥标识
Returns:
删除成功返回True
"""
try:
key_file = self._key_dir / f"{key_id}.key"
if key_file.exists():
key_file.unlink()
logger.info(f"Key deleted: {key_id}")
# 从缓存中移除
self._key_cache.pop(key_id, None)
return True
except Exception as e:
logger.error(f"Failed to delete key {key_id}: {e}")
return False
def get_or_create_session_key(self, session_id: str) -> bytes:
"""
获取或创建会话密钥
Args:
session_id: 会话标识
Returns:
会话密钥
"""
if session_id in self._key_cache:
return self._key_cache[session_id]
# 生成新的会话密钥
key = AESCipher.generate_key()
self._key_cache[session_id] = key
return key
def clear_session_key(self, session_id: str) -> None:
"""
清除会话密钥
Args:
session_id: 会话标识
"""
self._key_cache.pop(session_id, None)
class LocalDataEncryptor:
"""
本地数据加密器
实现 AES-256 加密存储 (需求 10.2)
"""
def __init__(self, password: Optional[str] = None, key: Optional[bytes] = None):
"""
初始化本地数据加密器
Args:
password: 用户密码(用于派生密钥)
key: 直接提供的密钥
"""
self._cipher = AESCipher()
self._password = password
self._salt: Optional[bytes] = None
if key:
self._cipher = AESCipher(key)
elif password:
# 从密码派生密钥
key, self._salt = AESCipher.derive_key_from_password(password)
self._cipher = AESCipher(key)
def encrypt_data(self, data: Union[str, bytes, dict, list]) -> str:
"""
加密数据
实现对本地数据进行加密存储 (需求 10.2)
Args:
data: 要加密的数据(字符串、字节、字典或列表)
Returns:
Base64编码的加密数据
"""
# 转换为字节
if isinstance(data, str):
plaintext = data.encode('utf-8')
elif isinstance(data, (dict, list)):
plaintext = json.dumps(data, ensure_ascii=False).encode('utf-8')
else:
plaintext = data
encrypted = self._cipher.encrypt(plaintext)
# 如果有盐值,添加到加密数据中
if self._salt:
encrypted.salt = self._salt
return encrypted.to_base64()
def decrypt_data(self, encrypted_data: str, as_json: bool = False) -> Union[bytes, dict]:
"""
解密数据
Args:
encrypted_data: Base64编码的加密数据
as_json: 是否解析为JSON
Returns:
解密后的数据
"""
encrypted = EncryptedData.from_base64(encrypted_data)
plaintext = self._cipher.decrypt(encrypted)
if as_json:
return json.loads(plaintext.decode('utf-8'))
return plaintext
def encrypt_chat_history(self, messages: list) -> str:
"""
加密聊天记录
实现对本地数据进行加密存储 (需求 10.2)
Args:
messages: 聊天消息列表
Returns:
加密后的数据
"""
return self.encrypt_data(messages)
def decrypt_chat_history(self, encrypted_data: str) -> list:
"""
解密聊天记录
Args:
encrypted_data: 加密的聊天记录
Returns:
聊天消息列表
"""
return self.decrypt_data(encrypted_data, as_json=True)
def save_encrypted_file(self, data: Union[str, bytes, dict], file_path: str) -> bool:
"""
保存加密数据到文件
Args:
data: 要保存的数据
file_path: 文件路径
Returns:
保存成功返回True
"""
try:
encrypted = self.encrypt_data(data)
# 确保目录存在
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(encrypted)
return True
except Exception as e:
logger.error(f"Failed to save encrypted file: {e}")
return False
def load_encrypted_file(self, file_path: str, as_json: bool = False) -> Optional[Union[bytes, dict]]:
"""
从文件加载加密数据
Args:
file_path: 文件路径
as_json: 是否解析为JSON
Returns:
解密后的数据如果失败返回None
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
encrypted = f.read()
return self.decrypt_data(encrypted, as_json=as_json)
except FileNotFoundError:
logger.error(f"Encrypted file not found: {file_path}")
return None
except Exception as e:
logger.error(f"Failed to load encrypted file: {e}")
return None
# 需要导入ipaddress模块用于自签名证书
import ipaddress
# 便捷函数
def create_message_encryptor(key: Optional[bytes] = None) -> MessageEncryptor:
"""创建消息加密器"""
cipher = AESCipher(key) if key else AESCipher()
return MessageEncryptor(cipher)
def create_file_encryptor(key: Optional[bytes] = None) -> FileEncryptor:
"""创建文件加密器"""
cipher = AESCipher(key) if key else AESCipher()
return FileEncryptor(cipher)
def create_local_data_encryptor(password: str) -> LocalDataEncryptor:
"""创建本地数据加密器"""
return LocalDataEncryptor(password=password)
def encrypt_message(message_data: bytes, key: bytes) -> bytes:
"""快速加密消息"""
encryptor = MessageEncryptor(AESCipher(key))
return encryptor.encrypt_message(message_data)
def decrypt_message(encrypted_data: bytes, key: bytes) -> bytes:
"""快速解密消息"""
encryptor = MessageEncryptor(AESCipher(key))
return encryptor.decrypt_message(encrypted_data)