You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
2.5 KiB
103 lines
2.5 KiB
from typing import (
|
|
Any,
|
|
Iterable,
|
|
NamedTuple,
|
|
Optional,
|
|
overload,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from typing_extensions import Self
|
|
|
|
from torch import Tensor
|
|
|
|
from torch._prims_common import DeviceLikeType
|
|
from torch.types import _dtype
|
|
|
|
class PackedSequence_(NamedTuple):
|
|
data: Tensor
|
|
batch_sizes: Tensor
|
|
sorted_indices: Optional[Tensor]
|
|
unsorted_indices: Optional[Tensor]
|
|
|
|
def bind(optional: Any, fn: Any): ...
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
class PackedSequence(PackedSequence_):
|
|
def __new__(
|
|
cls,
|
|
data: Tensor,
|
|
batch_sizes: Optional[Tensor] = ...,
|
|
sorted_indices: Optional[Tensor] = ...,
|
|
unsorted_indices: Optional[Tensor] = ...,
|
|
) -> Self: ...
|
|
def pin_memory(self: _T) -> _T: ...
|
|
def cuda(self: _T, *args: Any, **kwargs: Any) -> _T: ...
|
|
def cpu(self: _T) -> _T: ...
|
|
def double(self: _T) -> _T: ...
|
|
def float(self: _T) -> _T: ...
|
|
def half(self: _T) -> _T: ...
|
|
def long(self: _T) -> _T: ...
|
|
def int(self: _T) -> _T: ...
|
|
def short(self: _T) -> _T: ...
|
|
def char(self: _T) -> _T: ...
|
|
def byte(self: _T) -> _T: ...
|
|
@overload
|
|
def to(
|
|
self: _T,
|
|
dtype: _dtype,
|
|
non_blocking: bool = False,
|
|
copy: bool = False,
|
|
) -> _T: ...
|
|
@overload
|
|
def to(
|
|
self: _T,
|
|
device: Optional[DeviceLikeType] = None,
|
|
dtype: Optional[_dtype] = None,
|
|
non_blocking: bool = False,
|
|
copy: bool = False,
|
|
) -> _T: ...
|
|
@overload
|
|
def to(
|
|
self: _T,
|
|
other: Tensor,
|
|
non_blocking: bool = False,
|
|
copy: bool = False,
|
|
) -> _T: ...
|
|
@property
|
|
def is_cuda(self) -> bool: ...
|
|
def is_pinned(self) -> bool: ...
|
|
|
|
def invert_permutation(permutation: Optional[Tensor]): ...
|
|
def pack_padded_sequence(
|
|
input: Tensor,
|
|
lengths: Tensor,
|
|
batch_first: bool = ...,
|
|
enforce_sorted: bool = ...,
|
|
) -> PackedSequence: ...
|
|
def pad_packed_sequence(
|
|
sequence: PackedSequence,
|
|
batch_first: bool = ...,
|
|
padding_value: float = ...,
|
|
total_length: Optional[int] = ...,
|
|
) -> Tuple[Tensor, ...]: ...
|
|
def pad_sequence(
|
|
sequences: Union[Tensor, Iterable[Tensor]],
|
|
batch_first: bool = False,
|
|
padding_value: float = ...,
|
|
) -> Tensor: ...
|
|
def pack_sequence(
|
|
sequences: Sequence[Tensor],
|
|
enforce_sorted: bool = ...,
|
|
) -> PackedSequence: ...
|
|
def get_packed_sequence(
|
|
data: Tensor,
|
|
batch_sizes: Optional[Tensor],
|
|
sorted_indices: Optional[Tensor],
|
|
unsorted_indices: Optional[Tensor],
|
|
) -> PackedSequence: ...
|