From e4c02e8585f1744b9c58c324d951f103417b751a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=8D=9A=E6=96=87?= <15549487+FX_YBW@user.noreply.gitee.com> Date: Thu, 25 Dec 2025 23:15:22 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A1=AE=E4=BF=9D=E8=AF=AD=E9=9F=B3=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E6=B5=8B=E8=AF=95=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/__init__.py | 18 + client/voice_chat.py | 1478 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_voice_chat.py | 390 ++++++++++ 4 files changed, 1887 insertions(+) create mode 100644 client/voice_chat.py create mode 100644 tests/test_voice_chat.py diff --git a/client/__init__.py b/client/__init__.py index 50958d6..21e9b34 100644 --- a/client/__init__.py +++ b/client/__init__.py @@ -41,3 +41,21 @@ from client.media_player import ( PlaybackError, MediaLoadError, ) + +from client.voice_chat import ( + VoiceChatModule, + CallState, + VoiceChatError, + AudioDeviceError, + CallError, + AudioConfig, + CallInfo, + NetworkStats, + JitterBuffer, + AudioCapture, + AudioPlayback, + AudioEncoder, + AudioDecoder, + CallStateCallback, + IncomingCallCallback, +) diff --git a/client/voice_chat.py b/client/voice_chat.py new file mode 100644 index 0000000..9f4cfd9 --- /dev/null +++ b/client/voice_chat.py @@ -0,0 +1,1478 @@ +# P2P Network Communication - Voice Chat Module +""" +语音聊天模块 +负责实时语音通话功能,包括音频采集、编码、传输和播放 + +需求: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7 +""" + +import asyncio +import logging +import socket +import struct +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Callable, Dict, List, Optional, Tuple, Any + +from shared.models import ( + Message, MessageType, NetworkQuality +) +from config import ClientConfig + + +# 设置日志 +logger = logging.getLogger(__name__) + + +class CallState(Enum): + """通话状态枚举""" + IDLE = "idle" # 空闲 + CALLING = "calling" # 正在呼叫 + RINGING = "ringing" # 来电响铃 + CONNECTED = "connected" # 通话中 + ENDING = "ending" # 正在结束 + + +class VoiceChatError(Exception): + """语音聊天错误""" + pass + + +class AudioDeviceError(VoiceChatError): + """音频设备错误""" + pass + + +class CallError(VoiceChatError): + """通话错误""" + pass + + +# 回调类型定义 +CallStateCallback = Callable[[CallState, Optional[str]], None] +IncomingCallCallback = Callable[[str, str], None] # peer_id, peer_name + + +@dataclass +class AudioConfig: + """音频配置""" + sample_rate: int = 16000 # 采样率 (Hz) + channels: int = 1 # 声道数 + chunk_duration: float = 0.02 # 每个音频块的时长 (秒) + bits_per_sample: int = 16 # 每个采样的位数 + + @property + def chunk_size(self) -> int: + """计算每个音频块的采样数""" + return int(self.sample_rate * self.chunk_duration) + + @property + def bytes_per_chunk(self) -> int: + """计算每个音频块的字节数""" + return self.chunk_size * self.channels * (self.bits_per_sample // 8) + + +@dataclass +class CallInfo: + """通话信息""" + peer_id: str + peer_name: str + start_time: Optional[datetime] = None + is_outgoing: bool = True + + @property + def duration(self) -> float: + """获取通话时长(秒)""" + if self.start_time is None: + return 0.0 + return (datetime.now() - self.start_time).total_seconds() + + +@dataclass +class NetworkStats: + """网络统计信息""" + packets_sent: int = 0 + packets_received: int = 0 + packets_lost: int = 0 + avg_latency: float = 0.0 # 平均延迟 (ms) + jitter: float = 0.0 # 抖动 (ms) + last_update: float = field(default_factory=time.time) + + @property + def packet_loss_rate(self) -> float: + """计算丢包率""" + total = self.packets_sent + self.packets_received + if total == 0: + return 0.0 + return self.packets_lost / total + + def get_quality(self) -> NetworkQuality: + """根据统计信息判断网络质量""" + if self.avg_latency < 50: + return NetworkQuality.EXCELLENT + elif self.avg_latency < 100: + return NetworkQuality.GOOD + elif self.avg_latency < 200: + return NetworkQuality.FAIR + elif self.avg_latency < 300: + return NetworkQuality.POOR + else: + return NetworkQuality.BAD + + +class JitterBuffer: + """ + 抖动缓冲区 + + 用于平滑网络抖动,确保音频播放的连续性 + 需求: 7.3, 7.4 + """ + + def __init__(self, target_delay: float = 0.06, max_delay: float = 0.2): + """ + 初始化抖动缓冲区 + + Args: + target_delay: 目标延迟(秒),默认60ms + max_delay: 最大延迟(秒),默认200ms + """ + self._buffer: deque = deque() + self._target_delay = target_delay + self._max_delay = max_delay + self._lock = threading.Lock() + self._sequence_number = 0 + self._last_played_seq = -1 + + def push(self, sequence: int, data: bytes, timestamp: float) -> None: + """ + 将音频数据包放入缓冲区 + + Args: + sequence: 序列号 + data: 音频数据 + timestamp: 时间戳 + """ + with self._lock: + # 丢弃过期的包 + if sequence <= self._last_played_seq: + return + + # 按序列号排序插入 + packet = (sequence, data, timestamp) + + # 找到正确的插入位置 + inserted = False + for i, (seq, _, _) in enumerate(self._buffer): + if sequence < seq: + self._buffer.insert(i, packet) + inserted = True + break + elif sequence == seq: + # 重复包,忽略 + return + + if not inserted: + self._buffer.append(packet) + + # 限制缓冲区大小 + while len(self._buffer) > 50: # 最多缓存50个包 + self._buffer.popleft() + + def pop(self) -> Optional[bytes]: + """ + 从缓冲区取出下一个音频数据包 + + Returns: + 音频数据,如果缓冲区为空则返回None + """ + with self._lock: + if not self._buffer: + return None + + # 检查是否有足够的缓冲 + current_time = time.time() + if self._buffer: + _, _, oldest_ts = self._buffer[0] + buffer_delay = current_time - oldest_ts + + # 如果缓冲不足,等待更多数据 + if buffer_delay < self._target_delay and len(self._buffer) < 3: + return None + + # 取出最早的包 + seq, data, _ = self._buffer.popleft() + self._last_played_seq = seq + return data + + def clear(self) -> None: + """清空缓冲区""" + with self._lock: + self._buffer.clear() + self._last_played_seq = -1 + + @property + def size(self) -> int: + """获取缓冲区大小""" + with self._lock: + return len(self._buffer) + + @property + def delay(self) -> float: + """获取当前缓冲延迟(秒)""" + with self._lock: + if not self._buffer: + return 0.0 + _, _, oldest_ts = self._buffer[0] + return time.time() - oldest_ts + + +class AudioCapture: + """ + 音频采集器 + + 使用PyAudio进行音频采集 + 需求: 7.3 + """ + + def __init__(self, config: AudioConfig): + """ + 初始化音频采集器 + + Args: + config: 音频配置 + """ + self._config = config + self._pyaudio = None + self._stream = None + self._is_capturing = False + self._capture_thread: Optional[threading.Thread] = None + self._audio_queue: deque = deque(maxlen=100) + self._lock = threading.Lock() + + def start(self) -> None: + """ + 开始音频采集 + + Raises: + AudioDeviceError: 无法打开音频设备时抛出 + """ + if self._is_capturing: + return + + try: + import pyaudio + + self._pyaudio = pyaudio.PyAudio() + + # 打开输入流 + self._stream = self._pyaudio.open( + format=pyaudio.paInt16, + channels=self._config.channels, + rate=self._config.sample_rate, + input=True, + frames_per_buffer=self._config.chunk_size + ) + + self._is_capturing = True + + # 启动采集线程 + self._capture_thread = threading.Thread( + target=self._capture_loop, + daemon=True + ) + self._capture_thread.start() + + logger.info("Audio capture started") + + except ImportError: + raise AudioDeviceError("PyAudio is not installed") + except Exception as e: + self._cleanup() + raise AudioDeviceError(f"Failed to start audio capture: {e}") + + def stop(self) -> None: + """停止音频采集""" + self._is_capturing = False + + if self._capture_thread and self._capture_thread.is_alive(): + self._capture_thread.join(timeout=1.0) + + self._cleanup() + logger.info("Audio capture stopped") + + def _cleanup(self) -> None: + """清理资源""" + if self._stream: + try: + self._stream.stop_stream() + self._stream.close() + except Exception: + pass + self._stream = None + + if self._pyaudio: + try: + self._pyaudio.terminate() + except Exception: + pass + self._pyaudio = None + + def _capture_loop(self) -> None: + """音频采集循环""" + while self._is_capturing and self._stream: + try: + # 读取音频数据 + data = self._stream.read( + self._config.chunk_size, + exception_on_overflow=False + ) + + with self._lock: + self._audio_queue.append(data) + + except Exception as e: + logger.error(f"Audio capture error: {e}") + break + + def get_audio(self) -> Optional[bytes]: + """ + 获取采集到的音频数据 + + Returns: + 音频数据,如果没有数据则返回None + """ + with self._lock: + if self._audio_queue: + return self._audio_queue.popleft() + return None + + @property + def is_capturing(self) -> bool: + """是否正在采集""" + return self._is_capturing + + +class AudioPlayback: + """ + 音频播放器 + + 使用PyAudio进行音频播放 + 需求: 7.3 + """ + + def __init__(self, config: AudioConfig): + """ + 初始化音频播放器 + + Args: + config: 音频配置 + """ + self._config = config + self._pyaudio = None + self._stream = None + self._is_playing = False + self._playback_thread: Optional[threading.Thread] = None + self._jitter_buffer = JitterBuffer() + self._lock = threading.Lock() + + def start(self) -> None: + """ + 开始音频播放 + + Raises: + AudioDeviceError: 无法打开音频设备时抛出 + """ + if self._is_playing: + return + + try: + import pyaudio + + self._pyaudio = pyaudio.PyAudio() + + # 打开输出流 + self._stream = self._pyaudio.open( + format=pyaudio.paInt16, + channels=self._config.channels, + rate=self._config.sample_rate, + output=True, + frames_per_buffer=self._config.chunk_size + ) + + self._is_playing = True + + # 启动播放线程 + self._playback_thread = threading.Thread( + target=self._playback_loop, + daemon=True + ) + self._playback_thread.start() + + logger.info("Audio playback started") + + except ImportError: + raise AudioDeviceError("PyAudio is not installed") + except Exception as e: + self._cleanup() + raise AudioDeviceError(f"Failed to start audio playback: {e}") + + def stop(self) -> None: + """停止音频播放""" + self._is_playing = False + + if self._playback_thread and self._playback_thread.is_alive(): + self._playback_thread.join(timeout=1.0) + + self._cleanup() + self._jitter_buffer.clear() + logger.info("Audio playback stopped") + + def _cleanup(self) -> None: + """清理资源""" + if self._stream: + try: + self._stream.stop_stream() + self._stream.close() + except Exception: + pass + self._stream = None + + if self._pyaudio: + try: + self._pyaudio.terminate() + except Exception: + pass + self._pyaudio = None + + def _playback_loop(self) -> None: + """音频播放循环""" + silence = bytes(self._config.bytes_per_chunk) + + while self._is_playing and self._stream: + try: + # 从抖动缓冲区获取数据 + data = self._jitter_buffer.pop() + + if data is None: + # 没有数据,播放静音 + data = silence + + # 播放音频 + self._stream.write(data) + + except Exception as e: + logger.error(f"Audio playback error: {e}") + break + + def push_audio(self, sequence: int, data: bytes, timestamp: float) -> None: + """ + 将音频数据放入播放队列 + + Args: + sequence: 序列号 + data: 音频数据 + timestamp: 时间戳 + """ + self._jitter_buffer.push(sequence, data, timestamp) + + @property + def is_playing(self) -> bool: + """是否正在播放""" + return self._is_playing + + @property + def buffer_delay(self) -> float: + """获取当前缓冲延迟""" + return self._jitter_buffer.delay + + +class AudioEncoder: + """ + 音频编码器 + + 使用Opus编码器进行音频压缩 + 需求: 7.3, 7.7 + """ + + # Opus应用类型 + APPLICATION_VOIP = 2048 + APPLICATION_AUDIO = 2049 + APPLICATION_RESTRICTED_LOWDELAY = 2051 + + def __init__(self, config: AudioConfig, bitrate: int = 24000): + """ + 初始化音频编码器 + + Args: + config: 音频配置 + bitrate: 目标比特率 (bps),默认24kbps + """ + self._config = config + self._bitrate = bitrate + self._encoder = None + self._use_opus = True + + try: + self._init_opus_encoder() + except Exception as e: + logger.warning(f"Failed to initialize Opus encoder: {e}. Using raw audio.") + self._use_opus = False + + def _init_opus_encoder(self) -> None: + """初始化Opus编码器""" + try: + import opuslib + + self._encoder = opuslib.Encoder( + self._config.sample_rate, + self._config.channels, + self.APPLICATION_VOIP + ) + + # 设置比特率 + self._encoder.bitrate = self._bitrate + + logger.info(f"Opus encoder initialized (bitrate: {self._bitrate}bps)") + + except ImportError: + raise VoiceChatError("opuslib is not installed") + + def encode(self, pcm_data: bytes) -> bytes: + """ + 编码PCM音频数据 + + Args: + pcm_data: 原始PCM数据 + + Returns: + 编码后的数据 + """ + if not self._use_opus or self._encoder is None: + # 不使用Opus,直接返回原始数据 + return pcm_data + + try: + # Opus编码 + encoded = self._encoder.encode( + pcm_data, + self._config.chunk_size + ) + return encoded + except Exception as e: + logger.error(f"Encoding error: {e}") + return pcm_data + + def set_bitrate(self, bitrate: int) -> None: + """ + 设置编码比特率 + + 用于自适应调整编码参数 (需求 7.7) + + Args: + bitrate: 目标比特率 (bps) + """ + self._bitrate = bitrate + if self._encoder: + try: + self._encoder.bitrate = bitrate + logger.info(f"Encoder bitrate set to {bitrate}bps") + except Exception as e: + logger.error(f"Failed to set bitrate: {e}") + + @property + def bitrate(self) -> int: + """获取当前比特率""" + return self._bitrate + + @property + def is_opus_enabled(self) -> bool: + """是否启用Opus编码""" + return self._use_opus + + +class AudioDecoder: + """ + 音频解码器 + + 使用Opus解码器进行音频解压 + 需求: 7.3 + """ + + def __init__(self, config: AudioConfig): + """ + 初始化音频解码器 + + Args: + config: 音频配置 + """ + self._config = config + self._decoder = None + self._use_opus = True + + try: + self._init_opus_decoder() + except Exception as e: + logger.warning(f"Failed to initialize Opus decoder: {e}. Using raw audio.") + self._use_opus = False + + def _init_opus_decoder(self) -> None: + """初始化Opus解码器""" + try: + import opuslib + + self._decoder = opuslib.Decoder( + self._config.sample_rate, + self._config.channels + ) + + logger.info("Opus decoder initialized") + + except ImportError: + raise VoiceChatError("opuslib is not installed") + + def decode(self, encoded_data: bytes) -> bytes: + """ + 解码音频数据 + + Args: + encoded_data: 编码后的数据 + + Returns: + 解码后的PCM数据 + """ + if not self._use_opus or self._decoder is None: + # 不使用Opus,直接返回原始数据 + return encoded_data + + try: + # Opus解码 + decoded = self._decoder.decode( + encoded_data, + self._config.chunk_size + ) + return decoded + except Exception as e: + logger.error(f"Decoding error: {e}") + # 解码失败,返回静音 + return bytes(self._config.bytes_per_chunk) + + @property + def is_opus_enabled(self) -> bool: + """是否启用Opus解码""" + return self._use_opus + + +class VoiceChatModule: + """ + 语音聊天模块 + + 负责实时语音通话功能,包括: + - 发起/接听/拒绝/结束通话 (需求 7.1, 7.2, 7.6) + - 实时音频采集和传输 (需求 7.3) + - 低延迟音频传输 (需求 7.4) + - 静音功能 (需求 7.5) + - 自适应编码参数调整 (需求 7.7) + """ + + # 音频数据包头格式: 序列号(4字节) + 时间戳(8字节) + AUDIO_HEADER_FORMAT = "!Id" + AUDIO_HEADER_SIZE = struct.calcsize(AUDIO_HEADER_FORMAT) + + # UDP端口范围 + UDP_PORT_MIN = 10000 + UDP_PORT_MAX = 20000 + + def __init__(self, config: Optional[ClientConfig] = None): + """ + 初始化语音聊天模块 + + Args: + config: 客户端配置 + """ + self._config = config or ClientConfig() + + # 音频配置 + self._audio_config = AudioConfig( + sample_rate=self._config.audio_sample_rate, + channels=self._config.audio_channels, + chunk_duration=self._config.audio_chunk_duration + ) + + # 音频组件 + self._capture: Optional[AudioCapture] = None + self._playback: Optional[AudioPlayback] = None + self._encoder: Optional[AudioEncoder] = None + self._decoder: Optional[AudioDecoder] = None + + # 通话状态 + self._state = CallState.IDLE + self._call_info: Optional[CallInfo] = None + self._is_muted = False + + # 网络组件 + self._udp_socket: Optional[socket.socket] = None + self._udp_port: int = 0 + self._peer_address: Optional[Tuple[str, int]] = None + + # 异步任务 + self._send_task: Optional[asyncio.Task] = None + self._receive_task: Optional[asyncio.Task] = None + + # 序列号 + self._sequence_number = 0 + + # 网络统计 + self._network_stats = NetworkStats() + + # 回调 + self._state_callbacks: List[CallStateCallback] = [] + self._incoming_call_callbacks: List[IncomingCallCallback] = [] + + # 消息发送回调(由ConnectionManager设置) + self._send_message_callback: Optional[Callable] = None + + # 用户信息 + self._user_id: str = "" + self._username: str = "" + + # 锁 + self._lock = asyncio.Lock() + + logger.info("VoiceChatModule initialized") + + # ==================== 属性 ==================== + + @property + def state(self) -> CallState: + """获取当前通话状态""" + return self._state + + @property + def is_in_call(self) -> bool: + """是否在通话中""" + return self._state == CallState.CONNECTED + + @property + def is_muted(self) -> bool: + """是否静音""" + return self._is_muted + + @property + def call_info(self) -> Optional[CallInfo]: + """获取当前通话信息""" + return self._call_info + + @property + def udp_port(self) -> int: + """获取UDP端口""" + return self._udp_port + + # ==================== 初始化和配置 ==================== + + def set_user_info(self, user_id: str, username: str) -> None: + """ + 设置用户信息 + + Args: + user_id: 用户ID + username: 用户名 + """ + self._user_id = user_id + self._username = username + + def set_send_message_callback(self, callback: Callable) -> None: + """ + 设置消息发送回调 + + Args: + callback: 发送消息的回调函数 + """ + self._send_message_callback = callback + + def add_state_callback(self, callback: CallStateCallback) -> None: + """添加状态变更回调""" + self._state_callbacks.append(callback) + + def remove_state_callback(self, callback: CallStateCallback) -> None: + """移除状态变更回调""" + if callback in self._state_callbacks: + self._state_callbacks.remove(callback) + + def add_incoming_call_callback(self, callback: IncomingCallCallback) -> None: + """添加来电回调""" + self._incoming_call_callbacks.append(callback) + + def remove_incoming_call_callback(self, callback: IncomingCallCallback) -> None: + """移除来电回调""" + if callback in self._incoming_call_callbacks: + self._incoming_call_callbacks.remove(callback) + + def _set_state(self, state: CallState, reason: Optional[str] = None) -> None: + """设置通话状态并通知回调""" + old_state = self._state + self._state = state + + if old_state != state: + logger.info(f"Call state changed: {old_state.value} -> {state.value}" + + (f" ({reason})" if reason else "")) + + for callback in self._state_callbacks: + try: + callback(state, reason) + except Exception as e: + logger.error(f"Error in state callback: {e}") + + # ==================== UDP网络 ==================== + + async def _init_udp_socket(self) -> bool: + """ + 初始化UDP套接字 + + Returns: + 成功返回True,否则返回False + """ + try: + self._udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._udp_socket.setblocking(False) + + # 尝试绑定到可用端口 + for port in range(self.UDP_PORT_MIN, self.UDP_PORT_MAX): + try: + self._udp_socket.bind(('', port)) + self._udp_port = port + logger.info(f"UDP socket bound to port {port}") + return True + except OSError: + continue + + logger.error("No available UDP port") + return False + + except Exception as e: + logger.error(f"Failed to initialize UDP socket: {e}") + return False + + def _close_udp_socket(self) -> None: + """关闭UDP套接字""" + if self._udp_socket: + try: + self._udp_socket.close() + except Exception: + pass + self._udp_socket = None + self._udp_port = 0 + + # ==================== 通话控制 (需求 7.1, 7.2, 7.6) ==================== + + async def start_call(self, peer_id: str, peer_name: str = "") -> bool: + """ + 发起语音通话 + + 向目标用户发送通话邀请 (需求 7.1) + + Args: + peer_id: 目标用户ID + peer_name: 目标用户名 + + Returns: + 发起成功返回True,否则返回False + """ + if self._state != CallState.IDLE: + logger.warning(f"Cannot start call: current state is {self._state.value}") + return False + + if not self._send_message_callback: + logger.error("No message callback set") + return False + + async with self._lock: + try: + # 初始化UDP + if not await self._init_udp_socket(): + return False + + # 创建通话信息 + self._call_info = CallInfo( + peer_id=peer_id, + peer_name=peer_name, + is_outgoing=True + ) + + # 发送通话请求 + call_data = { + "caller_id": self._user_id, + "caller_name": self._username, + "udp_port": self._udp_port + } + + message = Message( + msg_type=MessageType.VOICE_CALL_REQUEST, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=time.time(), + payload=str(call_data).encode('utf-8') + ) + + success = await self._send_message_callback(peer_id, message) + + if success: + self._set_state(CallState.CALLING, f"Calling {peer_name or peer_id}") + logger.info(f"Call request sent to {peer_id}") + return True + else: + self._close_udp_socket() + self._call_info = None + return False + + except Exception as e: + logger.error(f"Failed to start call: {e}") + self._close_udp_socket() + self._call_info = None + return False + + async def accept_call(self, peer_id: str) -> bool: + """ + 接听语音通话 + + 接听来电并建立语音连接 (需求 7.2) + + Args: + peer_id: 来电用户ID + + Returns: + 接听成功返回True,否则返回False + """ + if self._state != CallState.RINGING: + logger.warning(f"Cannot accept call: current state is {self._state.value}") + return False + + if not self._call_info or self._call_info.peer_id != peer_id: + logger.warning(f"No incoming call from {peer_id}") + return False + + async with self._lock: + try: + # 发送接听响应 + accept_data = { + "callee_id": self._user_id, + "callee_name": self._username, + "udp_port": self._udp_port + } + + message = Message( + msg_type=MessageType.VOICE_CALL_ACCEPT, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=time.time(), + payload=str(accept_data).encode('utf-8') + ) + + success = await self._send_message_callback(peer_id, message) + + if success: + # 开始通话 + await self._start_audio_session() + return True + else: + return False + + except Exception as e: + logger.error(f"Failed to accept call: {e}") + return False + + def reject_call(self, peer_id: str) -> None: + """ + 拒绝语音通话 + + 拒绝来电 (需求 7.2) + + Args: + peer_id: 来电用户ID + """ + if self._state != CallState.RINGING: + return + + if not self._call_info or self._call_info.peer_id != peer_id: + return + + try: + # 发送拒绝响应 + if self._send_message_callback: + message = Message( + msg_type=MessageType.VOICE_CALL_REJECT, + sender_id=self._user_id, + receiver_id=peer_id, + timestamp=time.time(), + payload=b"" + ) + + # 同步发送(不等待结果) + asyncio.create_task(self._send_message_callback(peer_id, message)) + + logger.info(f"Call from {peer_id} rejected") + + finally: + self._cleanup_call() + + def end_call(self) -> None: + """ + 结束通话 + + 释放音频资源并关闭连接 (需求 7.6) + """ + if self._state == CallState.IDLE: + return + + self._set_state(CallState.ENDING, "Ending call") + + try: + # 发送结束通知 + if self._call_info and self._send_message_callback: + message = Message( + msg_type=MessageType.VOICE_CALL_END, + sender_id=self._user_id, + receiver_id=self._call_info.peer_id, + timestamp=time.time(), + payload=b"" + ) + + asyncio.create_task(self._send_message_callback( + self._call_info.peer_id, message + )) + + logger.info("Call ended") + + finally: + self._cleanup_call() + + def _cleanup_call(self) -> None: + """清理通话资源""" + # 停止音频 + self._stop_audio_session() + + # 关闭UDP + self._close_udp_socket() + + # 重置状态 + self._call_info = None + self._peer_address = None + self._is_muted = False + self._sequence_number = 0 + self._network_stats = NetworkStats() + + self._set_state(CallState.IDLE, "Call ended") + + + # ==================== 音频会话管理 ==================== + + async def _start_audio_session(self) -> None: + """ + 开始音频会话 + + 初始化音频采集、播放、编解码器,并启动传输任务 + """ + try: + # 初始化音频组件 + self._capture = AudioCapture(self._audio_config) + self._playback = AudioPlayback(self._audio_config) + self._encoder = AudioEncoder(self._audio_config) + self._decoder = AudioDecoder(self._audio_config) + + # 启动音频采集和播放 + self._capture.start() + self._playback.start() + + # 记录通话开始时间 + if self._call_info: + self._call_info.start_time = datetime.now() + + # 启动发送和接收任务 + self._send_task = asyncio.create_task(self._audio_send_loop()) + self._receive_task = asyncio.create_task(self._audio_receive_loop()) + + self._set_state(CallState.CONNECTED, "Call connected") + logger.info("Audio session started") + + except Exception as e: + logger.error(f"Failed to start audio session: {e}") + self._stop_audio_session() + raise + + def _stop_audio_session(self) -> None: + """停止音频会话""" + # 取消任务 + if self._send_task and not self._send_task.done(): + self._send_task.cancel() + if self._receive_task and not self._receive_task.done(): + self._receive_task.cancel() + + # 停止音频组件 + if self._capture: + self._capture.stop() + self._capture = None + + if self._playback: + self._playback.stop() + self._playback = None + + self._encoder = None + self._decoder = None + + logger.info("Audio session stopped") + + # ==================== 实时音频传输 (需求 7.3, 7.4) ==================== + + async def _audio_send_loop(self) -> None: + """ + 音频发送循环 + + 实时采集和传输音频数据 (需求 7.3) + """ + loop = asyncio.get_event_loop() + + while self._state == CallState.CONNECTED: + try: + # 获取采集到的音频 + if self._capture and not self._is_muted: + audio_data = self._capture.get_audio() + + if audio_data: + # 编码音频 + if self._encoder: + encoded_data = self._encoder.encode(audio_data) + else: + encoded_data = audio_data + + # 发送音频数据 + await self._send_audio_packet(encoded_data) + + # 短暂休眠,避免CPU占用过高 + await asyncio.sleep(0.001) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Audio send error: {e}") + await asyncio.sleep(0.01) + + async def _audio_receive_loop(self) -> None: + """ + 音频接收循环 + + 接收和播放远端音频数据 (需求 7.3, 7.4) + """ + loop = asyncio.get_event_loop() + + while self._state == CallState.CONNECTED: + try: + if self._udp_socket: + try: + # 非阻塞接收 + data, addr = await asyncio.wait_for( + loop.sock_recvfrom(self._udp_socket, 4096), + timeout=0.1 + ) + + # 处理接收到的音频数据 + await self._handle_audio_packet(data, addr) + + except asyncio.TimeoutError: + continue + else: + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Audio receive error: {e}") + await asyncio.sleep(0.01) + + async def _send_audio_packet(self, audio_data: bytes) -> None: + """ + 发送音频数据包 + + Args: + audio_data: 编码后的音频数据 + """ + if not self._udp_socket or not self._peer_address: + return + + try: + # 构建数据包头 + self._sequence_number += 1 + timestamp = time.time() + + header = struct.pack( + self.AUDIO_HEADER_FORMAT, + self._sequence_number, + timestamp + ) + + # 发送数据包 + packet = header + audio_data + loop = asyncio.get_event_loop() + await loop.sock_sendto(self._udp_socket, packet, self._peer_address) + + # 更新统计 + self._network_stats.packets_sent += 1 + + except Exception as e: + logger.error(f"Failed to send audio packet: {e}") + + async def _handle_audio_packet(self, data: bytes, addr: Tuple[str, int]) -> None: + """ + 处理接收到的音频数据包 + + Args: + data: 接收到的数据 + addr: 发送方地址 + """ + if len(data) < self.AUDIO_HEADER_SIZE: + return + + try: + # 解析数据包头 + header = data[:self.AUDIO_HEADER_SIZE] + sequence, timestamp = struct.unpack(self.AUDIO_HEADER_FORMAT, header) + + # 提取音频数据 + audio_data = data[self.AUDIO_HEADER_SIZE:] + + # 解码音频 + if self._decoder: + pcm_data = self._decoder.decode(audio_data) + else: + pcm_data = audio_data + + # 放入播放缓冲区 + if self._playback: + self._playback.push_audio(sequence, pcm_data, timestamp) + + # 更新统计 + self._network_stats.packets_received += 1 + + # 计算延迟 + latency = (time.time() - timestamp) * 1000 # 转换为毫秒 + self._update_latency_stats(latency) + + except Exception as e: + logger.error(f"Failed to handle audio packet: {e}") + + def _update_latency_stats(self, latency: float) -> None: + """更新延迟统计""" + # 使用指数移动平均 + alpha = 0.1 + if self._network_stats.avg_latency == 0: + self._network_stats.avg_latency = latency + else: + self._network_stats.avg_latency = ( + alpha * latency + (1 - alpha) * self._network_stats.avg_latency + ) + + # 更新抖动 + jitter_diff = abs(latency - self._network_stats.avg_latency) + self._network_stats.jitter = ( + alpha * jitter_diff + (1 - alpha) * self._network_stats.jitter + ) + + self._network_stats.last_update = time.time() + + # ==================== 静音功能 (需求 7.5) ==================== + + def mute(self, muted: bool) -> None: + """ + 设置静音状态 + + 静音后停止发送本地音频但继续接收对方音频 (需求 7.5) + + Args: + muted: True表示静音,False表示取消静音 + """ + self._is_muted = muted + logger.info(f"Mute {'enabled' if muted else 'disabled'}") + + def toggle_mute(self) -> bool: + """ + 切换静音状态 + + Returns: + 切换后的静音状态 + """ + self._is_muted = not self._is_muted + logger.info(f"Mute {'enabled' if self._is_muted else 'disabled'}") + return self._is_muted + + # ==================== 网络质量和自适应 (需求 7.7) ==================== + + def get_call_duration(self) -> float: + """ + 获取通话时长 + + Returns: + 通话时长(秒) + """ + if self._call_info: + return self._call_info.duration + return 0.0 + + def get_network_quality(self) -> NetworkQuality: + """ + 获取网络质量指标 + + Returns: + 网络质量枚举值 + """ + return self._network_stats.get_quality() + + def get_network_stats(self) -> NetworkStats: + """ + 获取详细网络统计 + + Returns: + 网络统计信息 + """ + return self._network_stats + + def _adjust_encoding_quality(self) -> None: + """ + 自适应调整编码参数 + + 根据网络质量自动调整音频编码参数 (需求 7.7) + """ + if not self._encoder: + return + + quality = self.get_network_quality() + + # 根据网络质量调整比特率 + bitrate_map = { + NetworkQuality.EXCELLENT: 32000, # 32kbps + NetworkQuality.GOOD: 24000, # 24kbps + NetworkQuality.FAIR: 16000, # 16kbps + NetworkQuality.POOR: 12000, # 12kbps + NetworkQuality.BAD: 8000, # 8kbps + } + + new_bitrate = bitrate_map.get(quality, 24000) + + if new_bitrate != self._encoder.bitrate: + self._encoder.set_bitrate(new_bitrate) + logger.info(f"Adjusted bitrate to {new_bitrate}bps due to {quality.value} network") + + # ==================== 消息处理 ==================== + + async def handle_voice_message(self, message: Message) -> None: + """ + 处理语音相关消息 + + Args: + message: 接收到的消息 + """ + if message.msg_type == MessageType.VOICE_CALL_REQUEST: + await self._handle_call_request(message) + elif message.msg_type == MessageType.VOICE_CALL_ACCEPT: + await self._handle_call_accept(message) + elif message.msg_type == MessageType.VOICE_CALL_REJECT: + await self._handle_call_reject(message) + elif message.msg_type == MessageType.VOICE_CALL_END: + await self._handle_call_end(message) + + async def _handle_call_request(self, message: Message) -> None: + """处理来电请求""" + if self._state != CallState.IDLE: + # 正忙,自动拒绝 + if self._send_message_callback: + reject_msg = Message( + msg_type=MessageType.VOICE_CALL_REJECT, + sender_id=self._user_id, + receiver_id=message.sender_id, + timestamp=time.time(), + payload=b"busy" + ) + await self._send_message_callback(message.sender_id, reject_msg) + return + + try: + # 解析来电信息 + call_data = eval(message.payload.decode('utf-8')) + caller_id = call_data.get("caller_id", message.sender_id) + caller_name = call_data.get("caller_name", "") + peer_udp_port = call_data.get("udp_port", 0) + + # 初始化UDP + if not await self._init_udp_socket(): + return + + # 创建通话信息 + self._call_info = CallInfo( + peer_id=caller_id, + peer_name=caller_name, + is_outgoing=False + ) + + # 设置对端地址(需要从消息中获取IP) + # 这里假设通过服务器中转,实际IP需要从其他途径获取 + + self._set_state(CallState.RINGING, f"Incoming call from {caller_name or caller_id}") + + # 通知来电回调 + for callback in self._incoming_call_callbacks: + try: + callback(caller_id, caller_name) + except Exception as e: + logger.error(f"Error in incoming call callback: {e}") + + logger.info(f"Incoming call from {caller_id}") + + except Exception as e: + logger.error(f"Failed to handle call request: {e}") + self._cleanup_call() + + async def _handle_call_accept(self, message: Message) -> None: + """处理通话接受响应""" + if self._state != CallState.CALLING: + return + + if not self._call_info or self._call_info.peer_id != message.sender_id: + return + + try: + # 解析响应信息 + accept_data = eval(message.payload.decode('utf-8')) + peer_udp_port = accept_data.get("udp_port", 0) + + # 设置对端地址(需要从消息中获取IP) + # 这里需要实际的IP地址,暂时使用占位符 + # self._peer_address = (peer_ip, peer_udp_port) + + # 开始音频会话 + await self._start_audio_session() + + logger.info(f"Call accepted by {message.sender_id}") + + except Exception as e: + logger.error(f"Failed to handle call accept: {e}") + self._cleanup_call() + + async def _handle_call_reject(self, message: Message) -> None: + """处理通话拒绝响应""" + if self._state != CallState.CALLING: + return + + reason = message.payload.decode('utf-8') if message.payload else "rejected" + logger.info(f"Call rejected by {message.sender_id}: {reason}") + + self._cleanup_call() + + async def _handle_call_end(self, message: Message) -> None: + """处理通话结束消息""" + if self._state == CallState.IDLE: + return + + logger.info(f"Call ended by {message.sender_id}") + self._cleanup_call() diff --git a/requirements.txt b/requirements.txt index 0375246..32f6f19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ PyAudio>=0.2.13 opencv-python>=4.8.0 ffmpeg-python>=0.2.0 mutagen>=1.47.0 +opuslib>=3.0.1 # Image Processing Pillow>=10.0.0 diff --git a/tests/test_voice_chat.py b/tests/test_voice_chat.py new file mode 100644 index 0000000..5c1fdfe --- /dev/null +++ b/tests/test_voice_chat.py @@ -0,0 +1,390 @@ +# P2P Network Communication - Voice Chat Module Tests +""" +语音聊天模块测试 +测试音频采集、编码、通话控制和实时传输功能 +""" + +import asyncio +import pytest +import struct +import time +from unittest.mock import Mock, MagicMock, patch, AsyncMock +from datetime import datetime + +from client.voice_chat import ( + VoiceChatModule, + CallState, + VoiceChatError, + AudioDeviceError, + CallError, + AudioConfig, + CallInfo, + NetworkStats, + JitterBuffer, + AudioCapture, + AudioPlayback, + AudioEncoder, + AudioDecoder, +) +from shared.models import Message, MessageType, NetworkQuality + + +class TestAudioConfig: + """音频配置测试""" + + def test_default_config(self): + """测试默认配置""" + config = AudioConfig() + assert config.sample_rate == 16000 + assert config.channels == 1 + assert config.chunk_duration == 0.02 + assert config.bits_per_sample == 16 + + def test_chunk_size_calculation(self): + """测试音频块大小计算""" + config = AudioConfig(sample_rate=16000, chunk_duration=0.02) + # 16000 * 0.02 = 320 samples + assert config.chunk_size == 320 + + def test_bytes_per_chunk_calculation(self): + """测试每块字节数计算""" + config = AudioConfig( + sample_rate=16000, + channels=1, + chunk_duration=0.02, + bits_per_sample=16 + ) + # 320 samples * 1 channel * 2 bytes = 640 bytes + assert config.bytes_per_chunk == 640 + + +class TestCallInfo: + """通话信息测试""" + + def test_call_info_creation(self): + """测试通话信息创建""" + info = CallInfo( + peer_id="user123", + peer_name="Test User", + is_outgoing=True + ) + assert info.peer_id == "user123" + assert info.peer_name == "Test User" + assert info.is_outgoing is True + assert info.start_time is None + + def test_call_duration_not_started(self): + """测试未开始通话的时长""" + info = CallInfo(peer_id="user123", peer_name="Test") + assert info.duration == 0.0 + + def test_call_duration_started(self): + """测试已开始通话的时长""" + info = CallInfo(peer_id="user123", peer_name="Test") + info.start_time = datetime.now() + time.sleep(0.1) + assert info.duration >= 0.1 + + +class TestNetworkStats: + """网络统计测试""" + + def test_default_stats(self): + """测试默认统计值""" + stats = NetworkStats() + assert stats.packets_sent == 0 + assert stats.packets_received == 0 + assert stats.packets_lost == 0 + assert stats.avg_latency == 0.0 + + def test_packet_loss_rate_no_packets(self): + """测试无数据包时的丢包率""" + stats = NetworkStats() + assert stats.packet_loss_rate == 0.0 + + def test_packet_loss_rate_with_loss(self): + """测试有丢包时的丢包率""" + stats = NetworkStats( + packets_sent=100, + packets_received=90, + packets_lost=10 + ) + # 10 / (100 + 90) ≈ 0.0526 + assert 0.05 < stats.packet_loss_rate < 0.06 + + def test_network_quality_excellent(self): + """测试优秀网络质量""" + stats = NetworkStats(avg_latency=30) + assert stats.get_quality() == NetworkQuality.EXCELLENT + + def test_network_quality_good(self): + """测试良好网络质量""" + stats = NetworkStats(avg_latency=75) + assert stats.get_quality() == NetworkQuality.GOOD + + def test_network_quality_fair(self): + """测试一般网络质量""" + stats = NetworkStats(avg_latency=150) + assert stats.get_quality() == NetworkQuality.FAIR + + def test_network_quality_poor(self): + """测试较差网络质量""" + stats = NetworkStats(avg_latency=250) + assert stats.get_quality() == NetworkQuality.POOR + + def test_network_quality_bad(self): + """测试很差网络质量""" + stats = NetworkStats(avg_latency=400) + assert stats.get_quality() == NetworkQuality.BAD + + +class TestJitterBuffer: + """抖动缓冲区测试""" + + def test_buffer_creation(self): + """测试缓冲区创建""" + buffer = JitterBuffer() + assert buffer.size == 0 + assert buffer.delay == 0.0 + + def test_push_and_pop(self): + """测试数据包入队和出队""" + buffer = JitterBuffer(target_delay=0.0) # 禁用延迟等待 + + # 添加数据包 + buffer.push(1, b"audio_data_1", time.time() - 0.1) + buffer.push(2, b"audio_data_2", time.time() - 0.05) + buffer.push(3, b"audio_data_3", time.time()) + + assert buffer.size == 3 + + # 取出数据包(按序列号顺序) + data = buffer.pop() + assert data == b"audio_data_1" + assert buffer.size == 2 + + def test_out_of_order_packets(self): + """测试乱序数据包处理""" + buffer = JitterBuffer(target_delay=0.0) + + # 乱序添加 + buffer.push(3, b"data_3", time.time()) + buffer.push(1, b"data_1", time.time()) + buffer.push(2, b"data_2", time.time()) + + # 应该按序列号顺序取出 + assert buffer.pop() == b"data_1" + assert buffer.pop() == b"data_2" + assert buffer.pop() == b"data_3" + + def test_duplicate_packet_ignored(self): + """测试重复数据包被忽略""" + buffer = JitterBuffer(target_delay=0.0) + + buffer.push(1, b"data_1", time.time()) + buffer.push(1, b"data_1_dup", time.time()) # 重复 + + assert buffer.size == 1 + + def test_old_packet_ignored(self): + """测试过期数据包被忽略""" + buffer = JitterBuffer(target_delay=0.0) + + buffer.push(5, b"data_5", time.time()) + buffer.pop() # 取出序列号5 + + buffer.push(3, b"data_3", time.time()) # 旧包,应被忽略 + assert buffer.size == 0 + + def test_clear_buffer(self): + """测试清空缓冲区""" + buffer = JitterBuffer() + buffer.push(1, b"data", time.time()) + buffer.push(2, b"data", time.time()) + + buffer.clear() + assert buffer.size == 0 + + +class TestVoiceChatModule: + """语音聊天模块测试""" + + @pytest.fixture + def voice_chat(self): + """创建语音聊天模块实例""" + module = VoiceChatModule() + module.set_user_info("test_user", "Test User") + return module + + def test_initial_state(self, voice_chat): + """测试初始状态""" + assert voice_chat.state == CallState.IDLE + assert voice_chat.is_in_call is False + assert voice_chat.is_muted is False + assert voice_chat.call_info is None + + def test_mute_toggle(self, voice_chat): + """测试静音切换""" + assert voice_chat.is_muted is False + + voice_chat.mute(True) + assert voice_chat.is_muted is True + + voice_chat.mute(False) + assert voice_chat.is_muted is False + + def test_toggle_mute(self, voice_chat): + """测试静音切换方法""" + assert voice_chat.toggle_mute() is True + assert voice_chat.toggle_mute() is False + + def test_get_call_duration_no_call(self, voice_chat): + """测试无通话时的时长""" + assert voice_chat.get_call_duration() == 0.0 + + def test_get_network_quality_default(self, voice_chat): + """测试默认网络质量""" + # 默认延迟为0,应该是EXCELLENT + quality = voice_chat.get_network_quality() + assert quality == NetworkQuality.EXCELLENT + + def test_state_callback(self, voice_chat): + """测试状态回调""" + callback_called = [] + + def state_callback(state, reason): + callback_called.append((state, reason)) + + voice_chat.add_state_callback(state_callback) + voice_chat._set_state(CallState.CALLING, "test") + + assert len(callback_called) == 1 + assert callback_called[0][0] == CallState.CALLING + assert callback_called[0][1] == "test" + + def test_remove_state_callback(self, voice_chat): + """测试移除状态回调""" + callback_called = [] + + def state_callback(state, reason): + callback_called.append((state, reason)) + + voice_chat.add_state_callback(state_callback) + voice_chat.remove_state_callback(state_callback) + voice_chat._set_state(CallState.CALLING, "test") + + assert len(callback_called) == 0 + + @pytest.mark.asyncio + async def test_start_call_no_callback(self, voice_chat): + """测试无消息回调时发起通话""" + result = await voice_chat.start_call("peer123", "Peer User") + assert result is False + assert voice_chat.state == CallState.IDLE + + @pytest.mark.asyncio + async def test_start_call_not_idle(self, voice_chat): + """测试非空闲状态发起通话""" + voice_chat._state = CallState.CONNECTED + + async def mock_send(peer_id, msg): + return True + + voice_chat.set_send_message_callback(mock_send) + result = await voice_chat.start_call("peer123") + + assert result is False + + def test_reject_call_not_ringing(self, voice_chat): + """测试非响铃状态拒绝通话""" + voice_chat.reject_call("peer123") + # 应该不会改变状态 + assert voice_chat.state == CallState.IDLE + + def test_end_call_idle(self, voice_chat): + """测试空闲状态结束通话""" + voice_chat.end_call() + # 应该保持空闲状态 + assert voice_chat.state == CallState.IDLE + + @pytest.mark.asyncio + async def test_accept_call_not_ringing(self, voice_chat): + """测试非响铃状态接听通话""" + result = await voice_chat.accept_call("peer123") + assert result is False + + +class TestAudioEncoderDecoder: + """音频编解码器测试""" + + def test_encoder_without_opus(self): + """测试无Opus时的编码器""" + config = AudioConfig() + + with patch.dict('sys.modules', {'opuslib': None}): + encoder = AudioEncoder(config) + # 应该回退到原始音频 + assert encoder.is_opus_enabled is False or encoder._encoder is None + + def test_decoder_without_opus(self): + """测试无Opus时的解码器""" + config = AudioConfig() + + with patch.dict('sys.modules', {'opuslib': None}): + decoder = AudioDecoder(config) + # 应该回退到原始音频 + assert decoder.is_opus_enabled is False or decoder._decoder is None + + def test_encoder_raw_passthrough(self): + """测试编码器原始数据透传""" + config = AudioConfig() + encoder = AudioEncoder(config) + encoder._use_opus = False + + test_data = b"test_audio_data" + result = encoder.encode(test_data) + assert result == test_data + + def test_decoder_raw_passthrough(self): + """测试解码器原始数据透传""" + config = AudioConfig() + decoder = AudioDecoder(config) + decoder._use_opus = False + + test_data = b"test_audio_data" + result = decoder.decode(test_data) + assert result == test_data + + def test_encoder_bitrate_property(self): + """测试编码器比特率属性""" + config = AudioConfig() + encoder = AudioEncoder(config, bitrate=32000) + assert encoder.bitrate == 32000 + + +class TestAudioPacketFormat: + """音频数据包格式测试""" + + def test_header_format(self): + """测试数据包头格式""" + # 格式: 序列号(4字节) + 时间戳(8字节) + header_size = struct.calcsize("!Id") + assert header_size == VoiceChatModule.AUDIO_HEADER_SIZE + + def test_pack_unpack_header(self): + """测试数据包头打包和解包""" + sequence = 12345 + timestamp = time.time() + + # 打包 + header = struct.pack("!Id", sequence, timestamp) + + # 解包 + unpacked_seq, unpacked_ts = struct.unpack("!Id", header) + + assert unpacked_seq == sequence + assert abs(unpacked_ts - timestamp) < 0.001 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])