You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1338 lines
50 KiB

from __future__ import annotations
import os
import socket as stdlib_socket
import ssl
import sys
import threading
from contextlib import asynccontextmanager, contextmanager, suppress
from functools import partial
from ssl import SSLContext
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
NoReturn,
)
import pytest
from trio import StapledStream
from trio._tests.pytest_plugin import skip_if_optional_else_raise
from trio.abc import ReceiveStream, SendStream
from trio.testing import (
Matcher,
MemoryReceiveStream,
MemorySendStream,
RaisesGroup,
)
try:
import trustme
from OpenSSL import SSL
except ImportError as error:
skip_if_optional_else_raise(error)
import trio
from .. import _core, socket as tsocket
from .._abc import Stream
from .._core import BrokenResourceError, ClosedResourceError
from .._core._tests.tutil import slow
from .._highlevel_generic import aclose_forcefully
from .._highlevel_open_tcp_stream import open_tcp_stream
from .._highlevel_socket import SocketListener, SocketStream
from .._ssl import NeedHandshakeError, SSLListener, SSLStream, _is_eof
from .._util import ConflictDetector
from ..testing import (
Sequencer,
assert_checkpoints,
check_two_way_stream,
lockstep_stream_pair,
memory_stream_pair,
)
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from trio._core import MockClock
from trio._ssl import T_Stream
from .._core._run import CancelScope
# We have two different kinds of echo server fixtures we use for testing. The
# first is a real server written using the stdlib ssl module and blocking
# sockets. It runs in a thread and we talk to it over a real socketpair(), to
# validate interoperability in a semi-realistic setting.
#
# The second is a very weird virtual echo server that lives inside a custom
# Stream class. It lives entirely inside the Python object space; there are no
# operating system calls in it at all. No threads, no I/O, nothing. It's
# 'send_all' call takes encrypted data from a client and feeds it directly into
# the server-side TLS state engine to decrypt, then takes that data, feeds it
# back through to get the encrypted response, and returns it from 'receive_some'. This
# gives us full control and reproducibility. This server is written using
# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows
# us to insert random (virtual) delays, to really exercise all the weird paths
# in SSLStream's state engine.
#
# Both present a certificate for "trio-test-1.example.org".
TRIO_TEST_CA = trustme.CA()
TRIO_TEST_1_CERT = TRIO_TEST_CA.issue_server_cert("trio-test-1.example.org")
SERVER_CTX = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
SERVER_CTX.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
TRIO_TEST_1_CERT.configure_cert(SERVER_CTX)
# TLS 1.3 has a lot of changes from previous versions. So we want to run tests
# with both TLS 1.3, and TLS 1.2.
# "tls13" means that we're willing to negotiate TLS 1.3. Usually that's
# what will happen, but the renegotiation tests explicitly force a
# downgrade on the server side. "tls12" means we refuse to negotiate TLS
# 1.3, so we'll almost certainly use TLS 1.2.
@pytest.fixture(scope="module", params=["tls13", "tls12"])
def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext:
ctx = ssl.create_default_context()
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
ctx.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
TRIO_TEST_CA.configure_trust(ctx)
if request.param in ["default", "tls13"]:
return ctx
elif request.param == "tls12":
ctx.maximum_version = ssl.TLSVersion.TLSv1_2
return ctx
else: # pragma: no cover
raise AssertionError()
# The blocking socket server.
def ssl_echo_serve_sync(
sock: stdlib_socket.socket, *, expect_fail: bool = False
) -> None:
try:
wrapped = SERVER_CTX.wrap_socket(
sock, server_side=True, suppress_ragged_eofs=False
)
with wrapped:
wrapped.do_handshake()
while True:
data = wrapped.recv(4096)
if not data:
# other side has initiated a graceful shutdown; we try to
# respond in kind but it's legal for them to have already
# gone away.
with suppress(BrokenPipeError, ssl.SSLZeroReturnError):
wrapped.unwrap()
return
wrapped.sendall(data)
# This is an obscure workaround for an openssl bug. In server mode, in
# some versions, openssl sends some extra data at the end of do_handshake
# that it shouldn't send. Normally this is harmless, but, if the other
# side shuts down the connection before it reads that data, it might cause
# the OS to report a ECONNREST or even ECONNABORTED (which is just wrong,
# since ECONNABORTED is supposed to mean that connect() failed, but what
# can you do). In this case the other side did nothing wrong, but there's
# no way to recover, so we let it pass, and just cross our fingers its not
# hiding any (other) real bugs. For more details see:
#
# https://github.com/python-trio/trio/issues/1293
#
# Also, this happens frequently but non-deterministically, so we have to
# 'no cover' it to avoid coverage flapping.
except (ConnectionResetError, ConnectionAbortedError): # pragma: no cover
return
except Exception as exc:
if expect_fail:
print("ssl_echo_serve_sync got error as expected:", exc)
else: # pragma: no cover
print("ssl_echo_serve_sync got unexpected error:", exc)
raise
else:
if expect_fail: # pragma: no cover
raise RuntimeError("failed to fail?")
finally:
sock.close()
# Fixture that gives a raw socket connected to a trio-test-1 echo server
# (running in a thread). Useful for testing making connections with different
# SSLContexts.
@asynccontextmanager # type: ignore[misc] # decorated contains Any
async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
a, b = stdlib_socket.socketpair()
async with trio.open_nursery() as nursery:
# Exiting the 'with a, b' context manager closes the sockets, which
# causes the thread to exit (possibly with an error), which allows the
# nursery context manager to exit too.
with a, b:
nursery.start_soon(
trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs)
)
yield SocketStream(tsocket.from_stdlib_socket(a))
# Fixture that gives a properly set up SSLStream connected to a trio-test-1
# echo server (running in a thread)
@asynccontextmanager # type: ignore[misc] # decorated contains Any
async def ssl_echo_server(
client_ctx: SSLContext, **kwargs: Any
) -> AsyncIterator[SSLStream[Stream]]:
async with ssl_echo_server_raw(**kwargs) as sock:
yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
# The weird in-memory server ... thing.
# Doesn't inherit from Stream because I left out the methods that we don't
# actually need.
# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
# from Stream for the sake of typechecking.
class PyOpenSSLEchoStream(Stream):
def __init__(self, sleeper: None = None) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
# we still have to support versions before that, and that means we
# need to test renegotiation support, which means we need to force this
# to use a lower version where this test server can trigger
# renegotiations. Of course TLS 1.3 support isn't released yet, but
# I'm told that this will work once it is. (And once it is we can
# remove the pragma: no cover too.) Alternatively, we could switch to
# using TLSv1_2_METHOD.
#
# Discussion: https://github.com/pyca/pyopenssl/issues/624
# This is the right way, but we can't use it until this PR is in a
# released:
# https://github.com/pyca/pyopenssl/pull/861
#
# if hasattr(SSL, "OP_NO_TLSv1_3"):
# ctx.set_options(SSL.OP_NO_TLSv1_3)
#
# Fortunately pyopenssl uses cryptography under the hood, so we can be
# confident that they're using the same version of openssl
from cryptography.hazmat.bindings.openssl.binding import Binding
b = Binding()
if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"):
ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3)
# Unfortunately there's currently no way to say "use 1.3 or worse", we
# can only disable specific versions. And if the two sides start
# negotiating 1.4 at some point in the future, it *might* mean that
# our tests silently stop working properly. So the next line is a
# tripwire to remind us we need to revisit this stuff in 5 years or
# whatever when the next TLS version is released:
assert not hasattr(SSL, "OP_NO_TLSv1_4")
TRIO_TEST_1_CERT.configure_cert(ctx)
self._conn = SSL.Connection(ctx, None)
self._conn.set_accept_state()
self._lot = _core.ParkingLot()
self._pending_cleartext = bytearray()
self._send_all_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.send_all"
)
self._receive_some_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.receive_some"
)
if sleeper is None:
async def no_op_sleeper(_: object) -> None:
return
self.sleeper = no_op_sleeper
else:
self.sleeper = sleeper
async def aclose(self) -> None:
self._conn.bio_shutdown()
def renegotiate_pending(self) -> bool:
return self._conn.renegotiate_pending()
def renegotiate(self) -> None:
# Returns false if a renegotiation is already in progress, meaning
# nothing happens.
assert self._conn.renegotiate()
async def wait_send_all_might_not_block(self) -> None:
with self._send_all_conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
await self.sleeper("wait_send_all_might_not_block")
async def send_all(self, data: bytes) -> None:
print(" --> transport_stream.send_all")
with self._send_all_conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
await self.sleeper("send_all")
self._conn.bio_write(data)
while True:
await self.sleeper("send_all")
try:
data = self._conn.recv(1)
except SSL.ZeroReturnError:
self._conn.shutdown()
print("renegotiations:", self._conn.total_renegotiations())
break
except SSL.WantReadError:
break
else:
self._pending_cleartext += data
self._lot.unpark_all()
await self.sleeper("send_all")
print(" <-- transport_stream.send_all finished")
async def receive_some(self, nbytes: int | None = None) -> bytes:
print(" --> transport_stream.receive_some")
if nbytes is None:
nbytes = 65536 # arbitrary
with self._receive_some_conflict_detector:
try:
await _core.checkpoint()
await _core.checkpoint()
while True:
await self.sleeper("receive_some")
try:
return self._conn.bio_read(nbytes)
except SSL.WantReadError:
# No data in our ciphertext buffer; try to generate
# some.
if self._pending_cleartext:
# We have some cleartext; maybe we can encrypt it
# and then return it.
print(" trying", self._pending_cleartext)
try:
# PyOpenSSL bug: doesn't accept bytearray
# https://github.com/pyca/pyopenssl/issues/621
next_byte = self._pending_cleartext[0:1]
self._conn.send(bytes(next_byte))
# Apparently this next bit never gets hit in the
# test suite, but it's not an interesting omission
# so let's pragma it.
except SSL.WantReadError: # pragma: no cover
# We didn't manage to send the cleartext (and
# in particular we better leave it there to
# try again, due to openssl's retry
# semantics), but it's possible we pushed a
# renegotiation forward and *now* we have data
# to send.
try:
return self._conn.bio_read(nbytes)
except SSL.WantReadError:
# Nope. We're just going to have to wait
# for someone to call send_all() to give
# use more data.
print("parking (a)")
await self._lot.park()
else:
# We successfully sent that byte, so we don't
# have to again.
del self._pending_cleartext[0:1]
else:
# no pending cleartext; nothing to do but wait for
# someone to call send_all
print("parking (b)")
await self._lot.park()
finally:
await self.sleeper("receive_some")
print(" <-- transport_stream.receive_some finished")
async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None:
# Make sure that PyOpenSSLEchoStream complains if two tasks call send_all
# at the same time, or ditto for receive_some. The tricky cases where SSLStream
# might accidentally do this are during renegotiation, which we test using
# PyOpenSSLEchoStream, so this makes sure that if we do have a bug then
# PyOpenSSLEchoStream will notice and complain.
async def do_test(
func1: str, args1: tuple[object, ...], func2: str, args2: tuple[object, ...]
) -> None:
s = PyOpenSSLEchoStream()
with RaisesGroup(Matcher(_core.BusyResourceError, "simultaneous")):
async with _core.open_nursery() as nursery:
nursery.start_soon(getattr(s, func1), *args1)
nursery.start_soon(getattr(s, func2), *args2)
await do_test("send_all", (b"x",), "send_all", (b"x",))
await do_test("send_all", (b"x",), "wait_send_all_might_not_block", ())
await do_test(
"wait_send_all_might_not_block", (), "wait_send_all_might_not_block", ()
)
await do_test("receive_some", (1,), "receive_some", (1,))
@contextmanager # type: ignore[misc] # decorated contains Any
def virtual_ssl_echo_server(
client_ctx: SSLContext, **kwargs: Any
) -> Iterator[SSLStream[PyOpenSSLEchoStream]]:
fakesock = PyOpenSSLEchoStream(**kwargs)
yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")
def ssl_wrap_pair(
client_ctx: SSLContext,
client_transport: T_Stream,
server_transport: T_Stream,
*,
client_kwargs: dict[str, Any] | None = None,
server_kwargs: dict[str, Any] | None = None,
) -> tuple[SSLStream[T_Stream], SSLStream[T_Stream]]:
if server_kwargs is None:
server_kwargs = {}
if client_kwargs is None:
client_kwargs = {}
client_ssl = SSLStream(
client_transport,
client_ctx,
server_hostname="trio-test-1.example.org",
**client_kwargs,
)
server_ssl = SSLStream(
server_transport, SERVER_CTX, server_side=True, **server_kwargs
)
return client_ssl, server_ssl
MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream]
def ssl_memory_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[
SSLStream[MemoryStapledStream],
SSLStream[MemoryStapledStream],
]:
client_transport, server_transport = memory_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream]
def ssl_lockstep_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[
SSLStream[MyStapledStream],
SSLStream[MyStapledStream],
]:
client_transport, server_transport = lockstep_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
# Simple smoke test for handshake/send/receive/shutdown talking to a
# synchronous server, plus make sure that we do the bare minimum of
# certificate checking (even though this is really Python's responsibility)
async def test_ssl_client_basics(client_ctx: SSLContext) -> None:
# Everything OK
async with ssl_echo_server(client_ctx) as s:
assert not s.server_side
await s.send_all(b"x")
assert await s.receive_some(1) == b"x"
await s.aclose()
# Didn't configure the CA file, should fail
async with ssl_echo_server_raw(expect_fail=True) as sock:
bad_client_ctx = ssl.create_default_context()
s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org")
assert not s.server_side
with pytest.raises(BrokenResourceError) as excinfo:
await s.send_all(b"x")
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
# Trusted CA, but wrong host name
async with ssl_echo_server_raw(expect_fail=True) as sock:
s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org")
assert not s.server_side
with pytest.raises(BrokenResourceError) as excinfo:
await s.send_all(b"x")
assert isinstance(excinfo.value.__cause__, ssl.CertificateError)
async def test_ssl_server_basics(client_ctx: SSLContext) -> None:
a, b = stdlib_socket.socketpair()
with a, b:
server_sock = tsocket.from_stdlib_socket(b)
server_transport = SSLStream(
SocketStream(server_sock), SERVER_CTX, server_side=True
)
assert server_transport.server_side
def client() -> None:
with client_ctx.wrap_socket(
a, server_hostname="trio-test-1.example.org"
) as client_sock:
client_sock.sendall(b"x")
assert client_sock.recv(1) == b"y"
client_sock.sendall(b"z")
client_sock.unwrap()
t = threading.Thread(target=client)
t.start()
assert await server_transport.receive_some(1) == b"x"
await server_transport.send_all(b"y")
assert await server_transport.receive_some(1) == b"z"
assert await server_transport.receive_some(1) == b""
await server_transport.aclose()
t.join()
async def test_attributes(client_ctx: SSLContext) -> None:
async with ssl_echo_server_raw(expect_fail=True) as sock:
good_ctx = client_ctx
bad_ctx = ssl.create_default_context()
s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org")
assert s.transport_stream is sock
# Forwarded attribute getting
assert s.context is good_ctx
assert s.server_side == False # noqa
assert s.server_hostname == "trio-test-1.example.org"
with pytest.raises(AttributeError):
s.asfdasdfsa # noqa: B018 # "useless expression"
# __dir__
assert "transport_stream" in dir(s)
assert "context" in dir(s)
# Setting the attribute goes through to the underlying object
# most attributes on SSLObject are read-only
with pytest.raises(AttributeError):
s.server_side = True
with pytest.raises(AttributeError):
s.server_hostname = "asdf"
# but .context is *not*. Check that we forward attribute setting by
# making sure that after we set the bad context our handshake indeed
# fails:
s.context = bad_ctx
assert s.context is bad_ctx
with pytest.raises(BrokenResourceError) as excinfo:
await s.do_handshake()
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
# Note: this test fails horribly if we force TLS 1.2 and trigger a
# renegotiation at the beginning (e.g. by switching to the pyopenssl
# server). Usually the client crashes in SSLObject.write with "UNEXPECTED
# RECORD"; sometimes we get something more exotic like a SyscallError. This is
# odd because openssl isn't doing any syscalls, but so it goes. After lots of
# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just
# can't reliably handle full-duplex communication combined with
# renegotiation. Nice, eh?
#
# https://rt.openssl.org/Ticket/Display.html?id=3712
# https://rt.openssl.org/Ticket/Display.html?id=2481
# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html
# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication
#
# In some variants of this test (maybe only against the java server?) I've
# also seen cases where our send_all blocks waiting to write, and then our receive_some
# also blocks waiting to write, and they never wake up again. It looks like
# some kind of deadlock. I suspect there may be an issue where we've filled up
# the send buffers, and the remote side is trying to handle the renegotiation
# from inside a write() call, so it has a problem: there's all this application
# data clogging up the pipe, but it can't process and return it to the
# application because it's in write(), and it doesn't want to buffer infinite
# amounts of data, and... actually I guess those are the only two choices.
#
# NSS even documents that you shouldn't try to do a renegotiation except when
# the connection is idle:
#
# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582
#
# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
async def test_full_duplex_basics(client_ctx: SSLContext) -> None:
CHUNKS = 30
CHUNK_SIZE = 32768
EXPECTED = CHUNKS * CHUNK_SIZE
sent = bytearray()
received = bytearray()
async def sender(s: Stream) -> None:
nonlocal sent
for i in range(CHUNKS):
print(i)
chunk = bytes([i] * CHUNK_SIZE)
sent += chunk
await s.send_all(chunk)
async def receiver(s: Stream) -> None:
nonlocal received
while len(received) < EXPECTED:
chunk = await s.receive_some(CHUNK_SIZE // 2)
received += chunk
async with ssl_echo_server(client_ctx) as s:
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, s)
nursery.start_soon(receiver, s)
# And let's have some doing handshakes too, everyone
# simultaneously
nursery.start_soon(s.do_handshake)
nursery.start_soon(s.do_handshake)
await s.aclose()
assert len(sent) == len(received) == EXPECTED
assert sent == received
async def test_renegotiation_simple(client_ctx: SSLContext) -> None:
with virtual_ssl_echo_server(client_ctx) as s:
await s.do_handshake()
s.transport_stream.renegotiate()
await s.send_all(b"a")
assert await s.receive_some(1) == b"a"
# Have to send some more data back and forth to make sure the
# renegotiation is finished before shutting down the
# connection... otherwise openssl raises an error. I think this is a
# bug in openssl but what can ya do.
await s.send_all(b"b")
assert await s.receive_some(1) == b"b"
await s.aclose()
@slow
async def test_renegotiation_randomized(
mock_clock: MockClock, client_ctx: SSLContext
) -> None:
# The only blocking things in this function are our random sleeps, so 0 is
# a good threshold.
mock_clock.autojump_threshold = 0
import random
r = random.Random(0)
async def sleeper(_: object) -> None:
await trio.sleep(r.uniform(0, 10))
async def clear() -> None:
while s.transport_stream.renegotiate_pending():
with assert_checkpoints():
await send(b"-")
with assert_checkpoints():
await expect(b"-")
print("-- clear --")
async def send(byte: bytes) -> None:
await s.transport_stream.sleeper("outer send")
print("calling SSLStream.send_all", byte)
with assert_checkpoints():
await s.send_all(byte)
async def expect(expected: bytes) -> None:
await s.transport_stream.sleeper("expect")
print("calling SSLStream.receive_some, expecting", expected)
assert len(expected) == 1
with assert_checkpoints():
assert await s.receive_some(1) == expected
with virtual_ssl_echo_server(client_ctx, sleeper=sleeper) as s:
await s.do_handshake()
await send(b"a")
s.transport_stream.renegotiate()
await expect(b"a")
await clear()
for i in range(100):
b1 = bytes([i % 0xFF])
b2 = bytes([(2 * i) % 0xFF])
s.transport_stream.renegotiate()
async with _core.open_nursery() as nursery:
nursery.start_soon(send, b1)
nursery.start_soon(expect, b1)
async with _core.open_nursery() as nursery:
nursery.start_soon(expect, b2)
nursery.start_soon(send, b2)
await clear()
for i in range(100):
b1 = bytes([i % 0xFF])
b2 = bytes([(2 * i) % 0xFF])
await send(b1)
s.transport_stream.renegotiate()
await expect(b1)
async with _core.open_nursery() as nursery:
nursery.start_soon(expect, b2)
nursery.start_soon(send, b2)
await clear()
# Checking that wait_send_all_might_not_block and receive_some don't
# conflict:
# 1) Set up a situation where expect (receive_some) is blocked sending,
# and wait_send_all_might_not_block comes in.
# Our receive_some() call will get stuck when it hits send_all
async def sleeper_with_slow_send_all(method: str) -> None:
if method == "send_all":
await trio.sleep(100000)
# And our wait_send_all_might_not_block call will give it time to get
# stuck, and then start
async def sleep_then_wait_writable() -> None:
await trio.sleep(1000)
await s.wait_send_all_might_not_block()
with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s:
await send(b"x")
s.transport_stream.renegotiate()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect, b"x")
nursery.start_soon(sleep_then_wait_writable)
await clear()
await s.aclose()
# 2) Same, but now wait_send_all_might_not_block is stuck when
# receive_some tries to send.
async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None:
if method == "wait_send_all_might_not_block":
await trio.sleep(100000)
elif method == "expect":
await trio.sleep(1000)
with virtual_ssl_echo_server(
client_ctx, sleeper=sleeper_with_slow_wait_writable_and_expect
) as s:
await send(b"x")
s.transport_stream.renegotiate()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect, b"x")
nursery.start_soon(s.wait_send_all_might_not_block)
await clear()
await s.aclose()
async def test_resource_busy_errors(client_ctx: SSLContext) -> None:
S: TypeAlias = trio.SSLStream[
trio.StapledStream[trio.abc.SendStream, trio.abc.ReceiveStream]
]
async def do_send_all(s: S) -> None:
with assert_checkpoints():
await s.send_all(b"x")
async def do_receive_some(s: S) -> None:
with assert_checkpoints():
await s.receive_some(1)
async def do_wait_send_all_might_not_block(s: S) -> None:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
async def do_test(
func1: Callable[[S], Awaitable[None]], func2: Callable[[S], Awaitable[None]]
) -> None:
s, _ = ssl_lockstep_stream_pair(client_ctx)
with RaisesGroup(Matcher(_core.BusyResourceError, "another task")):
async with _core.open_nursery() as nursery:
nursery.start_soon(func1, s)
nursery.start_soon(func2, s)
await do_test(do_send_all, do_send_all)
await do_test(do_receive_some, do_receive_some)
await do_test(do_send_all, do_wait_send_all_might_not_block)
await do_test(do_wait_send_all_might_not_block, do_wait_send_all_might_not_block)
async def test_wait_writable_calls_underlying_wait_writable() -> None:
record = []
class NotAStream(Stream):
async def wait_send_all_might_not_block(self) -> None:
record.append("ok")
# define methods that are abstract in Stream
async def aclose(self) -> None:
raise AssertionError("Should not get called") # pragma: no cover
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
raise AssertionError("Should not get called") # pragma: no cover
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
raise AssertionError("Should not get called") # pragma: no cover
ctx = ssl.create_default_context()
s = SSLStream(NotAStream(), ctx, server_hostname="x")
await s.wait_send_all_might_not_block()
assert record == ["ok"]
@pytest.mark.skipif(
os.name == "nt" and sys.version_info >= (3, 10),
reason="frequently fails on Windows + Python 3.10",
)
async def test_checkpoints(client_ctx: SSLContext) -> None:
async with ssl_echo_server(client_ctx) as s:
with assert_checkpoints():
await s.do_handshake()
with assert_checkpoints():
await s.do_handshake()
with assert_checkpoints():
await s.wait_send_all_might_not_block()
with assert_checkpoints():
await s.send_all(b"xxx")
with assert_checkpoints():
await s.receive_some(1)
# These receive_some's in theory could return immediately, because the
# "xxx" was sent in a single record and after the first
# receive_some(1) the rest are sitting inside the SSLObject's internal
# buffers.
with assert_checkpoints():
await s.receive_some(1)
with assert_checkpoints():
await s.receive_some(1)
with assert_checkpoints():
await s.unwrap()
async with ssl_echo_server(client_ctx) as s:
await s.do_handshake()
with assert_checkpoints():
await s.aclose()
async def test_send_all_empty_string(client_ctx: SSLContext) -> None:
async with ssl_echo_server(client_ctx) as s:
await s.do_handshake()
# underlying SSLObject interprets writing b"" as indicating an EOF,
# for some reason. Make sure we don't inherit this.
with assert_checkpoints():
await s.send_all(b"")
with assert_checkpoints():
await s.send_all(b"")
await s.send_all(b"x")
assert await s.receive_some(1) == b"x"
await s.aclose()
@pytest.mark.parametrize("https_compatible", [False, True])
async def test_SSLStream_generic(
client_ctx: SSLContext, https_compatible: bool
) -> None:
async def stream_maker() -> tuple[
SSLStream[MemoryStapledStream],
SSLStream[MemoryStapledStream],
]:
return ssl_memory_stream_pair(
client_ctx,
client_kwargs={"https_compatible": https_compatible},
server_kwargs={"https_compatible": https_compatible},
)
async def clogged_stream_maker() -> tuple[
SSLStream[MyStapledStream],
SSLStream[MyStapledStream],
]:
client, server = ssl_lockstep_stream_pair(client_ctx)
# If we don't do handshakes up front, then we run into a problem in
# the following situation:
# - server does wait_send_all_might_not_block
# - client does receive_some to unclog it
# Then the client's receive_some will actually send some data to start
# the handshake, and itself get stuck.
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
return client, server
await check_two_way_stream(stream_maker, clogged_stream_maker)
async def test_unwrap(client_ctx: SSLContext) -> None:
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
client_transport = client_ssl.transport_stream
server_transport = server_ssl.transport_stream
seq = Sequencer()
async def client() -> None:
await client_ssl.do_handshake()
await client_ssl.send_all(b"x")
assert await client_ssl.receive_some(1) == b"y"
await client_ssl.send_all(b"z")
# After sending that, disable outgoing data from our end, to make
# sure the server doesn't see our EOF until after we've sent some
# trailing data
async with seq(0):
send_all_hook = client_transport.send_stream.send_all_hook
client_transport.send_stream.send_all_hook = None
assert await client_ssl.receive_some(1) == b""
assert client_ssl.transport_stream is client_transport
# We just received EOF. Unwrap the connection and send some more.
raw, trailing = await client_ssl.unwrap()
assert raw is client_transport
assert trailing == b""
assert client_ssl.transport_stream is None
await raw.send_all(b"trailing")
# Reconnect the streams. Now the server will receive both our shutdown
# acknowledgement + the trailing data in a single lump.
client_transport.send_stream.send_all_hook = send_all_hook
await client_transport.send_stream.send_all_hook()
async def server() -> None:
await server_ssl.do_handshake()
assert await server_ssl.receive_some(1) == b"x"
await server_ssl.send_all(b"y")
assert await server_ssl.receive_some(1) == b"z"
# Now client is blocked waiting for us to send something, but
# instead we close the TLS connection (with sequencer to make sure
# that the client won't see and automatically respond before we've had
# a chance to disable the client->server transport)
async with seq(1):
raw, trailing = await server_ssl.unwrap()
assert raw is server_transport
assert trailing == b"trailing"
assert server_ssl.transport_stream is None
async with _core.open_nursery() as nursery:
nursery.start_soon(client)
nursery.start_soon(server)
async def test_closing_nice_case(client_ctx: SSLContext) -> None:
# the nice case: graceful closes all around
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
client_transport = client_ssl.transport_stream
# Both the handshake and the close require back-and-forth discussion, so
# we need to run them concurrently
async def client_closer() -> None:
with assert_checkpoints():
await client_ssl.aclose()
async def server_closer() -> None:
assert await server_ssl.receive_some(10) == b""
assert await server_ssl.receive_some(10) == b""
with assert_checkpoints():
await server_ssl.aclose()
async with _core.open_nursery() as nursery:
nursery.start_soon(client_closer)
nursery.start_soon(server_closer)
# closing the SSLStream also closes its transport
with pytest.raises(ClosedResourceError):
await client_transport.send_all(b"123")
# once closed, it's OK to close again
with assert_checkpoints():
await client_ssl.aclose()
with assert_checkpoints():
await client_ssl.aclose()
# Trying to send more data does not work
with pytest.raises(ClosedResourceError):
await server_ssl.send_all(b"123")
# And once the connection is has been closed *locally*, then instead of
# getting empty bytestrings we get a proper error
with pytest.raises(ClosedResourceError):
assert await client_ssl.receive_some(10) == b""
with pytest.raises(ClosedResourceError):
await client_ssl.unwrap()
with pytest.raises(ClosedResourceError):
await client_ssl.do_handshake()
# Check that a graceful close *before* handshaking gives a clean EOF on
# the other side
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
async def expect_eof_server() -> None:
with assert_checkpoints():
assert await server_ssl.receive_some(10) == b""
with assert_checkpoints():
await server_ssl.aclose()
async with _core.open_nursery() as nursery:
nursery.start_soon(client_ssl.aclose)
nursery.start_soon(expect_eof_server)
async def test_send_all_fails_in_the_middle(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.send_stream.send_all_hook = bad_hook
with pytest.raises(KeyError):
await client.send_all(b"x")
with pytest.raises(BrokenResourceError):
await client.wait_send_all_might_not_block()
closed = 0
def close_hook() -> None:
nonlocal closed
closed += 1
client.transport_stream.send_stream.close_hook = close_hook
client.transport_stream.receive_stream.close_hook = close_hook
await client.aclose()
assert closed == 2
async def test_ssl_over_ssl(client_ctx: SSLContext) -> None:
client_0, server_0 = memory_stream_pair()
client_1 = SSLStream(
client_0, client_ctx, server_hostname="trio-test-1.example.org"
)
server_1 = SSLStream(server_0, SERVER_CTX, server_side=True)
client_2 = SSLStream(
client_1, client_ctx, server_hostname="trio-test-1.example.org"
)
server_2 = SSLStream(server_1, SERVER_CTX, server_side=True)
async def client() -> None:
await client_2.send_all(b"hi")
assert await client_2.receive_some(10) == b"bye"
async def server() -> None:
assert await server_2.receive_some(10) == b"hi"
await server_2.send_all(b"bye")
async with _core.open_nursery() as nursery:
nursery.start_soon(client)
nursery.start_soon(server)
async def test_ssl_bad_shutdown(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
await trio.aclose_forcefully(client)
# now the server sees a broken stream
with pytest.raises(BrokenResourceError):
await server.receive_some(10)
with pytest.raises(BrokenResourceError):
await server.send_all(b"x" * 10)
await server.aclose()
async def test_ssl_bad_shutdown_but_its_ok(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
client_kwargs={"https_compatible": True},
)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
await trio.aclose_forcefully(client)
# the server sees that as a clean shutdown
assert await server.receive_some(10) == b""
with pytest.raises(BrokenResourceError):
await server.send_all(b"x" * 10)
await server.aclose()
async def test_ssl_handshake_failure_during_aclose() -> None:
# Weird scenario: aclose() triggers an automatic handshake, and this
# fails. This also exercises a bit of code in aclose() that was otherwise
# uncovered, for re-raising exceptions after calling aclose_forcefully on
# the underlying transport.
async with ssl_echo_server_raw(expect_fail=True) as sock:
# Don't configure trust correctly
client_ctx = ssl.create_default_context()
s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
# It's a little unclear here whether aclose should swallow the error
# or let it escape. We *do* swallow the error if it arrives when we're
# sending close_notify, because both sides closing the connection
# simultaneously is allowed. But I guess when https_compatible=False
# then it's bad if we can get through a whole connection with a peer
# that has no valid certificate, and never raise an error.
with pytest.raises(BrokenResourceError):
await s.aclose()
async def test_ssl_only_closes_stream_once(client_ctx: SSLContext) -> None:
# We used to have a bug where if transport_stream.aclose() raised an
# error, we would call it again. This checks that that's fixed.
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
client_orig_close_hook = client.transport_stream.send_stream.close_hook
transport_close_count = 0
def close_hook() -> NoReturn:
nonlocal transport_close_count
assert client_orig_close_hook is not None
client_orig_close_hook()
transport_close_count += 1
raise KeyError
client.transport_stream.send_stream.close_hook = close_hook
with pytest.raises(KeyError):
await client.aclose()
assert transport_close_count == 1
async def test_ssl_https_compatibility_disagreement(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": False},
client_kwargs={"https_compatible": True},
)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
# client is in HTTPS-mode, server is not
# so client doing graceful_shutdown causes an error on server
async def receive_and_expect_error() -> None:
with pytest.raises(BrokenResourceError) as excinfo:
await server.receive_some(10)
assert _is_eof(excinfo.value.__cause__)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.aclose)
nursery.start_soon(receive_and_expect_error)
async def test_https_mode_eof_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(
client_ctx,
server_kwargs={"https_compatible": True},
client_kwargs={"https_compatible": True},
)
async def server_expect_clean_eof() -> None:
assert await server.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(client.aclose)
nursery.start_soon(server_expect_clean_eof)
async def test_send_error_during_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.send_stream.send_all_hook = bad_hook
with pytest.raises(KeyError):
with assert_checkpoints():
await client.do_handshake()
with pytest.raises(BrokenResourceError):
with assert_checkpoints():
await client.do_handshake()
async def test_receive_error_during_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async def bad_hook() -> NoReturn:
raise KeyError
client.transport_stream.receive_stream.receive_some_hook = bad_hook
async def client_side(cancel_scope: CancelScope) -> None:
with pytest.raises(KeyError):
with assert_checkpoints():
await client.do_handshake()
cancel_scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(client_side, nursery.cancel_scope)
nursery.start_soon(server.do_handshake)
with pytest.raises(BrokenResourceError):
with assert_checkpoints():
await client.do_handshake()
async def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
client.selected_alpn_protocol()
with pytest.raises(NeedHandshakeError):
server.selected_alpn_protocol()
async def test_selected_alpn_protocol_when_not_set(client_ctx: SSLContext) -> None:
# ALPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
assert client.selected_alpn_protocol() is None
assert server.selected_alpn_protocol() is None
assert client.selected_alpn_protocol() == server.selected_alpn_protocol()
async def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
client.selected_npn_protocol()
with pytest.raises(NeedHandshakeError):
server.selected_npn_protocol()
@pytest.mark.filterwarnings(
r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning",
r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning",
)
async def test_selected_npn_protocol_when_not_set(client_ctx: SSLContext) -> None:
# NPN protocol still returns None when it's not set,
# instead of raising an exception
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
assert client.selected_npn_protocol() is None
assert server.selected_npn_protocol() is None
assert client.selected_npn_protocol() == server.selected_npn_protocol()
async def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
with pytest.raises(NeedHandshakeError):
client.get_channel_binding()
with pytest.raises(NeedHandshakeError):
server.get_channel_binding()
async def test_get_channel_binding_after_handshake(client_ctx: SSLContext) -> None:
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
assert client.get_channel_binding() is not None
assert server.get_channel_binding() is not None
assert client.get_channel_binding() == server.get_channel_binding()
async def test_getpeercert(client_ctx: SSLContext) -> None:
# Make sure we're not affected by https://bugs.python.org/issue29334
client, server = ssl_memory_stream_pair(client_ctx)
async with _core.open_nursery() as nursery:
nursery.start_soon(client.do_handshake)
nursery.start_soon(server.do_handshake)
assert server.getpeercert() is None
print(client.getpeercert())
assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"]
async def test_SSLListener(client_ctx: SSLContext) -> None:
async def setup(
**kwargs: Any,
) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
socket_listener = SocketListener(listen_sock)
ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs)
transport_client = await open_tcp_stream(*listen_sock.getsockname())
ssl_client = SSLStream(
transport_client, client_ctx, server_hostname="trio-test-1.example.org"
)
return listen_sock, ssl_listener, ssl_client
listen_sock, ssl_listener, ssl_client = await setup()
async with ssl_client:
ssl_server = await ssl_listener.accept()
async with ssl_server:
assert not ssl_server._https_compatible
# Make sure the connection works
async with _core.open_nursery() as nursery:
nursery.start_soon(ssl_client.do_handshake)
nursery.start_soon(ssl_server.do_handshake)
# Test SSLListener.aclose
await ssl_listener.aclose()
assert listen_sock.fileno() == -1
################
# Test https_compatible
_, ssl_listener, ssl_client = await setup(https_compatible=True)
ssl_server = await ssl_listener.accept()
assert ssl_server._https_compatible
await aclose_forcefully(ssl_listener)
await aclose_forcefully(ssl_client)
await aclose_forcefully(ssl_server)