You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

471 lines
13 KiB

from __future__ import annotations
import dataclasses
import enum
import io
import secrets
import struct
from typing import Callable, Generator, Optional, Sequence, Tuple
from . import exceptions, extensions
from .typing import Data
try:
from .speedups import apply_mask
except ImportError:
from .utils import apply_mask
__all__ = [
"Opcode",
"OP_CONT",
"OP_TEXT",
"OP_BINARY",
"OP_CLOSE",
"OP_PING",
"OP_PONG",
"DATA_OPCODES",
"CTRL_OPCODES",
"Frame",
"prepare_data",
"prepare_ctrl",
"Close",
]
class Opcode(enum.IntEnum):
"""Opcode values for WebSocket frames."""
CONT, TEXT, BINARY = 0x00, 0x01, 0x02
CLOSE, PING, PONG = 0x08, 0x09, 0x0A
OP_CONT = Opcode.CONT
OP_TEXT = Opcode.TEXT
OP_BINARY = Opcode.BINARY
OP_CLOSE = Opcode.CLOSE
OP_PING = Opcode.PING
OP_PONG = Opcode.PONG
DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
class CloseCode(enum.IntEnum):
"""Close code values for WebSocket close frames."""
NORMAL_CLOSURE = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
# 1004 is reserved
NO_STATUS_RCVD = 1005
ABNORMAL_CLOSURE = 1006
INVALID_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
TLS_HANDSHAKE = 1015
# See https://www.iana.org/assignments/websocket/websocket.xhtml
CLOSE_CODE_EXPLANATIONS: dict[int, str] = {
CloseCode.NORMAL_CLOSURE: "OK",
CloseCode.GOING_AWAY: "going away",
CloseCode.PROTOCOL_ERROR: "protocol error",
CloseCode.UNSUPPORTED_DATA: "unsupported data",
CloseCode.NO_STATUS_RCVD: "no status received [internal]",
CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]",
CloseCode.INVALID_DATA: "invalid frame payload data",
CloseCode.POLICY_VIOLATION: "policy violation",
CloseCode.MESSAGE_TOO_BIG: "message too big",
CloseCode.MANDATORY_EXTENSION: "mandatory extension",
CloseCode.INTERNAL_ERROR: "internal error",
CloseCode.SERVICE_RESTART: "service restart",
CloseCode.TRY_AGAIN_LATER: "try again later",
CloseCode.BAD_GATEWAY: "bad gateway",
CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]",
}
# Close code that are allowed in a close frame.
# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
EXTERNAL_CLOSE_CODES = {
CloseCode.NORMAL_CLOSURE,
CloseCode.GOING_AWAY,
CloseCode.PROTOCOL_ERROR,
CloseCode.UNSUPPORTED_DATA,
CloseCode.INVALID_DATA,
CloseCode.POLICY_VIOLATION,
CloseCode.MESSAGE_TOO_BIG,
CloseCode.MANDATORY_EXTENSION,
CloseCode.INTERNAL_ERROR,
CloseCode.SERVICE_RESTART,
CloseCode.TRY_AGAIN_LATER,
CloseCode.BAD_GATEWAY,
}
OK_CLOSE_CODES = {
CloseCode.NORMAL_CLOSURE,
CloseCode.GOING_AWAY,
CloseCode.NO_STATUS_RCVD,
}
BytesLike = bytes, bytearray, memoryview
@dataclasses.dataclass
class Frame:
"""
WebSocket frame.
Attributes:
opcode: Opcode.
data: Payload data.
fin: FIN bit.
rsv1: RSV1 bit.
rsv2: RSV2 bit.
rsv3: RSV3 bit.
Only these fields are needed. The MASK bit, payload length and masking-key
are handled on the fly when parsing and serializing frames.
"""
opcode: Opcode
data: bytes
fin: bool = True
rsv1: bool = False
rsv2: bool = False
rsv3: bool = False
def __str__(self) -> str:
"""
Return a human-readable representation of a frame.
"""
coding = None
length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
non_final = "" if self.fin else "continued"
if self.opcode is OP_TEXT:
# Decoding only the beginning and the end is needlessly hard.
# Decode the entire payload then elide later if necessary.
data = repr(self.data.decode())
elif self.opcode is OP_BINARY:
# We'll show at most the first 16 bytes and the last 8 bytes.
# Encode just what we need, plus two dummy bytes to elide later.
binary = self.data
if len(binary) > 25:
binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]])
data = " ".join(f"{byte:02x}" for byte in binary)
elif self.opcode is OP_CLOSE:
data = str(Close.parse(self.data))
elif self.data:
# We don't know if a Continuation frame contains text or binary.
# Ping and Pong frames could contain UTF-8.
# Attempt to decode as UTF-8 and display it as text; fallback to
# binary. If self.data is a memoryview, it has no decode() method,
# which raises AttributeError.
try:
data = repr(self.data.decode())
coding = "text"
except (UnicodeDecodeError, AttributeError):
binary = self.data
if len(binary) > 25:
binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]])
data = " ".join(f"{byte:02x}" for byte in binary)
coding = "binary"
else:
data = "''"
if len(data) > 75:
data = data[:48] + "..." + data[-24:]
metadata = ", ".join(filter(None, [coding, length, non_final]))
return f"{self.opcode.name} {data} [{metadata}]"
@classmethod
def parse(
cls,
read_exact: Callable[[int], Generator[None, None, bytes]],
*,
mask: bool,
max_size: Optional[int] = None,
extensions: Optional[Sequence[extensions.Extension]] = None,
) -> Generator[None, None, Frame]:
"""
Parse a WebSocket frame.
This is a generator-based coroutine.
Args:
read_exact: generator-based coroutine that reads the requested
bytes or raises an exception if there isn't enough data.
mask: whether the frame should be masked i.e. whether the read
happens on the server side.
max_size: maximum payload size in bytes.
extensions: list of extensions, applied in reverse order.
Raises:
EOFError: if the connection is closed without a full WebSocket frame.
UnicodeDecodeError: if the frame contains invalid UTF-8.
PayloadTooBig: if the frame's payload size exceeds ``max_size``.
ProtocolError: if the frame contains incorrect values.
"""
# Read the header.
data = yield from read_exact(2)
head1, head2 = struct.unpack("!BB", data)
# While not Pythonic, this is marginally faster than calling bool().
fin = True if head1 & 0b10000000 else False
rsv1 = True if head1 & 0b01000000 else False
rsv2 = True if head1 & 0b00100000 else False
rsv3 = True if head1 & 0b00010000 else False
try:
opcode = Opcode(head1 & 0b00001111)
except ValueError as exc:
raise exceptions.ProtocolError("invalid opcode") from exc
if (True if head2 & 0b10000000 else False) != mask:
raise exceptions.ProtocolError("incorrect masking")
length = head2 & 0b01111111
if length == 126:
data = yield from read_exact(2)
(length,) = struct.unpack("!H", data)
elif length == 127:
data = yield from read_exact(8)
(length,) = struct.unpack("!Q", data)
if max_size is not None and length > max_size:
raise exceptions.PayloadTooBig(
f"over size limit ({length} > {max_size} bytes)"
)
if mask:
mask_bytes = yield from read_exact(4)
# Read the data.
data = yield from read_exact(length)
if mask:
data = apply_mask(data, mask_bytes)
frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
if extensions is None:
extensions = []
for extension in reversed(extensions):
frame = extension.decode(frame, max_size=max_size)
frame.check()
return frame
def serialize(
self,
*,
mask: bool,
extensions: Optional[Sequence[extensions.Extension]] = None,
) -> bytes:
"""
Serialize a WebSocket frame.
Args:
mask: whether the frame should be masked i.e. whether the write
happens on the client side.
extensions: list of extensions, applied in order.
Raises:
ProtocolError: if the frame contains incorrect values.
"""
self.check()
if extensions is None:
extensions = []
for extension in extensions:
self = extension.encode(self)
output = io.BytesIO()
# Prepare the header.
head1 = (
(0b10000000 if self.fin else 0)
| (0b01000000 if self.rsv1 else 0)
| (0b00100000 if self.rsv2 else 0)
| (0b00010000 if self.rsv3 else 0)
| self.opcode
)
head2 = 0b10000000 if mask else 0
length = len(self.data)
if length < 126:
output.write(struct.pack("!BB", head1, head2 | length))
elif length < 65536:
output.write(struct.pack("!BBH", head1, head2 | 126, length))
else:
output.write(struct.pack("!BBQ", head1, head2 | 127, length))
if mask:
mask_bytes = secrets.token_bytes(4)
output.write(mask_bytes)
# Prepare the data.
if mask:
data = apply_mask(self.data, mask_bytes)
else:
data = self.data
output.write(data)
return output.getvalue()
def check(self) -> None:
"""
Check that reserved bits and opcode have acceptable values.
Raises:
ProtocolError: if a reserved bit or the opcode is invalid.
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise exceptions.ProtocolError("reserved bits must be 0")
if self.opcode in CTRL_OPCODES:
if len(self.data) > 125:
raise exceptions.ProtocolError("control frame too long")
if not self.fin:
raise exceptions.ProtocolError("fragmented control frame")
def prepare_data(data: Data) -> Tuple[int, bytes]:
"""
Convert a string or byte-like object to an opcode and a bytes-like object.
This function is designed for data frames.
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
object encoding ``data`` in UTF-8.
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
object.
Raises:
TypeError: if ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return OP_TEXT, data.encode("utf-8")
elif isinstance(data, BytesLike):
return OP_BINARY, data
else:
raise TypeError("data must be str or bytes-like")
def prepare_ctrl(data: Data) -> bytes:
"""
Convert a string or byte-like object to bytes.
This function is designed for ping and pong frames.
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
If ``data`` is a bytes-like object, return a :class:`bytes` object.
Raises:
TypeError: if ``data`` doesn't have a supported type.
"""
if isinstance(data, str):
return data.encode("utf-8")
elif isinstance(data, BytesLike):
return bytes(data)
else:
raise TypeError("data must be str or bytes-like")
@dataclasses.dataclass
class Close:
"""
Code and reason for WebSocket close frames.
Attributes:
code: Close code.
reason: Close reason.
"""
code: int
reason: str
def __str__(self) -> str:
"""
Return a human-readable representation of a close code and reason.
"""
if 3000 <= self.code < 4000:
explanation = "registered"
elif 4000 <= self.code < 5000:
explanation = "private use"
else:
explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown")
result = f"{self.code} ({explanation})"
if self.reason:
result = f"{result} {self.reason}"
return result
@classmethod
def parse(cls, data: bytes) -> Close:
"""
Parse the payload of a close frame.
Args:
data: payload of the close frame.
Raises:
ProtocolError: if data is ill-formed.
UnicodeDecodeError: if the reason isn't valid UTF-8.
"""
if len(data) >= 2:
(code,) = struct.unpack("!H", data[:2])
reason = data[2:].decode("utf-8")
close = cls(code, reason)
close.check()
return close
elif len(data) == 0:
return cls(CloseCode.NO_STATUS_RCVD, "")
else:
raise exceptions.ProtocolError("close frame too short")
def serialize(self) -> bytes:
"""
Serialize the payload of a close frame.
"""
self.check()
return struct.pack("!H", self.code) + self.reason.encode("utf-8")
def check(self) -> None:
"""
Check that the close code has a valid value for a close frame.
Raises:
ProtocolError: if the close code is invalid.
"""
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
raise exceptions.ProtocolError("invalid status code")