You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import socket
import threading
import struct
from typing import Callable, Optional, Tuple
MAGIC_NUMBER = 0xABCD1234
MSG_TYPE_HANDSHAKE = 0x01
MSG_TYPE_FILE_TRANSFER = 0x02
MSG_TYPE_ACK = 0x03
HEADER_SIZE = 16 # 魔数(4) + 类型(4) + 长度(8)
class P2PError(Exception):
pass
def _recvall(sock: socket.socket, n: int) -> bytes:
data = b""
while len(data) < n:
part = sock.recv(n - len(data))
if not part:
raise P2PError("Socket closed while receiving")
data += part
return data
class P2PNode:
def __init__(self, listen_port: int, host: str = "0.0.0.0", timeout: float = 10.0):
self.host = host
self.listen_port = int(listen_port)
self.timeout = timeout
self._server_sock: Optional[socket.socket] = None
self._listener_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
# 接收回调:当收到 envelope(bytes) 时调用
self.on_envelope: Optional[Callable[[bytes, Tuple[str, int]], None]] = None
# ---------- server side ----------
def start_listening(self) -> None:
"""开始监听端口,等待对方连接(后台线程)"""
if self._server_sock:
return
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.bind((self.host, self.listen_port))
srv.listen(5)
srv.settimeout(1.0) # 方便 stop_event 生效
self._server_sock = srv
self._stop_event.clear()
t = threading.Thread(target=self._accept_loop, daemon=True)
self._listener_thread = t
t.start()
def _accept_loop(self) -> None:
assert self._server_sock is not None
while not self._stop_event.is_set():
try:
conn, addr = self._server_sock.accept()
except socket.timeout:
continue
except OSError:
break
threading.Thread(
target=self._handle_client_connection,
args=(conn, addr),
daemon=True
).start()
def _handle_client_connection(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
"""处理客户端连接:握手 -> 收文件 -> 回 ACK"""
try:
conn.settimeout(self.timeout)
# 1) 握手(可选严格校验)
msg_type, payload = self._receive_message(conn)
if msg_type != MSG_TYPE_HANDSHAKE:
raise P2PError("Expected HANDSHAKE")
# 2) 读取文件/信封
msg_type, payload = self._receive_message(conn)
if msg_type != MSG_TYPE_FILE_TRANSFER:
raise P2PError("Expected FILE_TRANSFER")
# 3) 回调给上层(比如保存文件) 先ACK再处理
self._send_message(conn, MSG_TYPE_ACK, b"OK")
if self.on_envelope:
self.on_envelope(payload, addr)
except Exception:
# 课程项目里这里可以加日志;先不抛,避免线程炸到控制台
pass
finally:
try:
conn.close()
except Exception:
pass
# ---------- client side ----------
def connect_to_peer(self, peer_ip: str, peer_port: int) -> socket.socket:
"""主动连接到对方的IP和端口"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(self.timeout)
s.connect((peer_ip, int(peer_port)))
return s
def send_envelope(
self,
peer_ip: str,
peer_port: int,
envelope: bytes,
progress_cb: Optional[Callable[[int, int], None]] = None,
) -> None:
"""发送数字信封(这里 envelope 就是 bytes"""
s = self.connect_to_peer(peer_ip, peer_port)
try:
# 1) handshake
self._send_message(s, MSG_TYPE_HANDSHAKE, b"HELLO")
# 2) file transfer大包用分块发送但协议里 length 是总长)
# 这里为了简单:一次性 sendall(header+payload) 也行,但大文件会占内存
# 所以我们把 header 发出去后payload 分块发。
total = len(envelope)
header = struct.pack("!IIQ", MAGIC_NUMBER, MSG_TYPE_FILE_TRANSFER, total)
s.sendall(header)
sent = 0
chunk_size = 64 * 1024
mv = memoryview(envelope)
while sent < total:
end = min(sent + chunk_size, total)
s.sendall(mv[sent:end])
sent = end
if progress_cb:
progress_cb(sent, total)
# 3) 等待 ACK
msg_type, payload = self._receive_message(s)
if msg_type != MSG_TYPE_ACK:
raise P2PError("No ACK from peer")
finally:
try:
s.close()
except Exception:
pass
def receive_envelope(self, conn: socket.socket) -> bytes:
"""接收数字信封(如果你想在别处直接调用)"""
msg_type, payload = self._receive_message(conn)
if msg_type != MSG_TYPE_FILE_TRANSFER:
raise P2PError("Expected FILE_TRANSFER")
return payload
# ---------- protocol ----------
def _send_message(self, sock: socket.socket, msg_type: int, payload: bytes) -> None:
length = len(payload)
header = struct.pack("!IIQ", MAGIC_NUMBER, int(msg_type), length)
sock.sendall(header)
if length:
sock.sendall(payload)
def _receive_message(self, sock: socket.socket) -> Tuple[int, bytes]:
header = _recvall(sock, HEADER_SIZE)
magic, msg_type, length = struct.unpack("!IIQ", header)
if magic != MAGIC_NUMBER:
raise P2PError("Bad magic number")
payload = _recvall(sock, length) if length else b""
return msg_type, payload
# ---------- lifecycle ----------
def close(self) -> None:
"""关闭连接和监听"""
self._stop_event.set()
if self._server_sock:
try:
self._server_sock.close()
except Exception:
pass
self._server_sock = None