You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1338 lines
50 KiB
1338 lines
50 KiB
from __future__ import annotations
|
|
|
|
import os
|
|
import socket as stdlib_socket
|
|
import ssl
|
|
import sys
|
|
import threading
|
|
from contextlib import asynccontextmanager, contextmanager, suppress
|
|
from functools import partial
|
|
from ssl import SSLContext
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncIterator,
|
|
Awaitable,
|
|
Callable,
|
|
Iterator,
|
|
NoReturn,
|
|
)
|
|
|
|
import pytest
|
|
|
|
from trio import StapledStream
|
|
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
|
from trio.abc import ReceiveStream, SendStream
|
|
from trio.testing import (
|
|
Matcher,
|
|
MemoryReceiveStream,
|
|
MemorySendStream,
|
|
RaisesGroup,
|
|
)
|
|
|
|
try:
|
|
import trustme
|
|
from OpenSSL import SSL
|
|
except ImportError as error:
|
|
skip_if_optional_else_raise(error)
|
|
|
|
import trio
|
|
|
|
from .. import _core, socket as tsocket
|
|
from .._abc import Stream
|
|
from .._core import BrokenResourceError, ClosedResourceError
|
|
from .._core._tests.tutil import slow
|
|
from .._highlevel_generic import aclose_forcefully
|
|
from .._highlevel_open_tcp_stream import open_tcp_stream
|
|
from .._highlevel_socket import SocketListener, SocketStream
|
|
from .._ssl import NeedHandshakeError, SSLListener, SSLStream, _is_eof
|
|
from .._util import ConflictDetector
|
|
from ..testing import (
|
|
Sequencer,
|
|
assert_checkpoints,
|
|
check_two_way_stream,
|
|
lockstep_stream_pair,
|
|
memory_stream_pair,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import TypeAlias
|
|
|
|
from trio._core import MockClock
|
|
from trio._ssl import T_Stream
|
|
|
|
from .._core._run import CancelScope
|
|
|
|
# We have two different kinds of echo server fixtures we use for testing. The
|
|
# first is a real server written using the stdlib ssl module and blocking
|
|
# sockets. It runs in a thread and we talk to it over a real socketpair(), to
|
|
# validate interoperability in a semi-realistic setting.
|
|
#
|
|
# The second is a very weird virtual echo server that lives inside a custom
|
|
# Stream class. It lives entirely inside the Python object space; there are no
|
|
# operating system calls in it at all. No threads, no I/O, nothing. It's
|
|
# 'send_all' call takes encrypted data from a client and feeds it directly into
|
|
# the server-side TLS state engine to decrypt, then takes that data, feeds it
|
|
# back through to get the encrypted response, and returns it from 'receive_some'. This
|
|
# gives us full control and reproducibility. This server is written using
|
|
# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows
|
|
# us to insert random (virtual) delays, to really exercise all the weird paths
|
|
# in SSLStream's state engine.
|
|
#
|
|
# Both present a certificate for "trio-test-1.example.org".
|
|
|
|
TRIO_TEST_CA = trustme.CA()
|
|
TRIO_TEST_1_CERT = TRIO_TEST_CA.issue_server_cert("trio-test-1.example.org")
|
|
|
|
SERVER_CTX = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
|
|
SERVER_CTX.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
|
|
|
|
TRIO_TEST_1_CERT.configure_cert(SERVER_CTX)
|
|
|
|
|
|
# TLS 1.3 has a lot of changes from previous versions. So we want to run tests
|
|
# with both TLS 1.3, and TLS 1.2.
|
|
# "tls13" means that we're willing to negotiate TLS 1.3. Usually that's
|
|
# what will happen, but the renegotiation tests explicitly force a
|
|
# downgrade on the server side. "tls12" means we refuse to negotiate TLS
|
|
# 1.3, so we'll almost certainly use TLS 1.2.
|
|
@pytest.fixture(scope="module", params=["tls13", "tls12"])
|
|
def client_ctx(request: pytest.FixtureRequest) -> ssl.SSLContext:
|
|
ctx = ssl.create_default_context()
|
|
|
|
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
|
|
ctx.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
|
|
|
|
TRIO_TEST_CA.configure_trust(ctx)
|
|
if request.param in ["default", "tls13"]:
|
|
return ctx
|
|
elif request.param == "tls12":
|
|
ctx.maximum_version = ssl.TLSVersion.TLSv1_2
|
|
return ctx
|
|
else: # pragma: no cover
|
|
raise AssertionError()
|
|
|
|
|
|
# The blocking socket server.
|
|
def ssl_echo_serve_sync(
|
|
sock: stdlib_socket.socket, *, expect_fail: bool = False
|
|
) -> None:
|
|
try:
|
|
wrapped = SERVER_CTX.wrap_socket(
|
|
sock, server_side=True, suppress_ragged_eofs=False
|
|
)
|
|
with wrapped:
|
|
wrapped.do_handshake()
|
|
while True:
|
|
data = wrapped.recv(4096)
|
|
if not data:
|
|
# other side has initiated a graceful shutdown; we try to
|
|
# respond in kind but it's legal for them to have already
|
|
# gone away.
|
|
with suppress(BrokenPipeError, ssl.SSLZeroReturnError):
|
|
wrapped.unwrap()
|
|
return
|
|
wrapped.sendall(data)
|
|
# This is an obscure workaround for an openssl bug. In server mode, in
|
|
# some versions, openssl sends some extra data at the end of do_handshake
|
|
# that it shouldn't send. Normally this is harmless, but, if the other
|
|
# side shuts down the connection before it reads that data, it might cause
|
|
# the OS to report a ECONNREST or even ECONNABORTED (which is just wrong,
|
|
# since ECONNABORTED is supposed to mean that connect() failed, but what
|
|
# can you do). In this case the other side did nothing wrong, but there's
|
|
# no way to recover, so we let it pass, and just cross our fingers its not
|
|
# hiding any (other) real bugs. For more details see:
|
|
#
|
|
# https://github.com/python-trio/trio/issues/1293
|
|
#
|
|
# Also, this happens frequently but non-deterministically, so we have to
|
|
# 'no cover' it to avoid coverage flapping.
|
|
except (ConnectionResetError, ConnectionAbortedError): # pragma: no cover
|
|
return
|
|
except Exception as exc:
|
|
if expect_fail:
|
|
print("ssl_echo_serve_sync got error as expected:", exc)
|
|
else: # pragma: no cover
|
|
print("ssl_echo_serve_sync got unexpected error:", exc)
|
|
raise
|
|
else:
|
|
if expect_fail: # pragma: no cover
|
|
raise RuntimeError("failed to fail?")
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
# Fixture that gives a raw socket connected to a trio-test-1 echo server
|
|
# (running in a thread). Useful for testing making connections with different
|
|
# SSLContexts.
|
|
@asynccontextmanager # type: ignore[misc] # decorated contains Any
|
|
async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
|
|
a, b = stdlib_socket.socketpair()
|
|
async with trio.open_nursery() as nursery:
|
|
# Exiting the 'with a, b' context manager closes the sockets, which
|
|
# causes the thread to exit (possibly with an error), which allows the
|
|
# nursery context manager to exit too.
|
|
with a, b:
|
|
nursery.start_soon(
|
|
trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs)
|
|
)
|
|
|
|
yield SocketStream(tsocket.from_stdlib_socket(a))
|
|
|
|
|
|
# Fixture that gives a properly set up SSLStream connected to a trio-test-1
|
|
# echo server (running in a thread)
|
|
@asynccontextmanager # type: ignore[misc] # decorated contains Any
|
|
async def ssl_echo_server(
|
|
client_ctx: SSLContext, **kwargs: Any
|
|
) -> AsyncIterator[SSLStream[Stream]]:
|
|
async with ssl_echo_server_raw(**kwargs) as sock:
|
|
yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
|
|
|
|
|
|
# The weird in-memory server ... thing.
|
|
# Doesn't inherit from Stream because I left out the methods that we don't
|
|
# actually need.
|
|
# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
|
|
# from Stream for the sake of typechecking.
|
|
class PyOpenSSLEchoStream(Stream):
|
|
def __init__(self, sleeper: None = None) -> None:
|
|
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
|
# TLS 1.3 removes renegotiation support. Which is great for them, but
|
|
# we still have to support versions before that, and that means we
|
|
# need to test renegotiation support, which means we need to force this
|
|
# to use a lower version where this test server can trigger
|
|
# renegotiations. Of course TLS 1.3 support isn't released yet, but
|
|
# I'm told that this will work once it is. (And once it is we can
|
|
# remove the pragma: no cover too.) Alternatively, we could switch to
|
|
# using TLSv1_2_METHOD.
|
|
#
|
|
# Discussion: https://github.com/pyca/pyopenssl/issues/624
|
|
|
|
# This is the right way, but we can't use it until this PR is in a
|
|
# released:
|
|
# https://github.com/pyca/pyopenssl/pull/861
|
|
#
|
|
# if hasattr(SSL, "OP_NO_TLSv1_3"):
|
|
# ctx.set_options(SSL.OP_NO_TLSv1_3)
|
|
#
|
|
# Fortunately pyopenssl uses cryptography under the hood, so we can be
|
|
# confident that they're using the same version of openssl
|
|
from cryptography.hazmat.bindings.openssl.binding import Binding
|
|
|
|
b = Binding()
|
|
if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"):
|
|
ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3)
|
|
|
|
# Unfortunately there's currently no way to say "use 1.3 or worse", we
|
|
# can only disable specific versions. And if the two sides start
|
|
# negotiating 1.4 at some point in the future, it *might* mean that
|
|
# our tests silently stop working properly. So the next line is a
|
|
# tripwire to remind us we need to revisit this stuff in 5 years or
|
|
# whatever when the next TLS version is released:
|
|
assert not hasattr(SSL, "OP_NO_TLSv1_4")
|
|
TRIO_TEST_1_CERT.configure_cert(ctx)
|
|
self._conn = SSL.Connection(ctx, None)
|
|
self._conn.set_accept_state()
|
|
self._lot = _core.ParkingLot()
|
|
self._pending_cleartext = bytearray()
|
|
|
|
self._send_all_conflict_detector = ConflictDetector(
|
|
"simultaneous calls to PyOpenSSLEchoStream.send_all"
|
|
)
|
|
self._receive_some_conflict_detector = ConflictDetector(
|
|
"simultaneous calls to PyOpenSSLEchoStream.receive_some"
|
|
)
|
|
|
|
if sleeper is None:
|
|
|
|
async def no_op_sleeper(_: object) -> None:
|
|
return
|
|
|
|
self.sleeper = no_op_sleeper
|
|
else:
|
|
self.sleeper = sleeper
|
|
|
|
async def aclose(self) -> None:
|
|
self._conn.bio_shutdown()
|
|
|
|
def renegotiate_pending(self) -> bool:
|
|
return self._conn.renegotiate_pending()
|
|
|
|
def renegotiate(self) -> None:
|
|
# Returns false if a renegotiation is already in progress, meaning
|
|
# nothing happens.
|
|
assert self._conn.renegotiate()
|
|
|
|
async def wait_send_all_might_not_block(self) -> None:
|
|
with self._send_all_conflict_detector:
|
|
await _core.checkpoint()
|
|
await _core.checkpoint()
|
|
await self.sleeper("wait_send_all_might_not_block")
|
|
|
|
async def send_all(self, data: bytes) -> None:
|
|
print(" --> transport_stream.send_all")
|
|
with self._send_all_conflict_detector:
|
|
await _core.checkpoint()
|
|
await _core.checkpoint()
|
|
await self.sleeper("send_all")
|
|
self._conn.bio_write(data)
|
|
while True:
|
|
await self.sleeper("send_all")
|
|
try:
|
|
data = self._conn.recv(1)
|
|
except SSL.ZeroReturnError:
|
|
self._conn.shutdown()
|
|
print("renegotiations:", self._conn.total_renegotiations())
|
|
break
|
|
except SSL.WantReadError:
|
|
break
|
|
else:
|
|
self._pending_cleartext += data
|
|
self._lot.unpark_all()
|
|
await self.sleeper("send_all")
|
|
print(" <-- transport_stream.send_all finished")
|
|
|
|
async def receive_some(self, nbytes: int | None = None) -> bytes:
|
|
print(" --> transport_stream.receive_some")
|
|
if nbytes is None:
|
|
nbytes = 65536 # arbitrary
|
|
with self._receive_some_conflict_detector:
|
|
try:
|
|
await _core.checkpoint()
|
|
await _core.checkpoint()
|
|
while True:
|
|
await self.sleeper("receive_some")
|
|
try:
|
|
return self._conn.bio_read(nbytes)
|
|
except SSL.WantReadError:
|
|
# No data in our ciphertext buffer; try to generate
|
|
# some.
|
|
if self._pending_cleartext:
|
|
# We have some cleartext; maybe we can encrypt it
|
|
# and then return it.
|
|
print(" trying", self._pending_cleartext)
|
|
try:
|
|
# PyOpenSSL bug: doesn't accept bytearray
|
|
# https://github.com/pyca/pyopenssl/issues/621
|
|
next_byte = self._pending_cleartext[0:1]
|
|
self._conn.send(bytes(next_byte))
|
|
# Apparently this next bit never gets hit in the
|
|
# test suite, but it's not an interesting omission
|
|
# so let's pragma it.
|
|
except SSL.WantReadError: # pragma: no cover
|
|
# We didn't manage to send the cleartext (and
|
|
# in particular we better leave it there to
|
|
# try again, due to openssl's retry
|
|
# semantics), but it's possible we pushed a
|
|
# renegotiation forward and *now* we have data
|
|
# to send.
|
|
try:
|
|
return self._conn.bio_read(nbytes)
|
|
except SSL.WantReadError:
|
|
# Nope. We're just going to have to wait
|
|
# for someone to call send_all() to give
|
|
# use more data.
|
|
print("parking (a)")
|
|
await self._lot.park()
|
|
else:
|
|
# We successfully sent that byte, so we don't
|
|
# have to again.
|
|
del self._pending_cleartext[0:1]
|
|
else:
|
|
# no pending cleartext; nothing to do but wait for
|
|
# someone to call send_all
|
|
print("parking (b)")
|
|
await self._lot.park()
|
|
finally:
|
|
await self.sleeper("receive_some")
|
|
print(" <-- transport_stream.receive_some finished")
|
|
|
|
|
|
async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None:
|
|
# Make sure that PyOpenSSLEchoStream complains if two tasks call send_all
|
|
# at the same time, or ditto for receive_some. The tricky cases where SSLStream
|
|
# might accidentally do this are during renegotiation, which we test using
|
|
# PyOpenSSLEchoStream, so this makes sure that if we do have a bug then
|
|
# PyOpenSSLEchoStream will notice and complain.
|
|
|
|
async def do_test(
|
|
func1: str, args1: tuple[object, ...], func2: str, args2: tuple[object, ...]
|
|
) -> None:
|
|
s = PyOpenSSLEchoStream()
|
|
with RaisesGroup(Matcher(_core.BusyResourceError, "simultaneous")):
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(getattr(s, func1), *args1)
|
|
nursery.start_soon(getattr(s, func2), *args2)
|
|
|
|
await do_test("send_all", (b"x",), "send_all", (b"x",))
|
|
await do_test("send_all", (b"x",), "wait_send_all_might_not_block", ())
|
|
await do_test(
|
|
"wait_send_all_might_not_block", (), "wait_send_all_might_not_block", ()
|
|
)
|
|
await do_test("receive_some", (1,), "receive_some", (1,))
|
|
|
|
|
|
@contextmanager # type: ignore[misc] # decorated contains Any
|
|
def virtual_ssl_echo_server(
|
|
client_ctx: SSLContext, **kwargs: Any
|
|
) -> Iterator[SSLStream[PyOpenSSLEchoStream]]:
|
|
fakesock = PyOpenSSLEchoStream(**kwargs)
|
|
yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")
|
|
|
|
|
|
def ssl_wrap_pair(
|
|
client_ctx: SSLContext,
|
|
client_transport: T_Stream,
|
|
server_transport: T_Stream,
|
|
*,
|
|
client_kwargs: dict[str, Any] | None = None,
|
|
server_kwargs: dict[str, Any] | None = None,
|
|
) -> tuple[SSLStream[T_Stream], SSLStream[T_Stream]]:
|
|
if server_kwargs is None:
|
|
server_kwargs = {}
|
|
if client_kwargs is None:
|
|
client_kwargs = {}
|
|
client_ssl = SSLStream(
|
|
client_transport,
|
|
client_ctx,
|
|
server_hostname="trio-test-1.example.org",
|
|
**client_kwargs,
|
|
)
|
|
server_ssl = SSLStream(
|
|
server_transport, SERVER_CTX, server_side=True, **server_kwargs
|
|
)
|
|
return client_ssl, server_ssl
|
|
|
|
|
|
MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream]
|
|
|
|
|
|
def ssl_memory_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[
|
|
SSLStream[MemoryStapledStream],
|
|
SSLStream[MemoryStapledStream],
|
|
]:
|
|
client_transport, server_transport = memory_stream_pair()
|
|
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
|
|
|
|
|
|
MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream]
|
|
|
|
|
|
def ssl_lockstep_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[
|
|
SSLStream[MyStapledStream],
|
|
SSLStream[MyStapledStream],
|
|
]:
|
|
client_transport, server_transport = lockstep_stream_pair()
|
|
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
|
|
|
|
|
|
# Simple smoke test for handshake/send/receive/shutdown talking to a
|
|
# synchronous server, plus make sure that we do the bare minimum of
|
|
# certificate checking (even though this is really Python's responsibility)
|
|
async def test_ssl_client_basics(client_ctx: SSLContext) -> None:
|
|
# Everything OK
|
|
async with ssl_echo_server(client_ctx) as s:
|
|
assert not s.server_side
|
|
await s.send_all(b"x")
|
|
assert await s.receive_some(1) == b"x"
|
|
await s.aclose()
|
|
|
|
# Didn't configure the CA file, should fail
|
|
async with ssl_echo_server_raw(expect_fail=True) as sock:
|
|
bad_client_ctx = ssl.create_default_context()
|
|
s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org")
|
|
assert not s.server_side
|
|
with pytest.raises(BrokenResourceError) as excinfo:
|
|
await s.send_all(b"x")
|
|
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
|
|
|
|
# Trusted CA, but wrong host name
|
|
async with ssl_echo_server_raw(expect_fail=True) as sock:
|
|
s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org")
|
|
assert not s.server_side
|
|
with pytest.raises(BrokenResourceError) as excinfo:
|
|
await s.send_all(b"x")
|
|
assert isinstance(excinfo.value.__cause__, ssl.CertificateError)
|
|
|
|
|
|
async def test_ssl_server_basics(client_ctx: SSLContext) -> None:
|
|
a, b = stdlib_socket.socketpair()
|
|
with a, b:
|
|
server_sock = tsocket.from_stdlib_socket(b)
|
|
server_transport = SSLStream(
|
|
SocketStream(server_sock), SERVER_CTX, server_side=True
|
|
)
|
|
assert server_transport.server_side
|
|
|
|
def client() -> None:
|
|
with client_ctx.wrap_socket(
|
|
a, server_hostname="trio-test-1.example.org"
|
|
) as client_sock:
|
|
client_sock.sendall(b"x")
|
|
assert client_sock.recv(1) == b"y"
|
|
client_sock.sendall(b"z")
|
|
client_sock.unwrap()
|
|
|
|
t = threading.Thread(target=client)
|
|
t.start()
|
|
|
|
assert await server_transport.receive_some(1) == b"x"
|
|
await server_transport.send_all(b"y")
|
|
assert await server_transport.receive_some(1) == b"z"
|
|
assert await server_transport.receive_some(1) == b""
|
|
await server_transport.aclose()
|
|
|
|
t.join()
|
|
|
|
|
|
async def test_attributes(client_ctx: SSLContext) -> None:
|
|
async with ssl_echo_server_raw(expect_fail=True) as sock:
|
|
good_ctx = client_ctx
|
|
bad_ctx = ssl.create_default_context()
|
|
s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org")
|
|
|
|
assert s.transport_stream is sock
|
|
|
|
# Forwarded attribute getting
|
|
assert s.context is good_ctx
|
|
assert s.server_side == False # noqa
|
|
assert s.server_hostname == "trio-test-1.example.org"
|
|
with pytest.raises(AttributeError):
|
|
s.asfdasdfsa # noqa: B018 # "useless expression"
|
|
|
|
# __dir__
|
|
assert "transport_stream" in dir(s)
|
|
assert "context" in dir(s)
|
|
|
|
# Setting the attribute goes through to the underlying object
|
|
|
|
# most attributes on SSLObject are read-only
|
|
with pytest.raises(AttributeError):
|
|
s.server_side = True
|
|
with pytest.raises(AttributeError):
|
|
s.server_hostname = "asdf"
|
|
|
|
# but .context is *not*. Check that we forward attribute setting by
|
|
# making sure that after we set the bad context our handshake indeed
|
|
# fails:
|
|
s.context = bad_ctx
|
|
assert s.context is bad_ctx
|
|
with pytest.raises(BrokenResourceError) as excinfo:
|
|
await s.do_handshake()
|
|
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
|
|
|
|
|
|
# Note: this test fails horribly if we force TLS 1.2 and trigger a
|
|
# renegotiation at the beginning (e.g. by switching to the pyopenssl
|
|
# server). Usually the client crashes in SSLObject.write with "UNEXPECTED
|
|
# RECORD"; sometimes we get something more exotic like a SyscallError. This is
|
|
# odd because openssl isn't doing any syscalls, but so it goes. After lots of
|
|
# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just
|
|
# can't reliably handle full-duplex communication combined with
|
|
# renegotiation. Nice, eh?
|
|
#
|
|
# https://rt.openssl.org/Ticket/Display.html?id=3712
|
|
# https://rt.openssl.org/Ticket/Display.html?id=2481
|
|
# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html
|
|
# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication
|
|
#
|
|
# In some variants of this test (maybe only against the java server?) I've
|
|
# also seen cases where our send_all blocks waiting to write, and then our receive_some
|
|
# also blocks waiting to write, and they never wake up again. It looks like
|
|
# some kind of deadlock. I suspect there may be an issue where we've filled up
|
|
# the send buffers, and the remote side is trying to handle the renegotiation
|
|
# from inside a write() call, so it has a problem: there's all this application
|
|
# data clogging up the pipe, but it can't process and return it to the
|
|
# application because it's in write(), and it doesn't want to buffer infinite
|
|
# amounts of data, and... actually I guess those are the only two choices.
|
|
#
|
|
# NSS even documents that you shouldn't try to do a renegotiation except when
|
|
# the connection is idle:
|
|
#
|
|
# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582
|
|
#
|
|
# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
|
|
|
|
|
|
async def test_full_duplex_basics(client_ctx: SSLContext) -> None:
|
|
CHUNKS = 30
|
|
CHUNK_SIZE = 32768
|
|
EXPECTED = CHUNKS * CHUNK_SIZE
|
|
|
|
sent = bytearray()
|
|
received = bytearray()
|
|
|
|
async def sender(s: Stream) -> None:
|
|
nonlocal sent
|
|
for i in range(CHUNKS):
|
|
print(i)
|
|
chunk = bytes([i] * CHUNK_SIZE)
|
|
sent += chunk
|
|
await s.send_all(chunk)
|
|
|
|
async def receiver(s: Stream) -> None:
|
|
nonlocal received
|
|
while len(received) < EXPECTED:
|
|
chunk = await s.receive_some(CHUNK_SIZE // 2)
|
|
received += chunk
|
|
|
|
async with ssl_echo_server(client_ctx) as s:
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(sender, s)
|
|
nursery.start_soon(receiver, s)
|
|
# And let's have some doing handshakes too, everyone
|
|
# simultaneously
|
|
nursery.start_soon(s.do_handshake)
|
|
nursery.start_soon(s.do_handshake)
|
|
|
|
await s.aclose()
|
|
|
|
assert len(sent) == len(received) == EXPECTED
|
|
assert sent == received
|
|
|
|
|
|
async def test_renegotiation_simple(client_ctx: SSLContext) -> None:
|
|
with virtual_ssl_echo_server(client_ctx) as s:
|
|
await s.do_handshake()
|
|
s.transport_stream.renegotiate()
|
|
await s.send_all(b"a")
|
|
assert await s.receive_some(1) == b"a"
|
|
|
|
# Have to send some more data back and forth to make sure the
|
|
# renegotiation is finished before shutting down the
|
|
# connection... otherwise openssl raises an error. I think this is a
|
|
# bug in openssl but what can ya do.
|
|
await s.send_all(b"b")
|
|
assert await s.receive_some(1) == b"b"
|
|
|
|
await s.aclose()
|
|
|
|
|
|
@slow
|
|
async def test_renegotiation_randomized(
|
|
mock_clock: MockClock, client_ctx: SSLContext
|
|
) -> None:
|
|
# The only blocking things in this function are our random sleeps, so 0 is
|
|
# a good threshold.
|
|
mock_clock.autojump_threshold = 0
|
|
|
|
import random
|
|
|
|
r = random.Random(0)
|
|
|
|
async def sleeper(_: object) -> None:
|
|
await trio.sleep(r.uniform(0, 10))
|
|
|
|
async def clear() -> None:
|
|
while s.transport_stream.renegotiate_pending():
|
|
with assert_checkpoints():
|
|
await send(b"-")
|
|
with assert_checkpoints():
|
|
await expect(b"-")
|
|
print("-- clear --")
|
|
|
|
async def send(byte: bytes) -> None:
|
|
await s.transport_stream.sleeper("outer send")
|
|
print("calling SSLStream.send_all", byte)
|
|
with assert_checkpoints():
|
|
await s.send_all(byte)
|
|
|
|
async def expect(expected: bytes) -> None:
|
|
await s.transport_stream.sleeper("expect")
|
|
print("calling SSLStream.receive_some, expecting", expected)
|
|
assert len(expected) == 1
|
|
with assert_checkpoints():
|
|
assert await s.receive_some(1) == expected
|
|
|
|
with virtual_ssl_echo_server(client_ctx, sleeper=sleeper) as s:
|
|
await s.do_handshake()
|
|
|
|
await send(b"a")
|
|
s.transport_stream.renegotiate()
|
|
await expect(b"a")
|
|
|
|
await clear()
|
|
|
|
for i in range(100):
|
|
b1 = bytes([i % 0xFF])
|
|
b2 = bytes([(2 * i) % 0xFF])
|
|
s.transport_stream.renegotiate()
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(send, b1)
|
|
nursery.start_soon(expect, b1)
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(expect, b2)
|
|
nursery.start_soon(send, b2)
|
|
await clear()
|
|
|
|
for i in range(100):
|
|
b1 = bytes([i % 0xFF])
|
|
b2 = bytes([(2 * i) % 0xFF])
|
|
await send(b1)
|
|
s.transport_stream.renegotiate()
|
|
await expect(b1)
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(expect, b2)
|
|
nursery.start_soon(send, b2)
|
|
await clear()
|
|
|
|
# Checking that wait_send_all_might_not_block and receive_some don't
|
|
# conflict:
|
|
|
|
# 1) Set up a situation where expect (receive_some) is blocked sending,
|
|
# and wait_send_all_might_not_block comes in.
|
|
|
|
# Our receive_some() call will get stuck when it hits send_all
|
|
async def sleeper_with_slow_send_all(method: str) -> None:
|
|
if method == "send_all":
|
|
await trio.sleep(100000)
|
|
|
|
# And our wait_send_all_might_not_block call will give it time to get
|
|
# stuck, and then start
|
|
async def sleep_then_wait_writable() -> None:
|
|
await trio.sleep(1000)
|
|
await s.wait_send_all_might_not_block()
|
|
|
|
with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s:
|
|
await send(b"x")
|
|
s.transport_stream.renegotiate()
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(expect, b"x")
|
|
nursery.start_soon(sleep_then_wait_writable)
|
|
|
|
await clear()
|
|
|
|
await s.aclose()
|
|
|
|
# 2) Same, but now wait_send_all_might_not_block is stuck when
|
|
# receive_some tries to send.
|
|
|
|
async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None:
|
|
if method == "wait_send_all_might_not_block":
|
|
await trio.sleep(100000)
|
|
elif method == "expect":
|
|
await trio.sleep(1000)
|
|
|
|
with virtual_ssl_echo_server(
|
|
client_ctx, sleeper=sleeper_with_slow_wait_writable_and_expect
|
|
) as s:
|
|
await send(b"x")
|
|
s.transport_stream.renegotiate()
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(expect, b"x")
|
|
nursery.start_soon(s.wait_send_all_might_not_block)
|
|
|
|
await clear()
|
|
|
|
await s.aclose()
|
|
|
|
|
|
async def test_resource_busy_errors(client_ctx: SSLContext) -> None:
|
|
S: TypeAlias = trio.SSLStream[
|
|
trio.StapledStream[trio.abc.SendStream, trio.abc.ReceiveStream]
|
|
]
|
|
|
|
async def do_send_all(s: S) -> None:
|
|
with assert_checkpoints():
|
|
await s.send_all(b"x")
|
|
|
|
async def do_receive_some(s: S) -> None:
|
|
with assert_checkpoints():
|
|
await s.receive_some(1)
|
|
|
|
async def do_wait_send_all_might_not_block(s: S) -> None:
|
|
with assert_checkpoints():
|
|
await s.wait_send_all_might_not_block()
|
|
|
|
async def do_test(
|
|
func1: Callable[[S], Awaitable[None]], func2: Callable[[S], Awaitable[None]]
|
|
) -> None:
|
|
s, _ = ssl_lockstep_stream_pair(client_ctx)
|
|
with RaisesGroup(Matcher(_core.BusyResourceError, "another task")):
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(func1, s)
|
|
nursery.start_soon(func2, s)
|
|
|
|
await do_test(do_send_all, do_send_all)
|
|
await do_test(do_receive_some, do_receive_some)
|
|
await do_test(do_send_all, do_wait_send_all_might_not_block)
|
|
await do_test(do_wait_send_all_might_not_block, do_wait_send_all_might_not_block)
|
|
|
|
|
|
async def test_wait_writable_calls_underlying_wait_writable() -> None:
|
|
record = []
|
|
|
|
class NotAStream(Stream):
|
|
async def wait_send_all_might_not_block(self) -> None:
|
|
record.append("ok")
|
|
|
|
# define methods that are abstract in Stream
|
|
async def aclose(self) -> None:
|
|
raise AssertionError("Should not get called") # pragma: no cover
|
|
|
|
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
|
raise AssertionError("Should not get called") # pragma: no cover
|
|
|
|
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
|
raise AssertionError("Should not get called") # pragma: no cover
|
|
|
|
ctx = ssl.create_default_context()
|
|
s = SSLStream(NotAStream(), ctx, server_hostname="x")
|
|
await s.wait_send_all_might_not_block()
|
|
assert record == ["ok"]
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
os.name == "nt" and sys.version_info >= (3, 10),
|
|
reason="frequently fails on Windows + Python 3.10",
|
|
)
|
|
async def test_checkpoints(client_ctx: SSLContext) -> None:
|
|
async with ssl_echo_server(client_ctx) as s:
|
|
with assert_checkpoints():
|
|
await s.do_handshake()
|
|
with assert_checkpoints():
|
|
await s.do_handshake()
|
|
with assert_checkpoints():
|
|
await s.wait_send_all_might_not_block()
|
|
with assert_checkpoints():
|
|
await s.send_all(b"xxx")
|
|
with assert_checkpoints():
|
|
await s.receive_some(1)
|
|
# These receive_some's in theory could return immediately, because the
|
|
# "xxx" was sent in a single record and after the first
|
|
# receive_some(1) the rest are sitting inside the SSLObject's internal
|
|
# buffers.
|
|
with assert_checkpoints():
|
|
await s.receive_some(1)
|
|
with assert_checkpoints():
|
|
await s.receive_some(1)
|
|
with assert_checkpoints():
|
|
await s.unwrap()
|
|
|
|
async with ssl_echo_server(client_ctx) as s:
|
|
await s.do_handshake()
|
|
with assert_checkpoints():
|
|
await s.aclose()
|
|
|
|
|
|
async def test_send_all_empty_string(client_ctx: SSLContext) -> None:
|
|
async with ssl_echo_server(client_ctx) as s:
|
|
await s.do_handshake()
|
|
|
|
# underlying SSLObject interprets writing b"" as indicating an EOF,
|
|
# for some reason. Make sure we don't inherit this.
|
|
with assert_checkpoints():
|
|
await s.send_all(b"")
|
|
with assert_checkpoints():
|
|
await s.send_all(b"")
|
|
await s.send_all(b"x")
|
|
assert await s.receive_some(1) == b"x"
|
|
|
|
await s.aclose()
|
|
|
|
|
|
@pytest.mark.parametrize("https_compatible", [False, True])
|
|
async def test_SSLStream_generic(
|
|
client_ctx: SSLContext, https_compatible: bool
|
|
) -> None:
|
|
async def stream_maker() -> tuple[
|
|
SSLStream[MemoryStapledStream],
|
|
SSLStream[MemoryStapledStream],
|
|
]:
|
|
return ssl_memory_stream_pair(
|
|
client_ctx,
|
|
client_kwargs={"https_compatible": https_compatible},
|
|
server_kwargs={"https_compatible": https_compatible},
|
|
)
|
|
|
|
async def clogged_stream_maker() -> tuple[
|
|
SSLStream[MyStapledStream],
|
|
SSLStream[MyStapledStream],
|
|
]:
|
|
client, server = ssl_lockstep_stream_pair(client_ctx)
|
|
# If we don't do handshakes up front, then we run into a problem in
|
|
# the following situation:
|
|
# - server does wait_send_all_might_not_block
|
|
# - client does receive_some to unclog it
|
|
# Then the client's receive_some will actually send some data to start
|
|
# the handshake, and itself get stuck.
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
return client, server
|
|
|
|
await check_two_way_stream(stream_maker, clogged_stream_maker)
|
|
|
|
|
|
async def test_unwrap(client_ctx: SSLContext) -> None:
|
|
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
|
|
client_transport = client_ssl.transport_stream
|
|
server_transport = server_ssl.transport_stream
|
|
|
|
seq = Sequencer()
|
|
|
|
async def client() -> None:
|
|
await client_ssl.do_handshake()
|
|
await client_ssl.send_all(b"x")
|
|
assert await client_ssl.receive_some(1) == b"y"
|
|
await client_ssl.send_all(b"z")
|
|
|
|
# After sending that, disable outgoing data from our end, to make
|
|
# sure the server doesn't see our EOF until after we've sent some
|
|
# trailing data
|
|
async with seq(0):
|
|
send_all_hook = client_transport.send_stream.send_all_hook
|
|
client_transport.send_stream.send_all_hook = None
|
|
|
|
assert await client_ssl.receive_some(1) == b""
|
|
assert client_ssl.transport_stream is client_transport
|
|
# We just received EOF. Unwrap the connection and send some more.
|
|
raw, trailing = await client_ssl.unwrap()
|
|
assert raw is client_transport
|
|
assert trailing == b""
|
|
assert client_ssl.transport_stream is None
|
|
await raw.send_all(b"trailing")
|
|
|
|
# Reconnect the streams. Now the server will receive both our shutdown
|
|
# acknowledgement + the trailing data in a single lump.
|
|
client_transport.send_stream.send_all_hook = send_all_hook
|
|
await client_transport.send_stream.send_all_hook()
|
|
|
|
async def server() -> None:
|
|
await server_ssl.do_handshake()
|
|
assert await server_ssl.receive_some(1) == b"x"
|
|
await server_ssl.send_all(b"y")
|
|
assert await server_ssl.receive_some(1) == b"z"
|
|
# Now client is blocked waiting for us to send something, but
|
|
# instead we close the TLS connection (with sequencer to make sure
|
|
# that the client won't see and automatically respond before we've had
|
|
# a chance to disable the client->server transport)
|
|
async with seq(1):
|
|
raw, trailing = await server_ssl.unwrap()
|
|
assert raw is server_transport
|
|
assert trailing == b"trailing"
|
|
assert server_ssl.transport_stream is None
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client)
|
|
nursery.start_soon(server)
|
|
|
|
|
|
async def test_closing_nice_case(client_ctx: SSLContext) -> None:
|
|
# the nice case: graceful closes all around
|
|
|
|
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
|
|
client_transport = client_ssl.transport_stream
|
|
|
|
# Both the handshake and the close require back-and-forth discussion, so
|
|
# we need to run them concurrently
|
|
async def client_closer() -> None:
|
|
with assert_checkpoints():
|
|
await client_ssl.aclose()
|
|
|
|
async def server_closer() -> None:
|
|
assert await server_ssl.receive_some(10) == b""
|
|
assert await server_ssl.receive_some(10) == b""
|
|
with assert_checkpoints():
|
|
await server_ssl.aclose()
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client_closer)
|
|
nursery.start_soon(server_closer)
|
|
|
|
# closing the SSLStream also closes its transport
|
|
with pytest.raises(ClosedResourceError):
|
|
await client_transport.send_all(b"123")
|
|
|
|
# once closed, it's OK to close again
|
|
with assert_checkpoints():
|
|
await client_ssl.aclose()
|
|
with assert_checkpoints():
|
|
await client_ssl.aclose()
|
|
|
|
# Trying to send more data does not work
|
|
with pytest.raises(ClosedResourceError):
|
|
await server_ssl.send_all(b"123")
|
|
|
|
# And once the connection is has been closed *locally*, then instead of
|
|
# getting empty bytestrings we get a proper error
|
|
with pytest.raises(ClosedResourceError):
|
|
assert await client_ssl.receive_some(10) == b""
|
|
|
|
with pytest.raises(ClosedResourceError):
|
|
await client_ssl.unwrap()
|
|
|
|
with pytest.raises(ClosedResourceError):
|
|
await client_ssl.do_handshake()
|
|
|
|
# Check that a graceful close *before* handshaking gives a clean EOF on
|
|
# the other side
|
|
client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async def expect_eof_server() -> None:
|
|
with assert_checkpoints():
|
|
assert await server_ssl.receive_some(10) == b""
|
|
with assert_checkpoints():
|
|
await server_ssl.aclose()
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client_ssl.aclose)
|
|
nursery.start_soon(expect_eof_server)
|
|
|
|
|
|
async def test_send_all_fails_in_the_middle(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
async def bad_hook() -> NoReturn:
|
|
raise KeyError
|
|
|
|
client.transport_stream.send_stream.send_all_hook = bad_hook
|
|
|
|
with pytest.raises(KeyError):
|
|
await client.send_all(b"x")
|
|
|
|
with pytest.raises(BrokenResourceError):
|
|
await client.wait_send_all_might_not_block()
|
|
|
|
closed = 0
|
|
|
|
def close_hook() -> None:
|
|
nonlocal closed
|
|
closed += 1
|
|
|
|
client.transport_stream.send_stream.close_hook = close_hook
|
|
client.transport_stream.receive_stream.close_hook = close_hook
|
|
await client.aclose()
|
|
|
|
assert closed == 2
|
|
|
|
|
|
async def test_ssl_over_ssl(client_ctx: SSLContext) -> None:
|
|
client_0, server_0 = memory_stream_pair()
|
|
|
|
client_1 = SSLStream(
|
|
client_0, client_ctx, server_hostname="trio-test-1.example.org"
|
|
)
|
|
server_1 = SSLStream(server_0, SERVER_CTX, server_side=True)
|
|
|
|
client_2 = SSLStream(
|
|
client_1, client_ctx, server_hostname="trio-test-1.example.org"
|
|
)
|
|
server_2 = SSLStream(server_1, SERVER_CTX, server_side=True)
|
|
|
|
async def client() -> None:
|
|
await client_2.send_all(b"hi")
|
|
assert await client_2.receive_some(10) == b"bye"
|
|
|
|
async def server() -> None:
|
|
assert await server_2.receive_some(10) == b"hi"
|
|
await server_2.send_all(b"bye")
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client)
|
|
nursery.start_soon(server)
|
|
|
|
|
|
async def test_ssl_bad_shutdown(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
await trio.aclose_forcefully(client)
|
|
# now the server sees a broken stream
|
|
with pytest.raises(BrokenResourceError):
|
|
await server.receive_some(10)
|
|
with pytest.raises(BrokenResourceError):
|
|
await server.send_all(b"x" * 10)
|
|
|
|
await server.aclose()
|
|
|
|
|
|
async def test_ssl_bad_shutdown_but_its_ok(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(
|
|
client_ctx,
|
|
server_kwargs={"https_compatible": True},
|
|
client_kwargs={"https_compatible": True},
|
|
)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
await trio.aclose_forcefully(client)
|
|
# the server sees that as a clean shutdown
|
|
assert await server.receive_some(10) == b""
|
|
with pytest.raises(BrokenResourceError):
|
|
await server.send_all(b"x" * 10)
|
|
|
|
await server.aclose()
|
|
|
|
|
|
async def test_ssl_handshake_failure_during_aclose() -> None:
|
|
# Weird scenario: aclose() triggers an automatic handshake, and this
|
|
# fails. This also exercises a bit of code in aclose() that was otherwise
|
|
# uncovered, for re-raising exceptions after calling aclose_forcefully on
|
|
# the underlying transport.
|
|
async with ssl_echo_server_raw(expect_fail=True) as sock:
|
|
# Don't configure trust correctly
|
|
client_ctx = ssl.create_default_context()
|
|
s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
|
|
# It's a little unclear here whether aclose should swallow the error
|
|
# or let it escape. We *do* swallow the error if it arrives when we're
|
|
# sending close_notify, because both sides closing the connection
|
|
# simultaneously is allowed. But I guess when https_compatible=False
|
|
# then it's bad if we can get through a whole connection with a peer
|
|
# that has no valid certificate, and never raise an error.
|
|
with pytest.raises(BrokenResourceError):
|
|
await s.aclose()
|
|
|
|
|
|
async def test_ssl_only_closes_stream_once(client_ctx: SSLContext) -> None:
|
|
# We used to have a bug where if transport_stream.aclose() raised an
|
|
# error, we would call it again. This checks that that's fixed.
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
client_orig_close_hook = client.transport_stream.send_stream.close_hook
|
|
transport_close_count = 0
|
|
|
|
def close_hook() -> NoReturn:
|
|
nonlocal transport_close_count
|
|
assert client_orig_close_hook is not None
|
|
client_orig_close_hook()
|
|
transport_close_count += 1
|
|
raise KeyError
|
|
|
|
client.transport_stream.send_stream.close_hook = close_hook
|
|
|
|
with pytest.raises(KeyError):
|
|
await client.aclose()
|
|
assert transport_close_count == 1
|
|
|
|
|
|
async def test_ssl_https_compatibility_disagreement(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(
|
|
client_ctx,
|
|
server_kwargs={"https_compatible": False},
|
|
client_kwargs={"https_compatible": True},
|
|
)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
# client is in HTTPS-mode, server is not
|
|
# so client doing graceful_shutdown causes an error on server
|
|
async def receive_and_expect_error() -> None:
|
|
with pytest.raises(BrokenResourceError) as excinfo:
|
|
await server.receive_some(10)
|
|
|
|
assert _is_eof(excinfo.value.__cause__)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.aclose)
|
|
nursery.start_soon(receive_and_expect_error)
|
|
|
|
|
|
async def test_https_mode_eof_before_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(
|
|
client_ctx,
|
|
server_kwargs={"https_compatible": True},
|
|
client_kwargs={"https_compatible": True},
|
|
)
|
|
|
|
async def server_expect_clean_eof() -> None:
|
|
assert await server.receive_some(10) == b""
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.aclose)
|
|
nursery.start_soon(server_expect_clean_eof)
|
|
|
|
|
|
async def test_send_error_during_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async def bad_hook() -> NoReturn:
|
|
raise KeyError
|
|
|
|
client.transport_stream.send_stream.send_all_hook = bad_hook
|
|
|
|
with pytest.raises(KeyError):
|
|
with assert_checkpoints():
|
|
await client.do_handshake()
|
|
|
|
with pytest.raises(BrokenResourceError):
|
|
with assert_checkpoints():
|
|
await client.do_handshake()
|
|
|
|
|
|
async def test_receive_error_during_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async def bad_hook() -> NoReturn:
|
|
raise KeyError
|
|
|
|
client.transport_stream.receive_stream.receive_some_hook = bad_hook
|
|
|
|
async def client_side(cancel_scope: CancelScope) -> None:
|
|
with pytest.raises(KeyError):
|
|
with assert_checkpoints():
|
|
await client.do_handshake()
|
|
cancel_scope.cancel()
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client_side, nursery.cancel_scope)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
with pytest.raises(BrokenResourceError):
|
|
with assert_checkpoints():
|
|
await client.do_handshake()
|
|
|
|
|
|
async def test_selected_alpn_protocol_before_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
client.selected_alpn_protocol()
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
server.selected_alpn_protocol()
|
|
|
|
|
|
async def test_selected_alpn_protocol_when_not_set(client_ctx: SSLContext) -> None:
|
|
# ALPN protocol still returns None when it's not set,
|
|
# instead of raising an exception
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
assert client.selected_alpn_protocol() is None
|
|
assert server.selected_alpn_protocol() is None
|
|
|
|
assert client.selected_alpn_protocol() == server.selected_alpn_protocol()
|
|
|
|
|
|
async def test_selected_npn_protocol_before_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
client.selected_npn_protocol()
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
server.selected_npn_protocol()
|
|
|
|
|
|
@pytest.mark.filterwarnings(
|
|
r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning",
|
|
r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning",
|
|
)
|
|
async def test_selected_npn_protocol_when_not_set(client_ctx: SSLContext) -> None:
|
|
# NPN protocol still returns None when it's not set,
|
|
# instead of raising an exception
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
assert client.selected_npn_protocol() is None
|
|
assert server.selected_npn_protocol() is None
|
|
|
|
assert client.selected_npn_protocol() == server.selected_npn_protocol()
|
|
|
|
|
|
async def test_get_channel_binding_before_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
client.get_channel_binding()
|
|
|
|
with pytest.raises(NeedHandshakeError):
|
|
server.get_channel_binding()
|
|
|
|
|
|
async def test_get_channel_binding_after_handshake(client_ctx: SSLContext) -> None:
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
assert client.get_channel_binding() is not None
|
|
assert server.get_channel_binding() is not None
|
|
|
|
assert client.get_channel_binding() == server.get_channel_binding()
|
|
|
|
|
|
async def test_getpeercert(client_ctx: SSLContext) -> None:
|
|
# Make sure we're not affected by https://bugs.python.org/issue29334
|
|
client, server = ssl_memory_stream_pair(client_ctx)
|
|
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(client.do_handshake)
|
|
nursery.start_soon(server.do_handshake)
|
|
|
|
assert server.getpeercert() is None
|
|
print(client.getpeercert())
|
|
assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"]
|
|
|
|
|
|
async def test_SSLListener(client_ctx: SSLContext) -> None:
|
|
async def setup(
|
|
**kwargs: Any,
|
|
) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]:
|
|
listen_sock = tsocket.socket()
|
|
await listen_sock.bind(("127.0.0.1", 0))
|
|
listen_sock.listen(1)
|
|
socket_listener = SocketListener(listen_sock)
|
|
ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs)
|
|
|
|
transport_client = await open_tcp_stream(*listen_sock.getsockname())
|
|
ssl_client = SSLStream(
|
|
transport_client, client_ctx, server_hostname="trio-test-1.example.org"
|
|
)
|
|
return listen_sock, ssl_listener, ssl_client
|
|
|
|
listen_sock, ssl_listener, ssl_client = await setup()
|
|
|
|
async with ssl_client:
|
|
ssl_server = await ssl_listener.accept()
|
|
|
|
async with ssl_server:
|
|
assert not ssl_server._https_compatible
|
|
|
|
# Make sure the connection works
|
|
async with _core.open_nursery() as nursery:
|
|
nursery.start_soon(ssl_client.do_handshake)
|
|
nursery.start_soon(ssl_server.do_handshake)
|
|
|
|
# Test SSLListener.aclose
|
|
await ssl_listener.aclose()
|
|
assert listen_sock.fileno() == -1
|
|
|
|
################
|
|
|
|
# Test https_compatible
|
|
_, ssl_listener, ssl_client = await setup(https_compatible=True)
|
|
|
|
ssl_server = await ssl_listener.accept()
|
|
|
|
assert ssl_server._https_compatible
|
|
|
|
await aclose_forcefully(ssl_listener)
|
|
await aclose_forcefully(ssl_client)
|
|
await aclose_forcefully(ssl_server)
|