You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

502 lines
15 KiB

6 months ago
from __future__ import annotations
import io
from functools import partial
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
BinaryIO,
Callable,
Generic,
Iterable,
TypeVar,
Union,
overload,
)
import trio
from ._util import async_wraps
from .abc import AsyncResource
if TYPE_CHECKING:
from _typeshed import (
OpenBinaryMode,
OpenBinaryModeReading,
OpenBinaryModeUpdating,
OpenBinaryModeWriting,
OpenTextMode,
StrOrBytesPath,
)
from typing_extensions import Literal
# This list is also in the docs, make sure to keep them in sync
_FILE_SYNC_ATTRS: set[str] = {
"closed",
"encoding",
"errors",
"fileno",
"isatty",
"newlines",
"readable",
"seekable",
"writable",
# not defined in *IOBase:
"buffer",
"raw",
"line_buffering",
"closefd",
"name",
"mode",
"getvalue",
"getbuffer",
}
# This list is also in the docs, make sure to keep them in sync
_FILE_ASYNC_METHODS: set[str] = {
"flush",
"read",
"read1",
"readall",
"readinto",
"readline",
"readlines",
"seek",
"tell",
"truncate",
"write",
"writelines",
# not defined in *IOBase:
"readinto1",
"peek",
}
FileT = TypeVar("FileT")
FileT_co = TypeVar("FileT_co", covariant=True)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)
AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True)
# This is a little complicated. IO objects have a lot of methods, and which are available on
# different types varies wildly. We want to match the interface of whatever file we're wrapping.
# This pile of protocols each has one sync method/property, meaning they're going to be compatible
# with a file class that supports that method/property. The ones parameterized with AnyStr take
# either str or bytes depending.
# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're
# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be
# conditional - it's only valid to call them if the object you're accessing them on is compatible
# with that type hint. By using the protocols, the type checker will be checking to see if the
# wrapped type has that method, and only allow the methods that do to be called. We can then alter
# the signature however it needs to match runtime behaviour.
# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types
if TYPE_CHECKING:
from typing_extensions import Buffer, Protocol
# fmt: off
class _HasClosed(Protocol):
@property
def closed(self) -> bool: ...
class _HasEncoding(Protocol):
@property
def encoding(self) -> str: ...
class _HasErrors(Protocol):
@property
def errors(self) -> str | None: ...
class _HasFileNo(Protocol):
def fileno(self) -> int: ...
class _HasIsATTY(Protocol):
def isatty(self) -> bool: ...
class _HasNewlines(Protocol[T_co]):
# Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any.
@property
def newlines(self) -> T_co: ...
class _HasReadable(Protocol):
def readable(self) -> bool: ...
class _HasSeekable(Protocol):
def seekable(self) -> bool: ...
class _HasWritable(Protocol):
def writable(self) -> bool: ...
class _HasBuffer(Protocol):
@property
def buffer(self) -> BinaryIO: ...
class _HasRaw(Protocol):
@property
def raw(self) -> io.RawIOBase: ...
class _HasLineBuffering(Protocol):
@property
def line_buffering(self) -> bool: ...
class _HasCloseFD(Protocol):
@property
def closefd(self) -> bool: ...
class _HasName(Protocol):
@property
def name(self) -> str: ...
class _HasMode(Protocol):
@property
def mode(self) -> str: ...
class _CanGetValue(Protocol[AnyStr_co]):
def getvalue(self) -> AnyStr_co: ...
class _CanGetBuffer(Protocol):
def getbuffer(self) -> memoryview: ...
class _CanFlush(Protocol):
def flush(self) -> None: ...
class _CanRead(Protocol[AnyStr_co]):
def read(self, size: int | None = ..., /) -> AnyStr_co: ...
class _CanRead1(Protocol):
def read1(self, size: int | None = ..., /) -> bytes: ...
class _CanReadAll(Protocol[AnyStr_co]):
def readall(self) -> AnyStr_co: ...
class _CanReadInto(Protocol):
def readinto(self, buf: Buffer, /) -> int | None: ...
class _CanReadInto1(Protocol):
def readinto1(self, buffer: Buffer, /) -> int: ...
class _CanReadLine(Protocol[AnyStr_co]):
def readline(self, size: int = ..., /) -> AnyStr_co: ...
class _CanReadLines(Protocol[AnyStr]):
def readlines(self, hint: int = ..., /) -> list[AnyStr]: ...
class _CanSeek(Protocol):
def seek(self, target: int, whence: int = 0, /) -> int: ...
class _CanTell(Protocol):
def tell(self) -> int: ...
class _CanTruncate(Protocol):
def truncate(self, size: int | None = ..., /) -> int: ...
class _CanWrite(Protocol[T_contra]):
def write(self, data: T_contra, /) -> int: ...
class _CanWriteLines(Protocol[T_contra]):
# The lines parameter varies for bytes/str, so use a typevar to make the async match.
def writelines(self, lines: Iterable[T_contra], /) -> None: ...
class _CanPeek(Protocol[AnyStr_co]):
def peek(self, size: int = 0, /) -> AnyStr_co: ...
class _CanDetach(Protocol[T_co]):
# The T typevar will be the unbuffered/binary file this file wraps.
def detach(self) -> T_co: ...
class _CanClose(Protocol):
def close(self) -> None: ...
# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a
# subtype of the protocols.
class AsyncIOWrapper(AsyncResource, Generic[FileT_co]):
"""A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous
file object` interface. Wrapped methods that could block are executed in
:meth:`trio.to_thread.run_sync`.
All properties and methods defined in :mod:`~io` are exposed by this
wrapper, if they exist in the wrapped file object.
"""
def __init__(self, file: FileT_co) -> None:
self._wrapped = file
@property
def wrapped(self) -> FileT_co:
"""object: A reference to the wrapped file object"""
return self._wrapped
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> object:
if name in _FILE_SYNC_ATTRS:
return getattr(self._wrapped, name)
if name in _FILE_ASYNC_METHODS:
meth = getattr(self._wrapped, name)
@async_wraps(self.__class__, self._wrapped.__class__, name)
async def wrapper(*args, **kwargs):
func = partial(meth, *args, **kwargs)
return await trio.to_thread.run_sync(func)
# cache the generated method
setattr(self, name, wrapper)
return wrapper
raise AttributeError(name)
def __dir__(self) -> Iterable[str]:
attrs = set(super().__dir__())
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a))
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a))
return attrs
def __aiter__(self) -> AsyncIOWrapper[FileT_co]:
return self
async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr:
line = await self.readline()
if line:
return line
else:
raise StopAsyncIteration
async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]:
"""Like :meth:`io.BufferedIOBase.detach`, but async.
This also re-wraps the result in a new :term:`asynchronous file object`
wrapper.
"""
raw = await trio.to_thread.run_sync(self._wrapped.detach)
return wrap_file(raw)
async def aclose(self: AsyncIOWrapper[_CanClose]) -> None:
"""Like :meth:`io.IOBase.close`, but async.
This is also shielded from cancellation; if a cancellation scope is
cancelled, the wrapped file object will still be safely closed.
"""
# ensure the underling file is closed during cancellation
with trio.CancelScope(shield=True):
await trio.to_thread.run_sync(self._wrapped.close)
await trio.lowlevel.checkpoint_if_cancelled()
if TYPE_CHECKING:
# fmt: off
# Based on typing.IO and io stubs.
@property
def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ...
@property
def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ...
@property
def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ...
@property
def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ...
@property
def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ...
@property
def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ...
@property
def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ...
@property
def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ...
@property
def name(self: AsyncIOWrapper[_HasName]) -> str: ...
@property
def mode(self: AsyncIOWrapper[_HasMode]) -> str: ...
def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ...
def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ...
def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ...
def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ...
def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ...
def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ...
def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ...
async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ...
async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ...
async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ...
async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ...
async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ...
async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ...
async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ...
async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ...
async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ...
async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ...
async def write(self: AsyncIOWrapper[_CanWrite[T]], data: T, /) -> int: ...
async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ...
async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ...
async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ...
# Type hints are copied from builtin open.
_OpenFile = Union["StrOrBytesPath", int]
_Opener = Callable[[str, int], int]
@overload
async def open_file(
file: _OpenFile,
mode: OpenTextMode = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.TextIOWrapper]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryMode,
buffering: Literal[0],
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.FileIO]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeUpdating,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedRandom]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeWriting,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedWriter]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeReading,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedReader]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryMode,
buffering: int,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[BinaryIO]: ...
@overload
async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
file: _OpenFile,
mode: str,
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[IO[Any]]: ...
async def open_file(
file: _OpenFile,
mode: str = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[Any]:
"""Asynchronous version of :func:`open`.
Returns:
An :term:`asynchronous file object`
Example::
async with await trio.open_file(filename) as f:
async for line in f:
pass
assert f.closed
See also:
:func:`trio.Path.open`
"""
_file = wrap_file(
await trio.to_thread.run_sync(
io.open, file, mode, buffering, encoding, errors, newline, closefd, opener
)
)
return _file
def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]:
"""This wraps any file object in a wrapper that provides an asynchronous
file object interface.
Args:
file: a :term:`file object`
Returns:
An :term:`asynchronous file object` that wraps ``file``
Example::
async_file = trio.wrap_file(StringIO('asdf'))
assert await async_file.read() == 'asdf'
"""
def has(attr: str) -> bool:
return hasattr(file, attr) and callable(getattr(file, attr))
if not (has("close") and (has("read") or has("write"))):
raise TypeError(
f"{file} does not implement required duck-file methods: "
"close and (read or write)"
)
return AsyncIOWrapper(file)