You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

527 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 点对点安全文件传输核心模块 - 更正版
import socket
import os
import json
import threading
from datetime import datetime, timedelta
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding as sym_padding
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509.oid import NameOID
def generate_key_pair(name, key_size=2048, validity_days=365):
"""
生成RSA密钥对和自签名证书
:param name: 证书名称
:param key_size: 密钥长度
:param validity_days: 证书有效期(天)
:return: (私钥路径, 公钥/证书路径)
"""
# 创建目录
os.makedirs('certs', exist_ok=True)
# 生成私钥
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
)
# 创建证书主题
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Beijing"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Beijing"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Secure File Transfer"),
x509.NameAttribute(NameOID.COMMON_NAME, name),
])
# 创建自签名证书
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
datetime.utcnow() + timedelta(days=validity_days)
).add_extension(
x509.SubjectAlternativeName([
x509.DNSName(name),
]),
critical=False,
).sign(private_key, hashes.SHA256())
# 保存私钥
private_key_path = os.path.join('certs', f'{name}_private_key.pem')
with open(private_key_path, "wb") as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
# 保存证书(公钥)
cert_path = os.path.join('certs', f'{name}_cert.pem')
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
return private_key_path, cert_path
class P2PFileSender:
def __init__(self, host, port):
self.host = host
self.port = port
self.socket = None
def connect(self):
"""连接到接收方"""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.host, self.port))
# 接收证书
cert_length_bytes = self._recv_all(self.socket, 8)
cert_length = int.from_bytes(cert_length_bytes, 'big')
cert_data = self._recv_all(self.socket, cert_length)
# 保存接收到的证书
self.received_cert = cert_data
# 发送确认收到证书的消息
ack = {"status": "cert_received", "message": "证书已成功接收"}
ack_json = json.dumps(ack)
self.socket.sendall(len(ack_json).to_bytes(8, 'big'))
self.socket.sendall(ack_json.encode('utf-8'))
def _recv_all(self, client_socket, size):
"""接收指定大小的数据"""
data = b''
while len(data) < size:
packet = client_socket.recv(size - len(data))
if not packet:
break
data += packet
return data
def disconnect(self):
"""断开连接"""
if self.socket:
self.socket.close()
def send_file(self, file_path, cert_path=None, algorithm="AES", mode="CBC"):
"""
发送文件到接收方
:param file_path: 要发送的文件路径
:param cert_path: 接收方证书路径 (已弃用,现在使用接收到的证书)
:param algorithm: 对称加密算法 (AES, ChaCha20)
:param mode: 加密模式 (CBC, GCM, OFB, CTR)
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 读取文件内容
with open(file_path, 'rb') as f:
file_data = f.read()
# 使用接收到的证书,获取公钥
if not hasattr(self, 'received_cert'):
raise Exception("未接收到接收方证书")
try:
# 先尝试作为公钥加载
public_key = serialization.load_pem_public_key(
self.received_cert,
backend=default_backend()
)
except:
# 如果失败,尝试作为证书加载
try:
cert = x509.load_pem_x509_certificate(
self.received_cert,
backend=default_backend()
)
public_key = cert.public_key()
except Exception as e:
raise Exception(f"无法加载接收方证书: {str(e)}")
# 生成对称密钥
if algorithm.upper() == "AES":
symmetric_key = os.urandom(32) # AES-256
iv = os.urandom(16) # AES块大小
# 根据模式选择加密方式
if mode.upper() == "CBC":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.CBC(iv),
backend=default_backend()
)
# 添加PKCS7填充
padder = sym_padding.PKCS7(128).padder()
padded_data = padder.update(file_data) + padder.finalize()
encrypted_file = cipher.encryptor().update(padded_data) + cipher.encryptor().finalize()
elif mode.upper() == "GCM":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.GCM(iv),
backend=default_backend()
)
encryptor = cipher.encryptor()
encrypted_file = encryptor.update(file_data) + encryptor.finalize()
# GCM模式需要认证标签
auth_tag = encryptor.tag
elif mode.upper() == "OFB":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.OFB(iv),
backend=default_backend()
)
encrypted_file = cipher.encryptor().update(file_data) + cipher.encryptor().finalize()
elif mode.upper() == "CTR":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.CTR(iv),
backend=default_backend()
)
encrypted_file = cipher.encryptor().update(file_data) + cipher.encryptor().finalize()
else:
raise ValueError(f"不支持的加密模式: {mode}")
elif algorithm.upper() == "CHACHA20":
symmetric_key = os.urandom(32) # ChaCha20密钥
nonce = os.urandom(12) # ChaCha20的nonce
cipher = Cipher(
algorithms.ChaCha20(symmetric_key, nonce),
mode=None,
backend=default_backend()
)
encrypted_file = cipher.encryptor().update(file_data) + cipher.encryptor().finalize()
iv = nonce # 在ChaCha20中nonce的作用类似于IV
else:
raise ValueError(f"不支持的加密算法: {algorithm}")
# 使用接收方公钥加密对称密钥 (数字信封)
encrypted_key = public_key.encrypt(
symmetric_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
# 准备发送的数据
file_name = os.path.basename(file_path)
transmission_data = {
"file_name": file_name,
"algorithm": algorithm,
"mode": mode,
"encrypted_key": encrypted_key.hex(),
"iv": iv.hex(),
"encrypted_file": encrypted_file.hex()
}
# 如果是GCM模式添加认证标签
if algorithm.upper() == "AES" and mode.upper() == "GCM":
transmission_data["auth_tag"] = auth_tag.hex()
# 发送数据到接收方
json_data = json.dumps(transmission_data)
self.socket.sendall(len(json_data).to_bytes(8, 'big'))
self.socket.sendall(json_data.encode('utf-8'))
# 接收确认
response_length = int.from_bytes(self.socket.recv(8), 'big')
response_data = self.socket.recv(response_length).decode('utf-8')
response = json.loads(response_data)
if response.get("status") != "success":
raise Exception(f"发送失败: {response.get('message', '未知错误')}")
return True
class P2PFileReceiver:
def __init__(self, host, port, storage_dir='received_files', cert_path=None):
self.host = host
self.port = port
self.storage_dir = storage_dir
self.socket = None
self.listening = False
self.received_files = [] # 存储接收到的文件信息
self.cert_path = cert_path # 接收方证书路径
# 确保存储目录存在
os.makedirs(storage_dir, exist_ok=True)
def start_listening(self):
"""开始监听连接"""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
self.listening = True
while self.listening:
try:
client_socket, address = self.socket.accept()
# 为每个连接创建新线程
client_thread = threading.Thread(
target=self.handle_client,
args=(client_socket, address)
)
client_thread.daemon = True
client_thread.start()
except OSError:
if self.listening:
pass
break
def stop_listening(self):
"""停止监听"""
self.listening = False
if self.socket:
self.socket.close()
def handle_client(self, client_socket, address):
"""处理客户端连接"""
try:
# 发送证书给发送方
if self.cert_path:
with open(self.cert_path, 'rb') as cert_file:
cert_data = cert_file.read()
cert_length = len(cert_data)
client_socket.sendall(cert_length.to_bytes(8, 'big'))
client_socket.sendall(cert_data)
# 等待发送方确认收到证书
ack_length_bytes = self._recv_all(client_socket, 8)
ack_length = int.from_bytes(ack_length_bytes, 'big')
ack_data = self._recv_all(client_socket, ack_length).decode('utf-8')
ack = json.loads(ack_data)
if ack.get("status") != "cert_received":
raise Exception(f"发送方未确认收到证书: {ack.get('message', '未知错误')}")
# 接收数据长度
length_bytes = self._recv_all(client_socket, 8)
if not length_bytes:
return
data_length = int.from_bytes(length_bytes, 'big')
# 接收数据
json_data = self._recv_all(client_socket, data_length).decode('utf-8')
transmission_data = json.loads(json_data)
# 提取数据
file_name = transmission_data["file_name"]
algorithm = transmission_data["algorithm"]
mode = transmission_data["mode"]
encrypted_key = transmission_data["encrypted_key"]
iv = transmission_data["iv"]
encrypted_file = transmission_data["encrypted_file"]
# 如果是GCM模式提取认证标签
auth_tag = None
if "auth_tag" in transmission_data:
auth_tag = transmission_data["auth_tag"]
# 创建时间戳目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_dir = os.path.join(self.storage_dir, f"{timestamp}_{address[0]}")
os.makedirs(file_dir, exist_ok=True)
# 保存传输数据
transmission_info = {
"file_name": file_name,
"algorithm": algorithm,
"mode": mode,
"encrypted_key": encrypted_key,
"iv": iv,
"encrypted_file": encrypted_file,
"timestamp": timestamp,
"sender": address[0]
}
if auth_tag:
transmission_info["auth_tag"] = auth_tag
# 保存传输信息
info_path = os.path.join(file_dir, "transmission_info.json")
with open(info_path, 'w') as f:
json.dump(transmission_info, f, indent=4)
# 保存加密文件
encrypted_file_path = os.path.join(file_dir, f"encrypted_{file_name}")
with open(encrypted_file_path, 'wb') as f:
f.write(bytes.fromhex(encrypted_file))
# 添加到接收文件列表
self.received_files.append({
"dir": file_dir,
"info": transmission_info,
"status": "received" # 状态received已接收decrypted已解密
})
# 发送确认消息
response = {"status": "success", "message": "文件已成功接收"}
response_json = json.dumps(response)
client_socket.sendall(len(response_json).to_bytes(8, 'big'))
client_socket.sendall(response_json.encode('utf-8'))
except Exception as e:
try:
error_response = {"status": "error", "message": str(e)}
response_json = json.dumps(error_response)
client_socket.sendall(len(response_json).to_bytes(8, 'big'))
client_socket.sendall(response_json.encode('utf-8'))
except:
pass
finally:
client_socket.close()
def _recv_all(self, client_socket, size):
"""接收指定大小的数据"""
data = b''
while len(data) < size:
packet = client_socket.recv(size - len(data))
if not packet:
break
data += packet
return data
def get_received_files(self):
"""获取接收到的文件列表"""
return self.received_files
def decrypt_file(self, file_index, output_dir, private_key_path):
"""
解密指定索引的文件
:param file_index: 文件在接收列表中的索引
:param output_dir: 输出目录
:param private_key_path: 私钥路径
:return: 解密后的文件路径
"""
if file_index < 0 or file_index >= len(self.received_files):
raise IndexError("文件索引超出范围")
# 获取文件信息
file_info = self.received_files[file_index]
if file_info["status"] == "decrypted":
raise Exception("文件已经解密过")
transmission_data = file_info["info"]
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 读取私钥
with open(private_key_path, 'rb') as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
backend=default_backend()
)
# 提取数据
file_name = transmission_data["file_name"]
algorithm = transmission_data["algorithm"]
mode = transmission_data["mode"]
encrypted_key = transmission_data["encrypted_key"]
iv = transmission_data["iv"]
encrypted_file = transmission_data["encrypted_file"]
# 如果是GCM模式提取认证标签
auth_tag = None
if "auth_tag" in transmission_data:
auth_tag = transmission_data["auth_tag"]
# 将十六进制字符串转换为字节
iv_bytes = bytes.fromhex(iv)
encrypted_file_bytes = bytes.fromhex(encrypted_file)
# 如果是GCM模式转换认证标签
auth_tag_bytes = None
if auth_tag:
auth_tag_bytes = bytes.fromhex(auth_tag)
# 将十六进制字符串转换为字节
encrypted_key_bytes = bytes.fromhex(encrypted_key)
# 使用私钥解密对称密钥
symmetric_key = private_key.decrypt(
encrypted_key_bytes,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
# 使用对称密钥解密文件
if algorithm.upper() == "AES":
if mode.upper() == "CBC":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.CBC(iv_bytes),
backend=default_backend()
)
decrypted_padded = cipher.decryptor().update(encrypted_file_bytes) + cipher.decryptor().finalize()
# 移除PKCS7填充
unpadder = sym_padding.PKCS7(128).unpadder()
file_data = unpadder.update(decrypted_padded) + unpadder.finalize()
elif mode.upper() == "GCM":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.GCM(iv_bytes, auth_tag_bytes),
backend=default_backend()
)
file_data = cipher.decryptor().update(encrypted_file_bytes) + cipher.decryptor().finalize()
elif mode.upper() == "OFB":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.OFB(iv_bytes),
backend=default_backend()
)
file_data = cipher.decryptor().update(encrypted_file_bytes) + cipher.decryptor().finalize()
elif mode.upper() == "CTR":
cipher = Cipher(
algorithms.AES(symmetric_key),
modes.CTR(iv_bytes),
backend=default_backend()
)
file_data = cipher.decryptor().update(encrypted_file_bytes) + cipher.decryptor().finalize()
else:
raise ValueError(f"不支持的加密模式: {mode}")
elif algorithm.upper() == "CHACHA20":
cipher = Cipher(
algorithms.ChaCha20(symmetric_key, iv_bytes),
mode=None,
backend=default_backend()
)
file_data = cipher.decryptor().update(encrypted_file_bytes) + cipher.decryptor().finalize()
else:
raise ValueError(f"不支持的加密算法: {algorithm}")
# 保存解密后的文件
output_path = os.path.join(output_dir, file_name)
with open(output_path, 'wb') as f:
f.write(file_data)
# 更新文件状态
self.received_files[file_index]["status"] = "decrypted"
return output_path