You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
502 lines
15 KiB
502 lines
15 KiB
6 months ago
|
from __future__ import annotations
|
||
|
|
||
|
import io
|
||
|
from functools import partial
|
||
|
from typing import (
|
||
|
IO,
|
||
|
TYPE_CHECKING,
|
||
|
Any,
|
||
|
AnyStr,
|
||
|
BinaryIO,
|
||
|
Callable,
|
||
|
Generic,
|
||
|
Iterable,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
overload,
|
||
|
)
|
||
|
|
||
|
import trio
|
||
|
|
||
|
from ._util import async_wraps
|
||
|
from .abc import AsyncResource
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from _typeshed import (
|
||
|
OpenBinaryMode,
|
||
|
OpenBinaryModeReading,
|
||
|
OpenBinaryModeUpdating,
|
||
|
OpenBinaryModeWriting,
|
||
|
OpenTextMode,
|
||
|
StrOrBytesPath,
|
||
|
)
|
||
|
from typing_extensions import Literal
|
||
|
|
||
|
# This list is also in the docs, make sure to keep them in sync
|
||
|
_FILE_SYNC_ATTRS: set[str] = {
|
||
|
"closed",
|
||
|
"encoding",
|
||
|
"errors",
|
||
|
"fileno",
|
||
|
"isatty",
|
||
|
"newlines",
|
||
|
"readable",
|
||
|
"seekable",
|
||
|
"writable",
|
||
|
# not defined in *IOBase:
|
||
|
"buffer",
|
||
|
"raw",
|
||
|
"line_buffering",
|
||
|
"closefd",
|
||
|
"name",
|
||
|
"mode",
|
||
|
"getvalue",
|
||
|
"getbuffer",
|
||
|
}
|
||
|
|
||
|
# This list is also in the docs, make sure to keep them in sync
|
||
|
_FILE_ASYNC_METHODS: set[str] = {
|
||
|
"flush",
|
||
|
"read",
|
||
|
"read1",
|
||
|
"readall",
|
||
|
"readinto",
|
||
|
"readline",
|
||
|
"readlines",
|
||
|
"seek",
|
||
|
"tell",
|
||
|
"truncate",
|
||
|
"write",
|
||
|
"writelines",
|
||
|
# not defined in *IOBase:
|
||
|
"readinto1",
|
||
|
"peek",
|
||
|
}
|
||
|
|
||
|
|
||
|
FileT = TypeVar("FileT")
|
||
|
FileT_co = TypeVar("FileT_co", covariant=True)
|
||
|
T = TypeVar("T")
|
||
|
T_co = TypeVar("T_co", covariant=True)
|
||
|
T_contra = TypeVar("T_contra", contravariant=True)
|
||
|
AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)
|
||
|
AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True)
|
||
|
|
||
|
# This is a little complicated. IO objects have a lot of methods, and which are available on
|
||
|
# different types varies wildly. We want to match the interface of whatever file we're wrapping.
|
||
|
# This pile of protocols each has one sync method/property, meaning they're going to be compatible
|
||
|
# with a file class that supports that method/property. The ones parameterized with AnyStr take
|
||
|
# either str or bytes depending.
|
||
|
|
||
|
# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're
|
||
|
# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be
|
||
|
# conditional - it's only valid to call them if the object you're accessing them on is compatible
|
||
|
# with that type hint. By using the protocols, the type checker will be checking to see if the
|
||
|
# wrapped type has that method, and only allow the methods that do to be called. We can then alter
|
||
|
# the signature however it needs to match runtime behaviour.
|
||
|
# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types
|
||
|
if TYPE_CHECKING:
|
||
|
from typing_extensions import Buffer, Protocol
|
||
|
|
||
|
# fmt: off
|
||
|
|
||
|
class _HasClosed(Protocol):
|
||
|
@property
|
||
|
def closed(self) -> bool: ...
|
||
|
|
||
|
class _HasEncoding(Protocol):
|
||
|
@property
|
||
|
def encoding(self) -> str: ...
|
||
|
|
||
|
class _HasErrors(Protocol):
|
||
|
@property
|
||
|
def errors(self) -> str | None: ...
|
||
|
|
||
|
class _HasFileNo(Protocol):
|
||
|
def fileno(self) -> int: ...
|
||
|
|
||
|
class _HasIsATTY(Protocol):
|
||
|
def isatty(self) -> bool: ...
|
||
|
|
||
|
class _HasNewlines(Protocol[T_co]):
|
||
|
# Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any.
|
||
|
@property
|
||
|
def newlines(self) -> T_co: ...
|
||
|
|
||
|
class _HasReadable(Protocol):
|
||
|
def readable(self) -> bool: ...
|
||
|
|
||
|
class _HasSeekable(Protocol):
|
||
|
def seekable(self) -> bool: ...
|
||
|
|
||
|
class _HasWritable(Protocol):
|
||
|
def writable(self) -> bool: ...
|
||
|
|
||
|
class _HasBuffer(Protocol):
|
||
|
@property
|
||
|
def buffer(self) -> BinaryIO: ...
|
||
|
|
||
|
class _HasRaw(Protocol):
|
||
|
@property
|
||
|
def raw(self) -> io.RawIOBase: ...
|
||
|
|
||
|
class _HasLineBuffering(Protocol):
|
||
|
@property
|
||
|
def line_buffering(self) -> bool: ...
|
||
|
|
||
|
class _HasCloseFD(Protocol):
|
||
|
@property
|
||
|
def closefd(self) -> bool: ...
|
||
|
|
||
|
class _HasName(Protocol):
|
||
|
@property
|
||
|
def name(self) -> str: ...
|
||
|
|
||
|
class _HasMode(Protocol):
|
||
|
@property
|
||
|
def mode(self) -> str: ...
|
||
|
|
||
|
class _CanGetValue(Protocol[AnyStr_co]):
|
||
|
def getvalue(self) -> AnyStr_co: ...
|
||
|
|
||
|
class _CanGetBuffer(Protocol):
|
||
|
def getbuffer(self) -> memoryview: ...
|
||
|
|
||
|
class _CanFlush(Protocol):
|
||
|
def flush(self) -> None: ...
|
||
|
|
||
|
class _CanRead(Protocol[AnyStr_co]):
|
||
|
def read(self, size: int | None = ..., /) -> AnyStr_co: ...
|
||
|
|
||
|
class _CanRead1(Protocol):
|
||
|
def read1(self, size: int | None = ..., /) -> bytes: ...
|
||
|
|
||
|
class _CanReadAll(Protocol[AnyStr_co]):
|
||
|
def readall(self) -> AnyStr_co: ...
|
||
|
|
||
|
class _CanReadInto(Protocol):
|
||
|
def readinto(self, buf: Buffer, /) -> int | None: ...
|
||
|
|
||
|
class _CanReadInto1(Protocol):
|
||
|
def readinto1(self, buffer: Buffer, /) -> int: ...
|
||
|
|
||
|
class _CanReadLine(Protocol[AnyStr_co]):
|
||
|
def readline(self, size: int = ..., /) -> AnyStr_co: ...
|
||
|
|
||
|
class _CanReadLines(Protocol[AnyStr]):
|
||
|
def readlines(self, hint: int = ..., /) -> list[AnyStr]: ...
|
||
|
|
||
|
class _CanSeek(Protocol):
|
||
|
def seek(self, target: int, whence: int = 0, /) -> int: ...
|
||
|
|
||
|
class _CanTell(Protocol):
|
||
|
def tell(self) -> int: ...
|
||
|
|
||
|
class _CanTruncate(Protocol):
|
||
|
def truncate(self, size: int | None = ..., /) -> int: ...
|
||
|
|
||
|
class _CanWrite(Protocol[T_contra]):
|
||
|
def write(self, data: T_contra, /) -> int: ...
|
||
|
|
||
|
class _CanWriteLines(Protocol[T_contra]):
|
||
|
# The lines parameter varies for bytes/str, so use a typevar to make the async match.
|
||
|
def writelines(self, lines: Iterable[T_contra], /) -> None: ...
|
||
|
|
||
|
class _CanPeek(Protocol[AnyStr_co]):
|
||
|
def peek(self, size: int = 0, /) -> AnyStr_co: ...
|
||
|
|
||
|
class _CanDetach(Protocol[T_co]):
|
||
|
# The T typevar will be the unbuffered/binary file this file wraps.
|
||
|
def detach(self) -> T_co: ...
|
||
|
|
||
|
class _CanClose(Protocol):
|
||
|
def close(self) -> None: ...
|
||
|
|
||
|
|
||
|
# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a
|
||
|
# subtype of the protocols.
|
||
|
class AsyncIOWrapper(AsyncResource, Generic[FileT_co]):
|
||
|
"""A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous
|
||
|
file object` interface. Wrapped methods that could block are executed in
|
||
|
:meth:`trio.to_thread.run_sync`.
|
||
|
|
||
|
All properties and methods defined in :mod:`~io` are exposed by this
|
||
|
wrapper, if they exist in the wrapped file object.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, file: FileT_co) -> None:
|
||
|
self._wrapped = file
|
||
|
|
||
|
@property
|
||
|
def wrapped(self) -> FileT_co:
|
||
|
"""object: A reference to the wrapped file object"""
|
||
|
|
||
|
return self._wrapped
|
||
|
|
||
|
if not TYPE_CHECKING:
|
||
|
|
||
|
def __getattr__(self, name: str) -> object:
|
||
|
if name in _FILE_SYNC_ATTRS:
|
||
|
return getattr(self._wrapped, name)
|
||
|
if name in _FILE_ASYNC_METHODS:
|
||
|
meth = getattr(self._wrapped, name)
|
||
|
|
||
|
@async_wraps(self.__class__, self._wrapped.__class__, name)
|
||
|
async def wrapper(*args, **kwargs):
|
||
|
func = partial(meth, *args, **kwargs)
|
||
|
return await trio.to_thread.run_sync(func)
|
||
|
|
||
|
# cache the generated method
|
||
|
setattr(self, name, wrapper)
|
||
|
return wrapper
|
||
|
|
||
|
raise AttributeError(name)
|
||
|
|
||
|
def __dir__(self) -> Iterable[str]:
|
||
|
attrs = set(super().__dir__())
|
||
|
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a))
|
||
|
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a))
|
||
|
return attrs
|
||
|
|
||
|
def __aiter__(self) -> AsyncIOWrapper[FileT_co]:
|
||
|
return self
|
||
|
|
||
|
async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr:
|
||
|
line = await self.readline()
|
||
|
if line:
|
||
|
return line
|
||
|
else:
|
||
|
raise StopAsyncIteration
|
||
|
|
||
|
async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]:
|
||
|
"""Like :meth:`io.BufferedIOBase.detach`, but async.
|
||
|
|
||
|
This also re-wraps the result in a new :term:`asynchronous file object`
|
||
|
wrapper.
|
||
|
|
||
|
"""
|
||
|
|
||
|
raw = await trio.to_thread.run_sync(self._wrapped.detach)
|
||
|
return wrap_file(raw)
|
||
|
|
||
|
async def aclose(self: AsyncIOWrapper[_CanClose]) -> None:
|
||
|
"""Like :meth:`io.IOBase.close`, but async.
|
||
|
|
||
|
This is also shielded from cancellation; if a cancellation scope is
|
||
|
cancelled, the wrapped file object will still be safely closed.
|
||
|
|
||
|
"""
|
||
|
|
||
|
# ensure the underling file is closed during cancellation
|
||
|
with trio.CancelScope(shield=True):
|
||
|
await trio.to_thread.run_sync(self._wrapped.close)
|
||
|
|
||
|
await trio.lowlevel.checkpoint_if_cancelled()
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
# fmt: off
|
||
|
# Based on typing.IO and io stubs.
|
||
|
@property
|
||
|
def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ...
|
||
|
@property
|
||
|
def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ...
|
||
|
@property
|
||
|
def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ...
|
||
|
@property
|
||
|
def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ...
|
||
|
@property
|
||
|
def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ...
|
||
|
@property
|
||
|
def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ...
|
||
|
@property
|
||
|
def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ...
|
||
|
@property
|
||
|
def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ...
|
||
|
@property
|
||
|
def name(self: AsyncIOWrapper[_HasName]) -> str: ...
|
||
|
@property
|
||
|
def mode(self: AsyncIOWrapper[_HasMode]) -> str: ...
|
||
|
|
||
|
def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ...
|
||
|
def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ...
|
||
|
def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ...
|
||
|
def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ...
|
||
|
def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ...
|
||
|
def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ...
|
||
|
def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ...
|
||
|
async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ...
|
||
|
async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ...
|
||
|
async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ...
|
||
|
async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ...
|
||
|
async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ...
|
||
|
async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ...
|
||
|
async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ...
|
||
|
async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ...
|
||
|
async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ...
|
||
|
async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ...
|
||
|
async def write(self: AsyncIOWrapper[_CanWrite[T]], data: T, /) -> int: ...
|
||
|
async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ...
|
||
|
async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ...
|
||
|
async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ...
|
||
|
|
||
|
|
||
|
# Type hints are copied from builtin open.
|
||
|
_OpenFile = Union["StrOrBytesPath", int]
|
||
|
_Opener = Callable[[str, int], int]
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenTextMode = "r",
|
||
|
buffering: int = -1,
|
||
|
encoding: str | None = None,
|
||
|
errors: str | None = None,
|
||
|
newline: str | None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[io.TextIOWrapper]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenBinaryMode,
|
||
|
buffering: Literal[0],
|
||
|
encoding: None = None,
|
||
|
errors: None = None,
|
||
|
newline: None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[io.FileIO]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenBinaryModeUpdating,
|
||
|
buffering: Literal[-1, 1] = -1,
|
||
|
encoding: None = None,
|
||
|
errors: None = None,
|
||
|
newline: None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[io.BufferedRandom]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenBinaryModeWriting,
|
||
|
buffering: Literal[-1, 1] = -1,
|
||
|
encoding: None = None,
|
||
|
errors: None = None,
|
||
|
newline: None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[io.BufferedWriter]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenBinaryModeReading,
|
||
|
buffering: Literal[-1, 1] = -1,
|
||
|
encoding: None = None,
|
||
|
errors: None = None,
|
||
|
newline: None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[io.BufferedReader]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: OpenBinaryMode,
|
||
|
buffering: int,
|
||
|
encoding: None = None,
|
||
|
errors: None = None,
|
||
|
newline: None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[BinaryIO]: ...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
|
||
|
file: _OpenFile,
|
||
|
mode: str,
|
||
|
buffering: int = -1,
|
||
|
encoding: str | None = None,
|
||
|
errors: str | None = None,
|
||
|
newline: str | None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[IO[Any]]: ...
|
||
|
|
||
|
|
||
|
async def open_file(
|
||
|
file: _OpenFile,
|
||
|
mode: str = "r",
|
||
|
buffering: int = -1,
|
||
|
encoding: str | None = None,
|
||
|
errors: str | None = None,
|
||
|
newline: str | None = None,
|
||
|
closefd: bool = True,
|
||
|
opener: _Opener | None = None,
|
||
|
) -> AsyncIOWrapper[Any]:
|
||
|
"""Asynchronous version of :func:`open`.
|
||
|
|
||
|
Returns:
|
||
|
An :term:`asynchronous file object`
|
||
|
|
||
|
Example::
|
||
|
|
||
|
async with await trio.open_file(filename) as f:
|
||
|
async for line in f:
|
||
|
pass
|
||
|
|
||
|
assert f.closed
|
||
|
|
||
|
See also:
|
||
|
:func:`trio.Path.open`
|
||
|
|
||
|
"""
|
||
|
_file = wrap_file(
|
||
|
await trio.to_thread.run_sync(
|
||
|
io.open, file, mode, buffering, encoding, errors, newline, closefd, opener
|
||
|
)
|
||
|
)
|
||
|
return _file
|
||
|
|
||
|
|
||
|
def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]:
|
||
|
"""This wraps any file object in a wrapper that provides an asynchronous
|
||
|
file object interface.
|
||
|
|
||
|
Args:
|
||
|
file: a :term:`file object`
|
||
|
|
||
|
Returns:
|
||
|
An :term:`asynchronous file object` that wraps ``file``
|
||
|
|
||
|
Example::
|
||
|
|
||
|
async_file = trio.wrap_file(StringIO('asdf'))
|
||
|
|
||
|
assert await async_file.read() == 'asdf'
|
||
|
|
||
|
"""
|
||
|
|
||
|
def has(attr: str) -> bool:
|
||
|
return hasattr(file, attr) and callable(getattr(file, attr))
|
||
|
|
||
|
if not (has("close") and (has("read") or has("write"))):
|
||
|
raise TypeError(
|
||
|
f"{file} does not implement required duck-file methods: "
|
||
|
"close and (read or write)"
|
||
|
)
|
||
|
|
||
|
return AsyncIOWrapper(file)
|