You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
892 lines
33 KiB
892 lines
33 KiB
6 months ago
|
from __future__ import annotations
|
||
|
|
||
|
import random
|
||
|
from contextlib import asynccontextmanager
|
||
|
from itertools import count
|
||
|
from typing import TYPE_CHECKING, NoReturn
|
||
|
|
||
|
import attrs
|
||
|
import pytest
|
||
|
|
||
|
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||
|
|
||
|
try:
|
||
|
import trustme
|
||
|
from OpenSSL import SSL
|
||
|
except ImportError as error:
|
||
|
skip_if_optional_else_raise(error)
|
||
|
|
||
|
|
||
|
import trio
|
||
|
import trio.testing
|
||
|
from trio import DTLSChannel, DTLSEndpoint
|
||
|
from trio.testing._fake_net import FakeNet, UDPPacket
|
||
|
|
||
|
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from collections.abc import AsyncGenerator
|
||
|
|
||
|
ca = trustme.CA()
|
||
|
server_cert = ca.issue_cert("example.com")
|
||
|
|
||
|
server_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||
|
server_cert.configure_cert(server_ctx)
|
||
|
|
||
|
client_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||
|
ca.configure_trust(client_ctx)
|
||
|
|
||
|
|
||
|
parametrize_ipv6 = pytest.mark.parametrize(
|
||
|
"ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"]
|
||
|
)
|
||
|
|
||
|
|
||
|
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
|
||
|
ipv6 = kwargs.pop("ipv6", False)
|
||
|
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
|
||
|
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
|
||
|
return DTLSEndpoint(sock, **kwargs)
|
||
|
|
||
|
|
||
|
@asynccontextmanager
|
||
|
async def dtls_echo_server(
|
||
|
*, autocancel: bool = True, mtu: int | None = None, ipv6: bool = False
|
||
|
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
|
||
|
with endpoint(ipv6=ipv6) as server:
|
||
|
localhost = "::1" if ipv6 else "127.0.0.1"
|
||
|
await server.socket.bind((localhost, 0))
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
|
||
|
async def echo_handler(dtls_channel: DTLSChannel) -> None:
|
||
|
print(
|
||
|
"echo handler started: "
|
||
|
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
|
||
|
f"client {dtls_channel.peer_address!r}"
|
||
|
)
|
||
|
if mtu is not None:
|
||
|
dtls_channel.set_ciphertext_mtu(mtu)
|
||
|
try:
|
||
|
print("server starting do_handshake")
|
||
|
await dtls_channel.do_handshake()
|
||
|
print("server finished do_handshake")
|
||
|
async for packet in dtls_channel:
|
||
|
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
|
||
|
await dtls_channel.send(packet)
|
||
|
except trio.BrokenResourceError: # pragma: no cover
|
||
|
print("echo handler channel broken")
|
||
|
|
||
|
await nursery.start(server.serve, server_ctx, echo_handler)
|
||
|
|
||
|
yield server, server.socket.getsockname()
|
||
|
|
||
|
if autocancel:
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
|
||
|
@parametrize_ipv6
|
||
|
async def test_smoke(ipv6: bool) -> None:
|
||
|
async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
|
||
|
with endpoint(ipv6=ipv6) as client_endpoint:
|
||
|
client_channel = client_endpoint.connect(address, client_ctx)
|
||
|
with pytest.raises(trio.NeedHandshakeError):
|
||
|
client_channel.get_cleartext_mtu()
|
||
|
|
||
|
await client_channel.do_handshake()
|
||
|
await client_channel.send(b"hello")
|
||
|
assert await client_channel.receive() == b"hello"
|
||
|
await client_channel.send(b"goodbye")
|
||
|
assert await client_channel.receive() == b"goodbye"
|
||
|
|
||
|
with pytest.raises(
|
||
|
ValueError, match="^openssl doesn't support sending empty DTLS packets$"
|
||
|
):
|
||
|
await client_channel.send(b"")
|
||
|
|
||
|
client_channel.set_ciphertext_mtu(1234)
|
||
|
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
|
||
|
client_channel.set_ciphertext_mtu(4321)
|
||
|
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
|
||
|
client_channel.set_ciphertext_mtu(1234)
|
||
|
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
|
||
|
|
||
|
|
||
|
@slow
|
||
|
async def test_handshake_over_terrible_network(
|
||
|
autojump_clock: trio.testing.MockClock,
|
||
|
) -> None:
|
||
|
HANDSHAKES = 100
|
||
|
r = random.Random(0)
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
# avoid spurious timeouts on slow machines
|
||
|
autojump_clock.autojump_threshold = 0.001
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
|
||
|
async def route_packet(packet: UDPPacket) -> None:
|
||
|
while True:
|
||
|
op = r.choices(
|
||
|
["deliver", "drop", "dupe", "delay"],
|
||
|
weights=[0.7, 0.1, 0.1, 0.1],
|
||
|
)[0]
|
||
|
print(f"{packet.source} -> {packet.destination}: {op}")
|
||
|
if op == "drop":
|
||
|
return
|
||
|
elif op == "dupe":
|
||
|
fn.send_packet(packet)
|
||
|
elif op == "delay":
|
||
|
await trio.sleep(r.random() * 3)
|
||
|
# I wanted to test random packet corruption too, but it turns out
|
||
|
# openssl has a bug in the following scenario:
|
||
|
#
|
||
|
# - client sends ClientHello
|
||
|
# - server sends HelloVerifyRequest with cookie -- but cookie is
|
||
|
# invalid b/c either the ClientHello or HelloVerifyRequest was
|
||
|
# corrupted
|
||
|
# - client re-sends ClientHello with invalid cookie
|
||
|
# - server replies with new HelloVerifyRequest and correct cookie
|
||
|
#
|
||
|
# At this point, the client *should* switch to the new, valid
|
||
|
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
|
||
|
# the original, invalid cookie over and over. In theory we could
|
||
|
# work around this by detecting cookie changes and starting over
|
||
|
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
|
||
|
# when I tried then I ran into another issue where OpenSSL got stuck
|
||
|
# in an infinite loop sending alerts over and over, which I didn't
|
||
|
# dig into because see (a).
|
||
|
#
|
||
|
# elif op == "distort":
|
||
|
# payload = bytearray(packet.payload)
|
||
|
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
|
||
|
# packet = attrs.evolve(packet, payload=payload)
|
||
|
else:
|
||
|
assert op == "deliver"
|
||
|
print(
|
||
|
f"{packet.source} -> {packet.destination}: delivered"
|
||
|
f" {packet.payload.hex()}"
|
||
|
)
|
||
|
fn.deliver_packet(packet)
|
||
|
break
|
||
|
|
||
|
def route_packet_wrapper(packet: UDPPacket) -> None:
|
||
|
try: # noqa: SIM105 # suppressible-exception
|
||
|
nursery.start_soon(route_packet, packet)
|
||
|
except RuntimeError: # pragma: no cover
|
||
|
# We're exiting the nursery, so any remaining packets can just get
|
||
|
# dropped
|
||
|
pass
|
||
|
|
||
|
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||
|
|
||
|
for i in range(HANDSHAKES):
|
||
|
print("#" * 80)
|
||
|
print("#" * 80)
|
||
|
print("#" * 80)
|
||
|
with endpoint() as client_endpoint:
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
print("client starting do_handshake")
|
||
|
await client.do_handshake()
|
||
|
print("client finished do_handshake")
|
||
|
msg = str(i).encode()
|
||
|
# Make multiple attempts to send data, because the network might
|
||
|
# drop it
|
||
|
while True:
|
||
|
with trio.move_on_after(10) as cscope:
|
||
|
await client.send(msg)
|
||
|
assert await client.receive() == msg
|
||
|
if not cscope.cancelled_caught:
|
||
|
break
|
||
|
|
||
|
|
||
|
async def test_implicit_handshake() -> None:
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
|
||
|
# Implicit handshake
|
||
|
await client.send(b"xyz")
|
||
|
assert await client.receive() == b"xyz"
|
||
|
|
||
|
|
||
|
async def test_full_duplex() -> None:
|
||
|
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
|
||
|
# do_handshake simultaneously.
|
||
|
with endpoint() as server_endpoint, endpoint() as client_endpoint:
|
||
|
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||
|
async with trio.open_nursery() as server_nursery:
|
||
|
|
||
|
async def handler(channel: DTLSChannel) -> None:
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
nursery.start_soon(channel.send, b"from server")
|
||
|
nursery.start_soon(channel.receive)
|
||
|
|
||
|
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
|
||
|
|
||
|
client = client_endpoint.connect(
|
||
|
server_endpoint.socket.getsockname(), client_ctx
|
||
|
)
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
nursery.start_soon(client.send, b"from client")
|
||
|
nursery.start_soon(client.receive)
|
||
|
|
||
|
server_nursery.cancel_scope.cancel()
|
||
|
|
||
|
|
||
|
async def test_channel_closing() -> None:
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
await client.do_handshake()
|
||
|
client.close()
|
||
|
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
await client.send(b"abc")
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
await client.receive()
|
||
|
|
||
|
# close is idempotent
|
||
|
client.close()
|
||
|
# can also aclose
|
||
|
await client.aclose()
|
||
|
|
||
|
|
||
|
async def test_serve_exits_cleanly_on_close() -> None:
|
||
|
async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
|
||
|
server_endpoint.close()
|
||
|
# Testing that the nursery exits even without being cancelled
|
||
|
# close is idempotent
|
||
|
server_endpoint.close()
|
||
|
|
||
|
|
||
|
async def test_client_multiplex() -> None:
|
||
|
async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
|
||
|
with endpoint() as client_endpoint:
|
||
|
client1 = client_endpoint.connect(address1, client_ctx)
|
||
|
client2 = client_endpoint.connect(address2, client_ctx)
|
||
|
|
||
|
await client1.send(b"abc")
|
||
|
await client2.send(b"xyz")
|
||
|
assert await client2.receive() == b"xyz"
|
||
|
assert await client1.receive() == b"abc"
|
||
|
|
||
|
client_endpoint.close()
|
||
|
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
await client1.send(b"xxx")
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
await client2.receive()
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
client_endpoint.connect(address1, client_ctx)
|
||
|
|
||
|
async def null_handler(_: object) -> None: # pragma: no cover
|
||
|
pass
|
||
|
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
with pytest.raises(trio.ClosedResourceError):
|
||
|
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
|
||
|
|
||
|
|
||
|
async def test_dtls_over_dgram_only() -> None:
|
||
|
with trio.socket.socket() as s:
|
||
|
with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"):
|
||
|
DTLSEndpoint(s)
|
||
|
|
||
|
|
||
|
async def test_double_serve() -> None:
|
||
|
async def null_handler(_: object) -> None: # pragma: no cover
|
||
|
pass
|
||
|
|
||
|
with endpoint() as server_endpoint:
|
||
|
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||
|
with pytest.raises(trio.BusyResourceError):
|
||
|
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||
|
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
|
||
|
async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
with endpoint() as client1, endpoint() as client2:
|
||
|
await client1.socket.bind(("127.0.0.1", 0))
|
||
|
# This should just time out
|
||
|
with trio.move_on_after(100) as cscope:
|
||
|
channel = client2.connect(client1.socket.getsockname(), client_ctx)
|
||
|
await channel.do_handshake()
|
||
|
assert cscope.cancelled_caught
|
||
|
|
||
|
|
||
|
async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
for buffer_size in [10, 20]:
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
|
||
|
assert client_endpoint.incoming_packets_buffer == buffer_size
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
for i in range(buffer_size + 15):
|
||
|
await client.send(str(i).encode())
|
||
|
await trio.sleep(1)
|
||
|
stats = client.statistics()
|
||
|
assert stats.incoming_packets_dropped_in_trio == 15
|
||
|
for i in range(buffer_size):
|
||
|
assert await client.receive() == str(i).encode()
|
||
|
await client.send(b"buffer clear now")
|
||
|
assert await client.receive() == b"buffer clear now"
|
||
|
|
||
|
|
||
|
async def test_server_socket_doesnt_crash_on_garbage(
|
||
|
autojump_clock: trio.abc.Clock,
|
||
|
) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
from trio._dtls import (
|
||
|
ContentType,
|
||
|
HandshakeFragment,
|
||
|
HandshakeType,
|
||
|
ProtocolVersion,
|
||
|
Record,
|
||
|
encode_handshake_fragment,
|
||
|
encode_record,
|
||
|
)
|
||
|
|
||
|
client_hello = encode_record(
|
||
|
Record(
|
||
|
content_type=ContentType.handshake,
|
||
|
version=ProtocolVersion.DTLS10,
|
||
|
epoch_seqno=0,
|
||
|
payload=encode_handshake_fragment(
|
||
|
HandshakeFragment(
|
||
|
msg_type=HandshakeType.client_hello,
|
||
|
msg_len=10,
|
||
|
msg_seq=0,
|
||
|
frag_offset=0,
|
||
|
frag_len=10,
|
||
|
frag=bytes(10),
|
||
|
)
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
client_hello_extended = client_hello + b"\x00"
|
||
|
client_hello_short = client_hello[:-1]
|
||
|
# cuts off in middle of handshake message header
|
||
|
client_hello_really_short = client_hello[:14]
|
||
|
client_hello_corrupt_record_len = bytearray(client_hello)
|
||
|
client_hello_corrupt_record_len[11] = 0xFF
|
||
|
|
||
|
client_hello_fragmented = encode_record(
|
||
|
Record(
|
||
|
content_type=ContentType.handshake,
|
||
|
version=ProtocolVersion.DTLS10,
|
||
|
epoch_seqno=0,
|
||
|
payload=encode_handshake_fragment(
|
||
|
HandshakeFragment(
|
||
|
msg_type=HandshakeType.client_hello,
|
||
|
msg_len=20,
|
||
|
msg_seq=0,
|
||
|
frag_offset=0,
|
||
|
frag_len=10,
|
||
|
frag=bytes(10),
|
||
|
)
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
client_hello_trailing_data_in_record = encode_record(
|
||
|
Record(
|
||
|
content_type=ContentType.handshake,
|
||
|
version=ProtocolVersion.DTLS10,
|
||
|
epoch_seqno=0,
|
||
|
payload=encode_handshake_fragment(
|
||
|
HandshakeFragment(
|
||
|
msg_type=HandshakeType.client_hello,
|
||
|
msg_len=20,
|
||
|
msg_seq=0,
|
||
|
frag_offset=0,
|
||
|
frag_len=10,
|
||
|
frag=bytes(10),
|
||
|
)
|
||
|
)
|
||
|
+ b"\x00",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
handshake_empty = encode_record(
|
||
|
Record(
|
||
|
content_type=ContentType.handshake,
|
||
|
version=ProtocolVersion.DTLS10,
|
||
|
epoch_seqno=0,
|
||
|
payload=b"",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
client_hello_truncated_in_cookie = encode_record(
|
||
|
Record(
|
||
|
content_type=ContentType.handshake,
|
||
|
version=ProtocolVersion.DTLS10,
|
||
|
epoch_seqno=0,
|
||
|
payload=bytes(2 + 32 + 1) + b"\xff",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
|
||
|
for bad_packet in [
|
||
|
b"",
|
||
|
b"xyz",
|
||
|
client_hello_extended,
|
||
|
client_hello_short,
|
||
|
client_hello_really_short,
|
||
|
client_hello_corrupt_record_len,
|
||
|
client_hello_fragmented,
|
||
|
client_hello_trailing_data_in_record,
|
||
|
handshake_empty,
|
||
|
client_hello_truncated_in_cookie,
|
||
|
]:
|
||
|
await sock.sendto(bad_packet, address)
|
||
|
await trio.sleep(1)
|
||
|
|
||
|
|
||
|
async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
from trio._dtls import BadPacket, decode_client_hello_untrusted
|
||
|
|
||
|
with trio.CancelScope() as cscope:
|
||
|
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
|
||
|
# corrupting bytes after that.
|
||
|
offset_to_corrupt = count(11)
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
try:
|
||
|
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
|
||
|
except BadPacket:
|
||
|
pass
|
||
|
else:
|
||
|
if len(cookie) != 0:
|
||
|
# this is a challenge response packet
|
||
|
# let's corrupt the next offset so the handshake should fail
|
||
|
payload = bytearray(packet.payload)
|
||
|
offset = next(offset_to_corrupt)
|
||
|
if offset >= len(payload):
|
||
|
# We've tried all offsets. Clamp offset to the end of the
|
||
|
# payload, and terminate the test.
|
||
|
offset = len(payload) - 1
|
||
|
cscope.cancel()
|
||
|
payload[offset] ^= 0x01
|
||
|
packet = attrs.evolve(packet, payload=payload)
|
||
|
|
||
|
fn.deliver_packet(packet)
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
while True:
|
||
|
with endpoint() as client:
|
||
|
channel = client.connect(address, client_ctx)
|
||
|
await channel.do_handshake()
|
||
|
assert cscope.cancelled_caught
|
||
|
|
||
|
|
||
|
async def test_client_cancels_handshake_and_starts_new_one(
|
||
|
autojump_clock: trio.abc.Clock,
|
||
|
) -> None:
|
||
|
# if a client disappears during the handshake, and then starts a new handshake from
|
||
|
# scratch, then the first handler's channel should fail, and a new handler get
|
||
|
# started
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
with endpoint() as server, endpoint() as client:
|
||
|
await server.socket.bind(("127.0.0.1", 0))
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
first_time = True
|
||
|
|
||
|
async def handler(channel: DTLSChannel) -> None:
|
||
|
nonlocal first_time
|
||
|
if first_time:
|
||
|
first_time = False
|
||
|
print("handler: first time, cancelling connect")
|
||
|
connect_cscope.cancel()
|
||
|
await trio.sleep(0.5)
|
||
|
print("handler: handshake should fail now")
|
||
|
with pytest.raises(trio.BrokenResourceError):
|
||
|
await channel.do_handshake()
|
||
|
else:
|
||
|
print("handler: not first time, sending hello")
|
||
|
await channel.send(b"hello")
|
||
|
|
||
|
await nursery.start(server.serve, server_ctx, handler)
|
||
|
|
||
|
print("client: starting first connect")
|
||
|
with trio.CancelScope() as connect_cscope:
|
||
|
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||
|
await channel.do_handshake()
|
||
|
assert connect_cscope.cancelled_caught
|
||
|
|
||
|
print("client: starting second connect")
|
||
|
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||
|
assert await channel.receive() == b"hello"
|
||
|
|
||
|
# Give handlers a chance to finish
|
||
|
await trio.sleep(10)
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
|
||
|
async def test_swap_client_server() -> None:
|
||
|
with endpoint() as a, endpoint() as b:
|
||
|
await a.socket.bind(("127.0.0.1", 0))
|
||
|
await b.socket.bind(("127.0.0.1", 0))
|
||
|
|
||
|
async def echo_handler(channel: DTLSChannel) -> None:
|
||
|
async for packet in channel:
|
||
|
await channel.send(packet)
|
||
|
|
||
|
async def crashing_echo_handler(channel: DTLSChannel) -> None:
|
||
|
with pytest.raises(trio.BrokenResourceError):
|
||
|
await echo_handler(channel)
|
||
|
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
|
||
|
await nursery.start(b.serve, server_ctx, echo_handler)
|
||
|
|
||
|
b_to_a = b.connect(a.socket.getsockname(), client_ctx)
|
||
|
await b_to_a.send(b"b as client")
|
||
|
assert await b_to_a.receive() == b"b as client"
|
||
|
|
||
|
a_to_b = a.connect(b.socket.getsockname(), client_ctx)
|
||
|
await a_to_b.do_handshake()
|
||
|
with pytest.raises(trio.BrokenResourceError):
|
||
|
await b_to_a.send(b"association broken")
|
||
|
await a_to_b.send(b"a as client")
|
||
|
assert await a_to_b.receive() == b"a as client"
|
||
|
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
|
||
|
@slow
|
||
|
async def test_openssl_retransmit_doesnt_break_stuff() -> None:
|
||
|
# can't use autojump_clock here, because the point of the test is to wait for
|
||
|
# openssl's built-in retransmit timer to expire, which is hard-coded to use
|
||
|
# wall-clock time.
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
blackholed = True
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
if blackholed:
|
||
|
print("dropped packet", packet)
|
||
|
return
|
||
|
print("delivered packet", packet)
|
||
|
# packets.append(
|
||
|
# scapy.all.IP(
|
||
|
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
|
||
|
# )
|
||
|
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
|
||
|
# / packet.payload
|
||
|
# )
|
||
|
fn.deliver_packet(packet)
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
async with dtls_echo_server() as (server_endpoint, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
|
||
|
async def connecter() -> None:
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
await client.do_handshake(initial_retransmit_timeout=1.5)
|
||
|
await client.send(b"hi")
|
||
|
assert await client.receive() == b"hi"
|
||
|
|
||
|
nursery.start_soon(connecter)
|
||
|
|
||
|
# openssl's default timeout is 1 second, so this ensures that it thinks
|
||
|
# the timeout has expired
|
||
|
await trio.sleep(1.1)
|
||
|
# disable blackholing and send a garbage packet to wake up openssl so it
|
||
|
# notices the timeout has expired
|
||
|
blackholed = False
|
||
|
await server_endpoint.socket.sendto(
|
||
|
b"xxx", client_endpoint.socket.getsockname()
|
||
|
)
|
||
|
# now the client task should finish connecting and exit cleanly
|
||
|
|
||
|
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
|
||
|
|
||
|
|
||
|
async def test_initial_retransmit_timeout_configuration(
|
||
|
autojump_clock: trio.abc.Clock,
|
||
|
) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
blackholed = True
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
nonlocal blackholed
|
||
|
if blackholed:
|
||
|
blackholed = False
|
||
|
else:
|
||
|
fn.deliver_packet(packet)
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
for t in [1, 2, 4]:
|
||
|
with endpoint() as client:
|
||
|
before = trio.current_time()
|
||
|
blackholed = True
|
||
|
channel = client.connect(address, client_ctx)
|
||
|
await channel.do_handshake(initial_retransmit_timeout=t)
|
||
|
after = trio.current_time()
|
||
|
assert after - before == t
|
||
|
|
||
|
|
||
|
async def test_explicit_tiny_mtu_is_respected() -> None:
|
||
|
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
|
||
|
# be larger than that. (300 is still smaller than any real network though.)
|
||
|
MTU = 300
|
||
|
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
print(f"delivering {packet}")
|
||
|
print(f"payload size: {len(packet.payload)}")
|
||
|
assert len(packet.payload) <= MTU
|
||
|
fn.deliver_packet(packet)
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
async with dtls_echo_server(mtu=MTU) as (server, address):
|
||
|
with endpoint() as client:
|
||
|
channel = client.connect(address, client_ctx)
|
||
|
channel.set_ciphertext_mtu(MTU)
|
||
|
await channel.do_handshake()
|
||
|
await channel.send(b"hi")
|
||
|
assert await channel.receive() == b"hi"
|
||
|
|
||
|
|
||
|
@parametrize_ipv6
|
||
|
async def test_handshake_handles_minimum_network_mtu(
|
||
|
ipv6: bool, autojump_clock: trio.abc.Clock
|
||
|
) -> None:
|
||
|
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
mtu = 1280 - 48 if ipv6 else 576 - 28
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
if len(packet.payload) > mtu:
|
||
|
print(f"dropping {packet}")
|
||
|
else:
|
||
|
print(f"delivering {packet}")
|
||
|
fn.deliver_packet(packet)
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
# See if we can successfully do a handshake -- some of the volleys will get dropped,
|
||
|
# and the retransmit logic should detect this and back off the MTU to something
|
||
|
# smaller until it succeeds.
|
||
|
async with dtls_echo_server(ipv6=ipv6) as (_, address):
|
||
|
with endpoint(ipv6=ipv6) as client_endpoint:
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
# the handshake mtu backoff shouldn't affect the return value from
|
||
|
# get_cleartext_mtu, b/c that's under the user's control via
|
||
|
# set_ciphertext_mtu
|
||
|
client.set_ciphertext_mtu(9999)
|
||
|
await client.send(b"xyz")
|
||
|
assert await client.receive() == b"xyz"
|
||
|
assert client.get_cleartext_mtu() > 9000 # as vegeta said
|
||
|
|
||
|
|
||
|
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||
|
async def test_system_task_cleaned_up_on_gc() -> None:
|
||
|
before_tasks = trio.lowlevel.current_statistics().tasks_living
|
||
|
|
||
|
# We put this into a sub-function so that everything automatically becomes garbage
|
||
|
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
|
||
|
# with coverage enabled -- I think we were hitting this bug:
|
||
|
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
|
||
|
async def start_and_forget_endpoint() -> int:
|
||
|
e = endpoint()
|
||
|
|
||
|
# This connection/handshake attempt can't succeed. The only purpose is to force
|
||
|
# the endpoint to set up a receive loop.
|
||
|
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||
|
await s.bind(("127.0.0.1", 0))
|
||
|
c = e.connect(s.getsockname(), client_ctx)
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
nursery.start_soon(c.do_handshake)
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
nursery.cancel_scope.cancel()
|
||
|
|
||
|
during_tasks = trio.lowlevel.current_statistics().tasks_living
|
||
|
return during_tasks
|
||
|
|
||
|
with pytest.warns(ResourceWarning):
|
||
|
during_tasks = await start_and_forget_endpoint()
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
gc_collect_harder()
|
||
|
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
|
||
|
after_tasks = trio.lowlevel.current_statistics().tasks_living
|
||
|
assert before_tasks < during_tasks
|
||
|
assert before_tasks == after_tasks
|
||
|
|
||
|
|
||
|
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||
|
async def test_gc_before_system_task_starts() -> None:
|
||
|
e = endpoint()
|
||
|
|
||
|
with pytest.warns(ResourceWarning):
|
||
|
del e
|
||
|
gc_collect_harder()
|
||
|
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
|
||
|
|
||
|
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||
|
async def test_gc_as_packet_received() -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
e = endpoint()
|
||
|
await e.socket.bind(("127.0.0.1", 0))
|
||
|
e._ensure_receive_loop()
|
||
|
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
|
||
|
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||
|
await s.sendto(b"xxx", e.socket.getsockname())
|
||
|
# At this point, the endpoint's receive loop has been marked runnable because it
|
||
|
# just received a packet; closing the endpoint socket won't interrupt that. But by
|
||
|
# the time it wakes up to process the packet, the endpoint will be gone.
|
||
|
with pytest.warns(ResourceWarning):
|
||
|
del e
|
||
|
gc_collect_harder()
|
||
|
|
||
|
|
||
|
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||
|
def test_gc_after_trio_exits() -> None:
|
||
|
async def main() -> DTLSEndpoint:
|
||
|
# We use fakenet just to make sure no real sockets can leak out of the test
|
||
|
# case - on pypy somehow the socket was outliving the gc_collect_harder call
|
||
|
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
|
||
|
# when called after trio exits, it doesn't need a real socket.
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
return endpoint()
|
||
|
|
||
|
e = trio.run(main)
|
||
|
with pytest.warns(ResourceWarning):
|
||
|
del e
|
||
|
gc_collect_harder()
|
||
|
|
||
|
|
||
|
async def test_already_closed_socket_doesnt_crash() -> None:
|
||
|
with endpoint() as e:
|
||
|
# We close the socket before checkpointing, so the socket will already be closed
|
||
|
# when the system task starts up
|
||
|
e.socket.close()
|
||
|
# Now give it a chance to start up, and hopefully not crash
|
||
|
await trio.testing.wait_all_tasks_blocked()
|
||
|
|
||
|
|
||
|
async def test_socket_closed_while_processing_clienthello(
|
||
|
autojump_clock: trio.abc.Clock,
|
||
|
) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
# Check what happens if the socket is discovered to be closed when sending a
|
||
|
# HelloVerifyRequest, since that has its own sending logic
|
||
|
async with dtls_echo_server() as (server, address):
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
fn.deliver_packet(packet)
|
||
|
server.socket.close()
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
with endpoint() as client_endpoint:
|
||
|
with trio.move_on_after(10):
|
||
|
client = client_endpoint.connect(address, client_ctx)
|
||
|
await client.do_handshake()
|
||
|
|
||
|
|
||
|
async def test_association_replaced_while_handshake_running(
|
||
|
autojump_clock: trio.abc.Clock,
|
||
|
) -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
def route_packet(packet: UDPPacket) -> None:
|
||
|
pass
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
c1 = client_endpoint.connect(address, client_ctx)
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
|
||
|
async def doomed_handshake() -> None:
|
||
|
with pytest.raises(trio.BrokenResourceError):
|
||
|
await c1.do_handshake()
|
||
|
|
||
|
nursery.start_soon(doomed_handshake)
|
||
|
|
||
|
await trio.sleep(10)
|
||
|
|
||
|
client_endpoint.connect(address, client_ctx)
|
||
|
|
||
|
|
||
|
async def test_association_replaced_before_handshake_starts() -> None:
|
||
|
fn = FakeNet()
|
||
|
fn.enable()
|
||
|
|
||
|
# This test shouldn't send any packets
|
||
|
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
|
||
|
raise AssertionError()
|
||
|
|
||
|
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||
|
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
c1 = client_endpoint.connect(address, client_ctx)
|
||
|
client_endpoint.connect(address, client_ctx)
|
||
|
with pytest.raises(trio.BrokenResourceError):
|
||
|
await c1.do_handshake()
|
||
|
|
||
|
|
||
|
async def test_send_to_closed_local_port() -> None:
|
||
|
# On Windows, sending a UDP packet to a closed local port can cause a weird
|
||
|
# ECONNRESET error later, inside the receive task. Make sure we're handling it
|
||
|
# properly.
|
||
|
async with dtls_echo_server() as (_, address):
|
||
|
with endpoint() as client_endpoint:
|
||
|
async with trio.open_nursery() as nursery:
|
||
|
for i in range(1, 10):
|
||
|
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
|
||
|
nursery.start_soon(channel.do_handshake)
|
||
|
channel = client_endpoint.connect(address, client_ctx)
|
||
|
await channel.send(b"xxx")
|
||
|
assert await channel.receive() == b"xxx"
|
||
|
nursery.cancel_scope.cancel()
|