|
|
import socket
|
|
|
import threading
|
|
|
import struct
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
|
|
MAGIC_NUMBER = 0xABCD1234
|
|
|
MSG_TYPE_HANDSHAKE = 0x01
|
|
|
MSG_TYPE_FILE_TRANSFER = 0x02
|
|
|
MSG_TYPE_ACK = 0x03
|
|
|
|
|
|
HEADER_SIZE = 16 # 魔数(4) + 类型(4) + 长度(8)
|
|
|
|
|
|
|
|
|
class P2PError(Exception):
|
|
|
pass
|
|
|
|
|
|
|
|
|
def _recvall(sock: socket.socket, n: int) -> bytes:
|
|
|
data = b""
|
|
|
while len(data) < n:
|
|
|
part = sock.recv(n - len(data))
|
|
|
if not part:
|
|
|
raise P2PError("Socket closed while receiving")
|
|
|
data += part
|
|
|
return data
|
|
|
|
|
|
|
|
|
class P2PNode:
|
|
|
def __init__(self, listen_port: int, host: str = "0.0.0.0", timeout: float = 10.0):
|
|
|
self.host = host
|
|
|
self.listen_port = int(listen_port)
|
|
|
self.timeout = timeout
|
|
|
|
|
|
self._server_sock: Optional[socket.socket] = None
|
|
|
self._listener_thread: Optional[threading.Thread] = None
|
|
|
self._stop_event = threading.Event()
|
|
|
|
|
|
# 接收回调:当收到 envelope(bytes) 时调用
|
|
|
self.on_envelope: Optional[Callable[[bytes, Tuple[str, int]], None]] = None
|
|
|
|
|
|
# ---------- server side ----------
|
|
|
def start_listening(self) -> None:
|
|
|
"""开始监听端口,等待对方连接(后台线程)"""
|
|
|
if self._server_sock:
|
|
|
return
|
|
|
|
|
|
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
|
srv.bind((self.host, self.listen_port))
|
|
|
srv.listen(5)
|
|
|
srv.settimeout(1.0) # 方便 stop_event 生效
|
|
|
|
|
|
self._server_sock = srv
|
|
|
self._stop_event.clear()
|
|
|
|
|
|
t = threading.Thread(target=self._accept_loop, daemon=True)
|
|
|
self._listener_thread = t
|
|
|
t.start()
|
|
|
|
|
|
def _accept_loop(self) -> None:
|
|
|
assert self._server_sock is not None
|
|
|
while not self._stop_event.is_set():
|
|
|
try:
|
|
|
conn, addr = self._server_sock.accept()
|
|
|
except socket.timeout:
|
|
|
continue
|
|
|
except OSError:
|
|
|
break
|
|
|
|
|
|
threading.Thread(
|
|
|
target=self._handle_client_connection,
|
|
|
args=(conn, addr),
|
|
|
daemon=True
|
|
|
).start()
|
|
|
|
|
|
def _handle_client_connection(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
|
|
|
"""处理客户端连接:握手 -> 收文件 -> 回 ACK"""
|
|
|
try:
|
|
|
conn.settimeout(self.timeout)
|
|
|
|
|
|
# 1) 握手(可选严格校验)
|
|
|
msg_type, payload = self._receive_message(conn)
|
|
|
if msg_type != MSG_TYPE_HANDSHAKE:
|
|
|
raise P2PError("Expected HANDSHAKE")
|
|
|
|
|
|
# 2) 读取文件/信封
|
|
|
msg_type, payload = self._receive_message(conn)
|
|
|
if msg_type != MSG_TYPE_FILE_TRANSFER:
|
|
|
raise P2PError("Expected FILE_TRANSFER")
|
|
|
|
|
|
# 3) 回调给上层(比如保存文件) 先ACK,再处理)
|
|
|
self._send_message(conn, MSG_TYPE_ACK, b"OK")
|
|
|
|
|
|
if self.on_envelope:
|
|
|
self.on_envelope(payload, addr)
|
|
|
except Exception:
|
|
|
# 课程项目里这里可以加日志;先不抛,避免线程炸到控制台
|
|
|
pass
|
|
|
finally:
|
|
|
try:
|
|
|
conn.close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
# ---------- client side ----------
|
|
|
def connect_to_peer(self, peer_ip: str, peer_port: int) -> socket.socket:
|
|
|
"""主动连接到对方的IP和端口"""
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
s.settimeout(self.timeout)
|
|
|
s.connect((peer_ip, int(peer_port)))
|
|
|
return s
|
|
|
|
|
|
def send_envelope(
|
|
|
self,
|
|
|
peer_ip: str,
|
|
|
peer_port: int,
|
|
|
envelope: bytes,
|
|
|
progress_cb: Optional[Callable[[int, int], None]] = None,
|
|
|
) -> None:
|
|
|
"""发送数字信封(这里 envelope 就是 bytes)"""
|
|
|
s = self.connect_to_peer(peer_ip, peer_port)
|
|
|
try:
|
|
|
# 1) handshake
|
|
|
self._send_message(s, MSG_TYPE_HANDSHAKE, b"HELLO")
|
|
|
|
|
|
# 2) file transfer:大包用分块发送(但协议里 length 是总长)
|
|
|
# 这里为了简单:一次性 sendall(header+payload) 也行,但大文件会占内存
|
|
|
# 所以我们把 header 发出去后,payload 分块发。
|
|
|
total = len(envelope)
|
|
|
header = struct.pack("!IIQ", MAGIC_NUMBER, MSG_TYPE_FILE_TRANSFER, total)
|
|
|
s.sendall(header)
|
|
|
|
|
|
sent = 0
|
|
|
chunk_size = 64 * 1024
|
|
|
mv = memoryview(envelope)
|
|
|
while sent < total:
|
|
|
end = min(sent + chunk_size, total)
|
|
|
s.sendall(mv[sent:end])
|
|
|
sent = end
|
|
|
if progress_cb:
|
|
|
progress_cb(sent, total)
|
|
|
|
|
|
# 3) 等待 ACK
|
|
|
msg_type, payload = self._receive_message(s)
|
|
|
if msg_type != MSG_TYPE_ACK:
|
|
|
raise P2PError("No ACK from peer")
|
|
|
finally:
|
|
|
try:
|
|
|
s.close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
def receive_envelope(self, conn: socket.socket) -> bytes:
|
|
|
"""接收数字信封(如果你想在别处直接调用)"""
|
|
|
msg_type, payload = self._receive_message(conn)
|
|
|
if msg_type != MSG_TYPE_FILE_TRANSFER:
|
|
|
raise P2PError("Expected FILE_TRANSFER")
|
|
|
return payload
|
|
|
|
|
|
# ---------- protocol ----------
|
|
|
def _send_message(self, sock: socket.socket, msg_type: int, payload: bytes) -> None:
|
|
|
length = len(payload)
|
|
|
header = struct.pack("!IIQ", MAGIC_NUMBER, int(msg_type), length)
|
|
|
sock.sendall(header)
|
|
|
if length:
|
|
|
sock.sendall(payload)
|
|
|
|
|
|
def _receive_message(self, sock: socket.socket) -> Tuple[int, bytes]:
|
|
|
header = _recvall(sock, HEADER_SIZE)
|
|
|
magic, msg_type, length = struct.unpack("!IIQ", header)
|
|
|
if magic != MAGIC_NUMBER:
|
|
|
raise P2PError("Bad magic number")
|
|
|
payload = _recvall(sock, length) if length else b""
|
|
|
return msg_type, payload
|
|
|
|
|
|
# ---------- lifecycle ----------
|
|
|
def close(self) -> None:
|
|
|
"""关闭连接和监听"""
|
|
|
self._stop_event.set()
|
|
|
if self._server_sock:
|
|
|
try:
|
|
|
self._server_sock.close()
|
|
|
except Exception:
|
|
|
pass
|
|
|
self._server_sock = None
|