You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

391 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# P2P Network Communication - Voice Chat Module Tests
"""
语音聊天模块测试
测试音频采集、编码、通话控制和实时传输功能
"""
import asyncio
import pytest
import struct
import time
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime
from client.voice_chat import (
VoiceChatModule,
CallState,
VoiceChatError,
AudioDeviceError,
CallError,
AudioConfig,
CallInfo,
NetworkStats,
JitterBuffer,
AudioCapture,
AudioPlayback,
AudioEncoder,
AudioDecoder,
)
from shared.models import Message, MessageType, NetworkQuality
class TestAudioConfig:
"""音频配置测试"""
def test_default_config(self):
"""测试默认配置"""
config = AudioConfig()
assert config.sample_rate == 16000
assert config.channels == 1
assert config.chunk_duration == 0.02
assert config.bits_per_sample == 16
def test_chunk_size_calculation(self):
"""测试音频块大小计算"""
config = AudioConfig(sample_rate=16000, chunk_duration=0.02)
# 16000 * 0.02 = 320 samples
assert config.chunk_size == 320
def test_bytes_per_chunk_calculation(self):
"""测试每块字节数计算"""
config = AudioConfig(
sample_rate=16000,
channels=1,
chunk_duration=0.02,
bits_per_sample=16
)
# 320 samples * 1 channel * 2 bytes = 640 bytes
assert config.bytes_per_chunk == 640
class TestCallInfo:
"""通话信息测试"""
def test_call_info_creation(self):
"""测试通话信息创建"""
info = CallInfo(
peer_id="user123",
peer_name="Test User",
is_outgoing=True
)
assert info.peer_id == "user123"
assert info.peer_name == "Test User"
assert info.is_outgoing is True
assert info.start_time is None
def test_call_duration_not_started(self):
"""测试未开始通话的时长"""
info = CallInfo(peer_id="user123", peer_name="Test")
assert info.duration == 0.0
def test_call_duration_started(self):
"""测试已开始通话的时长"""
info = CallInfo(peer_id="user123", peer_name="Test")
info.start_time = datetime.now()
time.sleep(0.1)
assert info.duration >= 0.1
class TestNetworkStats:
"""网络统计测试"""
def test_default_stats(self):
"""测试默认统计值"""
stats = NetworkStats()
assert stats.packets_sent == 0
assert stats.packets_received == 0
assert stats.packets_lost == 0
assert stats.avg_latency == 0.0
def test_packet_loss_rate_no_packets(self):
"""测试无数据包时的丢包率"""
stats = NetworkStats()
assert stats.packet_loss_rate == 0.0
def test_packet_loss_rate_with_loss(self):
"""测试有丢包时的丢包率"""
stats = NetworkStats(
packets_sent=100,
packets_received=90,
packets_lost=10
)
# 10 / (100 + 90) ≈ 0.0526
assert 0.05 < stats.packet_loss_rate < 0.06
def test_network_quality_excellent(self):
"""测试优秀网络质量"""
stats = NetworkStats(avg_latency=30)
assert stats.get_quality() == NetworkQuality.EXCELLENT
def test_network_quality_good(self):
"""测试良好网络质量"""
stats = NetworkStats(avg_latency=75)
assert stats.get_quality() == NetworkQuality.GOOD
def test_network_quality_fair(self):
"""测试一般网络质量"""
stats = NetworkStats(avg_latency=150)
assert stats.get_quality() == NetworkQuality.FAIR
def test_network_quality_poor(self):
"""测试较差网络质量"""
stats = NetworkStats(avg_latency=250)
assert stats.get_quality() == NetworkQuality.POOR
def test_network_quality_bad(self):
"""测试很差网络质量"""
stats = NetworkStats(avg_latency=400)
assert stats.get_quality() == NetworkQuality.BAD
class TestJitterBuffer:
"""抖动缓冲区测试"""
def test_buffer_creation(self):
"""测试缓冲区创建"""
buffer = JitterBuffer()
assert buffer.size == 0
assert buffer.delay == 0.0
def test_push_and_pop(self):
"""测试数据包入队和出队"""
buffer = JitterBuffer(target_delay=0.0) # 禁用延迟等待
# 添加数据包
buffer.push(1, b"audio_data_1", time.time() - 0.1)
buffer.push(2, b"audio_data_2", time.time() - 0.05)
buffer.push(3, b"audio_data_3", time.time())
assert buffer.size == 3
# 取出数据包(按序列号顺序)
data = buffer.pop()
assert data == b"audio_data_1"
assert buffer.size == 2
def test_out_of_order_packets(self):
"""测试乱序数据包处理"""
buffer = JitterBuffer(target_delay=0.0)
# 乱序添加
buffer.push(3, b"data_3", time.time())
buffer.push(1, b"data_1", time.time())
buffer.push(2, b"data_2", time.time())
# 应该按序列号顺序取出
assert buffer.pop() == b"data_1"
assert buffer.pop() == b"data_2"
assert buffer.pop() == b"data_3"
def test_duplicate_packet_ignored(self):
"""测试重复数据包被忽略"""
buffer = JitterBuffer(target_delay=0.0)
buffer.push(1, b"data_1", time.time())
buffer.push(1, b"data_1_dup", time.time()) # 重复
assert buffer.size == 1
def test_old_packet_ignored(self):
"""测试过期数据包被忽略"""
buffer = JitterBuffer(target_delay=0.0)
buffer.push(5, b"data_5", time.time())
buffer.pop() # 取出序列号5
buffer.push(3, b"data_3", time.time()) # 旧包,应被忽略
assert buffer.size == 0
def test_clear_buffer(self):
"""测试清空缓冲区"""
buffer = JitterBuffer()
buffer.push(1, b"data", time.time())
buffer.push(2, b"data", time.time())
buffer.clear()
assert buffer.size == 0
class TestVoiceChatModule:
"""语音聊天模块测试"""
@pytest.fixture
def voice_chat(self):
"""创建语音聊天模块实例"""
module = VoiceChatModule()
module.set_user_info("test_user", "Test User")
return module
def test_initial_state(self, voice_chat):
"""测试初始状态"""
assert voice_chat.state == CallState.IDLE
assert voice_chat.is_in_call is False
assert voice_chat.is_muted is False
assert voice_chat.call_info is None
def test_mute_toggle(self, voice_chat):
"""测试静音切换"""
assert voice_chat.is_muted is False
voice_chat.mute(True)
assert voice_chat.is_muted is True
voice_chat.mute(False)
assert voice_chat.is_muted is False
def test_toggle_mute(self, voice_chat):
"""测试静音切换方法"""
assert voice_chat.toggle_mute() is True
assert voice_chat.toggle_mute() is False
def test_get_call_duration_no_call(self, voice_chat):
"""测试无通话时的时长"""
assert voice_chat.get_call_duration() == 0.0
def test_get_network_quality_default(self, voice_chat):
"""测试默认网络质量"""
# 默认延迟为0应该是EXCELLENT
quality = voice_chat.get_network_quality()
assert quality == NetworkQuality.EXCELLENT
def test_state_callback(self, voice_chat):
"""测试状态回调"""
callback_called = []
def state_callback(state, reason):
callback_called.append((state, reason))
voice_chat.add_state_callback(state_callback)
voice_chat._set_state(CallState.CALLING, "test")
assert len(callback_called) == 1
assert callback_called[0][0] == CallState.CALLING
assert callback_called[0][1] == "test"
def test_remove_state_callback(self, voice_chat):
"""测试移除状态回调"""
callback_called = []
def state_callback(state, reason):
callback_called.append((state, reason))
voice_chat.add_state_callback(state_callback)
voice_chat.remove_state_callback(state_callback)
voice_chat._set_state(CallState.CALLING, "test")
assert len(callback_called) == 0
@pytest.mark.asyncio
async def test_start_call_no_callback(self, voice_chat):
"""测试无消息回调时发起通话"""
result = await voice_chat.start_call("peer123", "Peer User")
assert result is False
assert voice_chat.state == CallState.IDLE
@pytest.mark.asyncio
async def test_start_call_not_idle(self, voice_chat):
"""测试非空闲状态发起通话"""
voice_chat._state = CallState.CONNECTED
async def mock_send(peer_id, msg):
return True
voice_chat.set_send_message_callback(mock_send)
result = await voice_chat.start_call("peer123")
assert result is False
def test_reject_call_not_ringing(self, voice_chat):
"""测试非响铃状态拒绝通话"""
voice_chat.reject_call("peer123")
# 应该不会改变状态
assert voice_chat.state == CallState.IDLE
def test_end_call_idle(self, voice_chat):
"""测试空闲状态结束通话"""
voice_chat.end_call()
# 应该保持空闲状态
assert voice_chat.state == CallState.IDLE
@pytest.mark.asyncio
async def test_accept_call_not_ringing(self, voice_chat):
"""测试非响铃状态接听通话"""
result = await voice_chat.accept_call("peer123")
assert result is False
class TestAudioEncoderDecoder:
"""音频编解码器测试"""
def test_encoder_without_opus(self):
"""测试无Opus时的编码器"""
config = AudioConfig()
with patch.dict('sys.modules', {'opuslib': None}):
encoder = AudioEncoder(config)
# 应该回退到原始音频
assert encoder.is_opus_enabled is False or encoder._encoder is None
def test_decoder_without_opus(self):
"""测试无Opus时的解码器"""
config = AudioConfig()
with patch.dict('sys.modules', {'opuslib': None}):
decoder = AudioDecoder(config)
# 应该回退到原始音频
assert decoder.is_opus_enabled is False or decoder._decoder is None
def test_encoder_raw_passthrough(self):
"""测试编码器原始数据透传"""
config = AudioConfig()
encoder = AudioEncoder(config)
encoder._use_opus = False
test_data = b"test_audio_data"
result = encoder.encode(test_data)
assert result == test_data
def test_decoder_raw_passthrough(self):
"""测试解码器原始数据透传"""
config = AudioConfig()
decoder = AudioDecoder(config)
decoder._use_opus = False
test_data = b"test_audio_data"
result = decoder.decode(test_data)
assert result == test_data
def test_encoder_bitrate_property(self):
"""测试编码器比特率属性"""
config = AudioConfig()
encoder = AudioEncoder(config, bitrate=32000)
assert encoder.bitrate == 32000
class TestAudioPacketFormat:
"""音频数据包格式测试"""
def test_header_format(self):
"""测试数据包头格式"""
# 格式: 序列号(4字节) + 时间戳(8字节)
header_size = struct.calcsize("!Id")
assert header_size == VoiceChatModule.AUDIO_HEADER_SIZE
def test_pack_unpack_header(self):
"""测试数据包头打包和解包"""
sequence = 12345
timestamp = time.time()
# 打包
header = struct.pack("!Id", sequence, timestamp)
# 解包
unpacked_seq, unpacked_ts = struct.unpack("!Id", header)
assert unpacked_seq == sequence
assert abs(unpacked_ts - timestamp) < 0.001
if __name__ == "__main__":
pytest.main([__file__, "-v"])