import contextlib import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union import torchgen.local as local from torchgen.model import ( BackendIndex, DispatchKey, NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, ) from torchgen.utils import context, S, T # Helper functions for defining generators on things in the model F = TypeVar( "F", NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, Union[NativeFunction, NativeFunctionsGroup], Union[NativeFunction, NativeFunctionsViewGroup], ) F2 = TypeVar( "F2", NativeFunction, NativeFunctionsGroup, Optional[NativeFunction], bool, str, ) F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction]) @contextlib.contextmanager def native_function_manager( g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction] ) -> Iterator[None]: if isinstance(g, NativeFunctionsGroup): # By default, we associate all errors with structured native functions # with the out variant. In some cases, it might be better to have # a more specific place to hang things; if so, use # native_function_manager again on the inside f = g.out elif isinstance(g, NativeFunctionsViewGroup): # We associate errors with the view operator f = g.view else: f = g with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): with local.parametrize( use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, use_ilistref_for_tensor_lists=f.part_of_structured_group, ): yield # Given a function that operates on NativeFunction, wrap it into a new function # that sets some appropriate context managers for that native function. # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound # (you will get an error if we try to access the local variables without having # set them). def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: @functools.wraps(func) def wrapper(f: F) -> T: with native_function_manager(f): return func(f) return wrapper def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: @functools.wraps(func) def wrapper(f: F, f2: F2) -> T: # The first native_function is assumed to be the one with the appropriate context. with native_function_manager(f): return func(f, f2) return wrapper def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: @functools.wraps(func) def wrapper(slf: S, f: F) -> T: with native_function_manager(f): return func(slf, f) return wrapper def method_with_nested_native_function( func: Callable[[S, F3], T] ) -> Callable[[S, F3], T]: @functools.wraps(func) def wrapper(slf: S, f: F3) -> T: with native_function_manager(f[0]): return func(slf, f) return wrapper # Convenience decorator for functions that explicitly take in a BackendIndex, # instead of indirectly taking one in as a closure def with_native_function_and_index( func: Callable[[F, BackendIndex], T] ) -> Callable[[F, BackendIndex], T]: @functools.wraps(func) def wrapper(f: F, backend_index: BackendIndex) -> T: with native_function_manager(f): return func(f, backend_index) return wrapper # Convenience decorator for functions that explicitly take in a Dict of BackendIndices def with_native_function_and_indices( func: Callable[[F, Dict[DispatchKey, BackendIndex]], T] ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]: @functools.wraps(func) def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T: with native_function_manager(f): return func(f, backend_indices) return wrapper