You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.9 KiB

import contextlib
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
import torchgen.local as local
from torchgen.model import (
BackendIndex,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
from torchgen.utils import context, S, T
# Helper functions for defining generators on things in the model
F = TypeVar(
"F",
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
Union[NativeFunction, NativeFunctionsGroup],
Union[NativeFunction, NativeFunctionsViewGroup],
)
F2 = TypeVar(
"F2",
NativeFunction,
NativeFunctionsGroup,
Optional[NativeFunction],
bool,
str,
)
F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
@contextlib.contextmanager
def native_function_manager(
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions
# with the out variant. In some cases, it might be better to have
# a more specific place to hang things; if so, use
# native_function_manager again on the inside
f = g.out
elif isinstance(g, NativeFunctionsViewGroup):
# We associate errors with the view operator
f = g.view
else:
f = g
with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
with local.parametrize(
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=f.part_of_structured_group,
):
yield
# Given a function that operates on NativeFunction, wrap it into a new function
# that sets some appropriate context managers for that native function.
# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
# (you will get an error if we try to access the local variables without having
# set them).
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
@functools.wraps(func)
def wrapper(f: F) -> T:
with native_function_manager(f):
return func(f)
return wrapper
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
@functools.wraps(func)
def wrapper(f: F, f2: F2) -> T:
# The first native_function is assumed to be the one with the appropriate context.
with native_function_manager(f):
return func(f, f2)
return wrapper
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F) -> T:
with native_function_manager(f):
return func(slf, f)
return wrapper
def method_with_nested_native_function(
func: Callable[[S, F3], T]
) -> Callable[[S, F3], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F3) -> T:
with native_function_manager(f[0]):
return func(slf, f)
return wrapper
# Convenience decorator for functions that explicitly take in a BackendIndex,
# instead of indirectly taking one in as a closure
def with_native_function_and_index(
func: Callable[[F, BackendIndex], T]
) -> Callable[[F, BackendIndex], T]:
@functools.wraps(func)
def wrapper(f: F, backend_index: BackendIndex) -> T:
with native_function_manager(f):
return func(f, backend_index)
return wrapper
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices(
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f):
return func(f, backend_indices)
return wrapper