You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2938 lines
108 KiB

5 months ago
import argparse
import functools
import json
import os
import pathlib
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import yaml
import torchgen.api.dispatcher as dispatcher
import torchgen.api.meta as meta
import torchgen.api.native as native
import torchgen.api.structured as structured
import torchgen.dest as dest
from torchgen.api import cpp
from torchgen.api.translate import translate
from torchgen.api.types import (
Binding,
CppSignature,
CppSignatureGroup,
DispatcherSignature,
NamedCType,
NativeSignature,
SpecialArgName,
)
from torchgen.context import (
method_with_native_function,
native_function_manager,
with_native_function,
with_native_function_and_indices,
)
from torchgen.gen_aoti_c_shim import (
gen_aoti_c_shim,
gen_static_dispatch_backend_call_signature,
get_backend_index_for_aoti,
)
from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
GenCompositeViewCopyKernel,
)
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseOperatorName,
DEFAULT_KERNEL_NAMESPACE,
DispatchKey,
FRAGMENT_NAMESPACES,
FunctionSchema,
is_cuda_dispatch_key,
is_generic_dispatch_key,
is_ufunc_dispatch_key,
Location,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
OperatorName,
OptionalType,
SchemaKind,
SelfArgument,
STRUCTURED_DISPATCH_KEYS,
TensorOptionsArguments,
Type,
Variant,
ViewSchemaKind,
)
from torchgen.native_function_generation import (
add_generated_native_functions,
gen_composite_functional_kernel,
gen_composite_out_kernel,
pre_group_native_functions,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import (
assert_never,
concatMap,
context,
FileManager,
make_file_manager,
mapMaybe,
NamespaceHelper,
Target,
)
from torchgen.yaml_utils import YamlDumper, YamlLoader
T = TypeVar("T")
# Welcome to the ATen code generator v2! The ATen code generator is
# responsible for parsing native_functions.yaml and then generating
# various generated files (e.g., TypeDefault.cpp) based on the operators
# defined in this file. This means that the code generator knows how to
# parse function schema, and then translate this into various C++ types
# and boilerplate code.
#
# Some things to know about this file when you modify it:
#
# - This file has STRICT mypy typechecking. Typecheck it with
# `mypy --config mypy-strict.ini` in the root source directory
#
# - Most of the heavy lifting lives in external modules:
# - 'model' has the data model for native_functions.yaml. The classes
# in those file represent what you see when you look at
# a native_functions.yaml
# - 'api' has conversions for how to translate JIT schema into
# the various C++ APIs that the codegen interacts with. There
# are in fact THREE different C++ APIs: the public C++ API,
# the dispatcher API, and the legacy dispatcher API. See each
# of these respective files for more information
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# HELPER FUNCTIONS
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# A custom loader for YAML to let us also keep track of line numbers
# of each entry in the YAML file
class LineLoader(YamlLoader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
# Add 1 so line numbering starts at 1
mapping["__line__"] = node.start_mark.line + 1
return mapping
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {}
def parse_native_yaml_struct(
es: object,
valid_tags: Set[str],
ignore_keys: Optional[Set[DispatchKey]] = None,
path: str = "<stdin>",
skip_native_fns_gen: bool = False,
) -> ParsedYaml:
assert isinstance(es, list)
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
funcs = e.get("func")
with context(lambda: f"in {loc}:\n {funcs}"):
func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
rs.append(func)
BackendIndex.grow_index(bs, m)
error_check_native_functions(rs)
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined,
use_out_as_primary=True,
external=False,
device_guard=False,
# I'm actually not sure about this; undefined could be hit on
# empty TensorList, hypothetically that could have sizes in it
index={},
)
)
if not skip_native_fns_gen:
add_generated_native_functions(rs, bs)
for k, v in bs.items():
# All structured in-tree operators are implemented in terms of their out operator.
indices[k] = BackendIndex(
dispatch_key=k,
use_out_as_primary=True,
external=False,
# Only cuda-like devices in tree require device guards
device_guard=is_cuda_dispatch_key(k),
index=v,
)
return ParsedYaml(rs, indices)
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
assert isinstance(es, list)
rs: Set[str] = set()
for e in es:
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
tags = e.get("tag")
with context(lambda: f"in {loc}:\n {tags}"):
e_i = e.copy()
name = e_i.pop("tag")
desc = e_i.pop("desc", "")
# ensure that each tag has a non-empty description
assert desc != ""
rs.add(name)
return rs
@functools.lru_cache(maxsize=None)
def parse_tags_yaml(path: str) -> Set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
def parse_native_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: Optional[Set[DispatchKey]] = None,
*,
skip_native_fns_gen: bool = False,
loaded_yaml: Optional[object] = None,
) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
valid_tags = parse_tags_yaml(tags_yaml_path)
# if a loaded yaml is provided, use that instead of reading from path
if loaded_yaml is None:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
else:
es = loaded_yaml
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
es,
valid_tags,
ignore_keys,
path=path,
skip_native_fns_gen=skip_native_fns_gen,
)
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
# Assertions here are meant to be performed across NativeFunctions.
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: Dict[OperatorName, NativeFunction] = {}
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
for f in funcs:
func_map[f.func.name] = f
base_func_map[f.func.name.name].append(f)
for f in funcs:
if f.structured_delegate is not None:
delegate_func = func_map[f.structured_delegate]
assert delegate_func.structured, (
f"{f.func.name} is marked as a structured_delegate pointing to "
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
f"Consider adding 'structured=True' to the delegated operator"
)
# See Note [resize_ in Functionalization]
# resize_() is technically an inplace view op (and therefore needs the tag),
# but it would be overkill to add a true "view" variant of resize.
# Instead, resize_() gets special treatment in functionalization,
# and we have a resize() op that is non-aliasing + functional.
if (
"inplace_view" in f.tags
and str(f.func.name) != "resize_"
and str(f.func.name) != "resize_as_"
and str(f.func.name.name) != "set_"
):
base_name = f.func.name.name
overload_name = f.func.name.overload_name
assert base_name.inplace, (
f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
"convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
)
out_of_place_base_name = BaseOperatorName(
base_name.base, False, base_name.dunder_method
)
assert len(base_func_map[out_of_place_base_name]) > 0, (
f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
)
def cpp_string(s: str) -> str:
"""Convert a python string into a c++ string literal"""
s = s.replace("\\", "\\\\")
s = s.replace('"', '\\"')
s = s.replace("\a", "\\a")
s = s.replace("\b", "\\b")
s = s.replace("\f", "\\f")
s = s.replace("\n", "\\n")
s = s.replace("\v", "\\v")
s = s.replace("\t", "\\t")
return f'"{s}"'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# C++ CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Most functions in this section are curried: they consist of a function
# that takes some parameters (e.g., what is to be generated) which itself
# returns a function that actually maps NativeFunction to the code
# to be generated. This pattern makes it convenient to use map, concatMap
# and similar functional combinators.
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
if len(backends) == 0:
return []
else:
return [backend.dispatch_key for backend in backends] + [
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
]
def get_static_dispatch_backend(
f: NativeFunction, backend_index: BackendIndex
) -> Optional[DispatchKey]:
if f.structured_delegate is not None or backend_index.has_kernel(f):
# TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
# so we always dispatch to the `backend`, but this could be wrong when we
# migrate math/default_backend ops to use structured delegate.
return backend_index.dispatch_key
elif f.has_composite_explicit_autograd_kernel:
return DispatchKey.CompositeExplicitAutograd
elif f.has_composite_explicit_autograd_non_functional_kernel:
return DispatchKey.CompositeExplicitAutogradNonFunctional
elif f.has_composite_implicit_autograd_kernel:
return DispatchKey.CompositeImplicitAutograd
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
return DispatchKey.CompositeImplicitAutogradNestedTensor
return None
def static_dispatch_ops_header(
f: NativeFunction, backend_index: List[BackendIndex]
) -> Optional[str]:
if backend_index is None or f.manual_kernel_registration:
return None
output = []
for index in backend_index:
dispatch_key = get_static_dispatch_backend(f, index)
if dispatch_key is not None:
output.append(
f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
)
return "\n".join(output)
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
return [
f"#include <ATen/{dispatch_key}Functions.h>"
for dispatch_key in static_dispatch_keys(backends)
]
# Translates arguments of `sig` to CppSignature bindings.
# Note that we have a special case for `memory_format` argument and this case is not covered by
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
def translate_args(
sig: Union[CppSignature, DispatcherSignature],
cpp_sig: CppSignature,
) -> str:
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
output_bindings: List[Binding] = []
for binding in input_bindings:
if binding.name == "memory_format":
spl_mem_format_binding = Binding(
nctype=NamedCType(
SpecialArgName.possibly_redundant_memory_format,
binding.nctype.type,
),
name=binding.name,
default=binding.default,
argument=binding.argument,
)
output_bindings.append(spl_mem_format_binding)
else:
output_bindings.append(binding)
return output_bindings
src_bindings = list(sig.arguments())
goal_bindings = list(cpp_sig.arguments())
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
# get memory_format bindings of dispatcher signature to have the same NCType as well
for arg in goal_bindings:
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
src_bindings = add_spl_memory_format_binding(src_bindings)
break
exprs = translate(src_bindings, goal_bindings)
return ", ".join(a.expr for a in exprs)
def generate_static_dispatch_backend_call(
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
name = cpp_sig.name()
exprs = translate_args(sig, cpp_sig)
backend_metadata = backend_index.get_kernel(f)
kernel_ns = (
backend_metadata.cpp_namespace
if backend_metadata and backend_metadata.cpp_namespace
else DEFAULT_KERNEL_NAMESPACE
)
ns = kernel_ns.replace("::native", "")
return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
def generate_static_dispatch_fallback_call(
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
name = cpp_sig.name()
exprs = translate_args(sig, cpp_sig)
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
if f.has_composite_explicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
elif f.has_composite_explicit_autograd_non_functional_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
elif f.has_composite_implicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
def static_dispatch(
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
backends exsit, fallback to static dispatch by determining dispatch key from inputs.
Arguments:
sig: A CppSignature or DispatcherSignature for this native function we want to use.
f: NativeFunction to generate static dispatch.
backend_indices: All available backends.
Return:
C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
"""
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""
keys = [
b
for b in backend_indices
if b.has_kernel(f)
or (
f.structured_delegate is not None
and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
)
]
if len(keys) == 1:
return generate_static_dispatch_backend_call(sig, f, keys[0])
elif len(keys) == 0:
return generate_static_dispatch_fallback_call(sig, f, backend_indices)
native_tensor_args = [
a.name
for a in sig.arguments()
if isinstance(a.argument, SelfArgument)
or isinstance(a.argument, Argument)
and a.argument.type.is_tensor_like()
]
tensor_args = ", ".join(native_tensor_args)
tensor_opts = f.func.arguments.tensor_options
stmts = []
subexprs: List[str] = []
if tensor_opts is not None:
subexprs.append(
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
)
if tensor_args != "":
subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
dispatch_code = []
for index in keys:
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
dispatch_code.append(
f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
)
fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
connector = "\n\t\t"
return f"""
{connector.join(stmts)}
switch (_dk) {{
{connector.join(dispatch_code)}
default:
{fallback}
}}
"""
# Generates RegisterSchema.cpp. Depending on the selector, either
# all schemas are registered, or only some are (in the case of
# selective build)
@dataclass(frozen=True)
class RegisterSchema:
selector: SelectiveBuilder
known_tags: Dict[str, int] = field(default_factory=dict)
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_native_function_selected(f):
return None
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
if tags == "{}":
return f"m.def({cpp_string(str(f.func))}, {{}});\n"
maybe_tags = ""
if tags not in self.known_tags:
idx = len(self.known_tags)
self.known_tags[tags] = idx
maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
# Generates Operators.h and Operators.cpp.
# These provide macros that, given an operator and overload name, allow users
# to access an "un-overloaded" function version of the operator. This
# is useful for extension writers who want to (1) want to decltype the operator
# and (2) don't want to worry about method-only operators.
@dataclass(frozen=True)
class ComputeOperators:
target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name()
if self.target is Target.DECLARATION:
# Note [The ATen Operators API]
# The ATen Operators API lives in the at::_ops namespace, and contains compile-time
# metadata about each operator + entry points into the Dispatcher.
# The C++ function, method, and redispatch API's are all implemented as wrappers
# into various bits of the structs defined here.
#
# Important characteristics about the Operators API:
# (1) It follows the Dispatcher API.
# This is kind of necessary to avoid overhead.
# For example: if it followed the C++ API, then all of the faithful C++ factory functions
# would need to wrap their arguments into TensorOptions only to unwrap them again.
# (2) Overload names are disambiguated.
# This is helpful for pytorch extenders who would like to decltype() an aten operator,
# that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
# (3) No argument defaulting is allowed.
# This is more of an implementation detail to avoid #include cycles,
# since TensorBody.h (which defines the Tensor class) needs to include this file.
# (4) manual_cpp_bindings and faithful names are not included in the API.
# This applies to stuff like __dispatch__is_complex(), and add_outf().
# These aren't "real aten ops", they're just additional functions provided by the C++ API.
# They're implemented as wrappers in Functions.h that call into the actual operators
# defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
# This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
return f"""
struct TORCH_API {name} {{
using schema = {sig.type()};
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
static {sig.defn(name="call", is_redispatching_fn=False)};
static {sig.defn(name="redispatch", is_redispatching_fn=True)};
}};"""
elif self.target is Target.DEFINITION:
defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
// aten::{f.func}
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
return c10::Dispatcher::singleton()
.findSchemaOrThrow({name}::name, {name}::overload_name)
.typed<{name}::schema>();
}}
"""
for is_redispatching_fn in [False, True]:
if is_redispatching_fn:
dispatcher_exprs_str = ", ".join(
["dispatchKeySet"] + [a.name for a in sig.arguments()]
)
method_base = "redispatch"
else:
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
method_base = "call"
dispatcher_call = method_base
method_name = f"{name}::{method_base}"
fn_body = f"""
static auto op = create_{name}_typed_handle();
return op.{dispatcher_call}({dispatcher_exprs_str});"""
if (
not is_redispatching_fn
and len(self.static_dispatch_backend_indices) > 0
):
# call() should go through static dispatch
fn_body = static_dispatch(
sig, f, backend_indices=self.static_dispatch_backend_indices
)
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
{fn_body}
}}
"""
return defns
else:
assert_never(self.target)
# Generates Functions.h, which provides the functional public C++ API,
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
)
has_symint = f.func.has_symint()
result = ""
for sig in sig_group.signatures():
# See Note [The ATen Operators API]
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join([e.expr for e in exprs])
if sig.symint:
intlike_t = "c10::SymInt"
else:
intlike_t = "int64_t"
if Variant.function in f.variants:
result += f"""
// aten::{f.func}
inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}"""
# The template function can be used from template situations
# where you want to switch between the symint or not version
# depending on a template argument
#
# NB: we ALWAYS generate this even for methods. But we put it in
# this header so it can take advantage of per-op headers
if has_symint:
result += f"""
namespace symint {{
template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
{sig.decl(suppress_symint_suffix=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
}}
"""
return result
# Generates TensorBody.h. This file provides the object-oriented (method-based)
# public C++ API, and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeTensorMethod:
target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
return None
assert not f.func.is_out_fn()
assert f.func.arguments.self_arg is not None
sig_group = CppSignatureGroup.from_native_function(
f, method=True, fallback_binding=f.manual_cpp_binding
)
if self.target is Target.DECLARATION:
result = ""
for sig in sig_group.signatures():
result += f"{sig.decl()} const;\n"
return result
if self.target is not Target.DEFINITION:
assert_never(self.target)
result = ""
for sig in sig_group.signatures():
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
exprs_str = ", ".join([e.expr for e in exprs])
result += f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
return result
# Generates RedispatchFunctions.h.
# This is similar to the C++ API defined in Functions.h, but provides access
# to the dispatcher's redispatch API.
@dataclass(frozen=True)
class ComputeRedispatchFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
# We unconditionally generate function variants of the redispatch API.
# This is mainly because we can namespace functions separately, but not methods,
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
)
result = ""
for sig in sig_group.signatures():
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
result += f"""
// aten::{f.func}
inline {sig.decl(is_redispatching_fn=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
}}
"""
return result
# Generates ATenOpList.cpp, a runtime accessible list of all aten
# operators.
# TODO: This was historically used to help some JIT interop code
# figure out whether or not to treat aten namespace'd operators
# one way or another, we should reevaluate if this is actually needed.
@with_native_function
def compute_aten_op(f: NativeFunction) -> str:
return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
# Generates MetaFunctions.h
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
if not g.structured:
return None
with native_function_manager(g.out):
name = meta.name(g)
args = structured.meta_arguments(g)
args_str = ", ".join(a.decl() for a in args)
parent_class = g.out.structured_inherits
if parent_class is None:
parent_class = "at::impl::MetaBase"
meta_return = "void"
precomputed = g.out.precomputed if g.structured else None
if precomputed:
# Generate the template declaration with one bool parameter for each
# precomputed element. Each parameter is true if the corresponding (in
# terms of position) precomputed element has been set.
precomputed_values = [*precomputed.replace.values(), precomputed.add]
precomputed_elements = [
elem for replace_list in precomputed_values for elem in replace_list
]
precomputed_template_parameters = [
elem.name.upper() for elem in precomputed_elements
]
precomputed_template_params_str = ", ".join(
f"bool {param} = false" for param in precomputed_template_parameters
)
precompute_template_decl = f"template <{precomputed_template_params_str}>"
# Generate a string containing declarations of all precomputed elements.
precomputed_elements_with_cpp_types = [
structured.argument_type(elem, binds=elem.name)
for elem in precomputed_elements
]
precomputed_elements_decl = ";\n".join(
f"{elem.cpp_type(strip_ref=True)} {elem.name}"
for elem in precomputed_elements_with_cpp_types
)
# Generate "setter" methods for each precomputed element. Each method will return
# a new instance of precompute_out with the template parameter that corresponds to
# the member set by the method to true (to indicate that it has been set).
setter_methods = []
for i, elem in enumerate(precomputed_elements):
# Generate the signature. The return type will be the same
# as the type of `this` but with the template parameter
# corresponding to the element set by this method set to true.
# The assert generated below will ensure that this template
# parameter is false on the type of `this`.
return_ty_templates = ", ".join(
precomputed_template_parameters[:i]
+ ["true"]
+ precomputed_template_parameters[i + 1 :]
)
return_ty = f"precompute_out<{return_ty_templates}>"
elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
strip_ref=True
)
signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
# Generate an assert which checks that the
# template parameter corresponding to the precomputed
# element that is set by this method is false on the
# class corresponding to the object that `this` points to.
# This ensures that each element can be set only once.
assert_msg = f'"{precomputed_elements[i].name} already set"'
assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
# Generate the new object construction block. All state
# except the element that this method sets is copied from the
# object that `this` points to. The value for the element that
# the method sets is taken from a method parameter.
construction_stmts = []
construction_stmts.append(f"{return_ty} ret;")
for j, elem in enumerate(precomputed_elements):
if i == j:
construction_stmts.append(f"ret.{elem.name} = value;")
else:
construction_stmts.append(
f"ret.{elem.name} = this->{elem.name};"
)
construction_stmts.append("return ret;")
construction_block = "\n".join(construction_stmts)
setter_methods.append(
f"""
{signature} {{
{assert_stmt}
{construction_block}
}}
"""
)
setter_methods_decl = "\n".join(setter_methods)
# Meta should return an instance of the struct containing the precomputed elements.
meta_return_template_params = ", ".join(
["true"] * len(precomputed_template_parameters)
)
# This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
# type (which has a variable number of template parameters).
meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
meta_return = "meta_return_ty"
precomputed_decl = f"""
{precompute_template_decl}
struct TORCH_API precompute_out {{
{setter_methods_decl}
{precomputed_elements_decl};
}};"""
else:
meta_return_typedef = ""
precomputed_decl = ""
return f"""\
struct TORCH_API structured_{name} : public {parent_class} {{
{precomputed_decl}
{meta_return_typedef}
{meta_return} meta({args_str});
}};
"""
def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
name = str(f.func.name.name)
if name.endswith("_like") or name.startswith("new_"):
return False
if f.func.arguments.tensor_options is None:
return False
return selector.is_native_function_selected(f)
# Generates RegisterBackendSelect.cpp, a series of kernels which provide
# specialized computation of dispatch key for operator signatures which cannot
# be easily done automatically using templating.
@dataclass(frozen=True)
class ComputeBackendSelect:
target: Literal[Target.DEFINITION, Target.REGISTRATION]
# Selector object to determine which operators to generate
# registration code for.
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not needs_backend_select(f, self.selector):
return None
name = native.name(f.func)
# BackendSelect can go to Meta, so it must preserve symints
native_sig = NativeSignature(f.func, symint=True)
native_tensor_args = [
a
for a in native_sig.arguments()
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
]
dispatcher_sig = DispatcherSignature.from_schema(f.func)
sig: Union[NativeSignature, DispatcherSignature]
sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
if self.target is Target.DEFINITION:
# I don't think there's actually a good reason to generate
# these two cases differently
# The first case could probably be improved though- it calls computeDispatchKeySet(),
# which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
if native_tensor_args:
assert f.func.arguments.has_tensor_arg()
tensor_args = ", ".join(a.name for a in native_tensor_args)
compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
else:
assert not f.func.arguments.has_tensor_arg()
compute_dk = (
f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
)
return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
{sig.defn(name)} {{
{compute_dk}
return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif self.target is Target.REGISTRATION:
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
else:
assert_never(self.target)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# YAML CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def format_yaml(data: object) -> str:
# Ignore alias in Dumper
YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
# Support serializing OrderedDict
def dict_representer(dumper: Any, data: Any) -> Any:
return dumper.represent_dict(data.items())
YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
# Some yaml parsers (e.g. Haskell's) don't understand line breaks.
# width=1e9 turns off optional line breaks and improves
# the portability of the outputted yaml.
return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
# For some reason, some defaults we write to YAML are written as native
# YAML objects, rather than doing them uniformly as strings. This
# function detects those cases and converts them into native Python
# objects.
def pythonify_default(s: str) -> object:
if s == "true":
return True
elif s == "false":
return False
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
# What is a dynamic type? Over time, the semantic meaning of
# dynamic type has degraded to meaninglessness (in the old days,
# it captured dtype-ness of types, but that has gone away with
# the removal of TH). These days, it's mostly the same thing as
# the C++ API argument type, except that Tensor and Tensor?
# arguments simply present as Tensor.
#
# TODO: Get rid of dynamic_type, after getting tools/autograd
# to use the new codegen framework
def dynamic_type(t: Type) -> str:
if isinstance(t, OptionalType):
return dynamic_type(t.elem)
# Note we don't use t.is_tensor_like() here because it would
# also include Tensor[]
if str(t) == "Tensor":
return "at::Tensor"
# This is a legacy concept, so never report SymInt
return cpp.argumenttype_type(
t, mutable=False, binds="__placeholder__", symint=False
).cpp_type()
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
# This is written out explicitly to ensure that Tensor and
# namespace are put into the list in the right order
method_of = ["Type"]
if Variant.method in variants:
method_of.append("Tensor")
if Variant.function in variants:
method_of.append("namespace")
return method_of
def compute_returns_yaml(
f: NativeFunction,
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
# Note [name and field_name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# To understand name_to_field_name, we must first talk about this
# schema:
#
# lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
#
# There is something very odd about this schema: it is an out
# variant of the function (that is to say, it will convert into
# at::lstsq_out() in the C++ API), but the names of the output
# return arguments don't match the keyword argument names of
# the inputs. It TURNS OUT that in this situation, the historical
# Declarations.yaml we want to output is this (abbreviated to
# only show relevant fields):
#
# arguments:
# ...
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
# ...
#
# returns:
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
#
# The name of the return fields is stored in 'field_name', and the
# name of the arguments is stored in 'name'. So when we process
# arguments, we need a way to get at the corresponding return. At
# the moment, this is most conveniently done by constructing a
# mapping from name (the argument concept) to field_name (the
# return concept) while processing return arguments, since we don't
# directly maintain this correspondence in the modeling of function
# schema itself.
#
# See also https://github.com/pytorch/pytorch/issues/43114
name_to_field_name: Dict[str, str] = {}
# Compute the returns field of the YAML entry
names = cpp.return_names(f)
returns = []
for i, (r, name) in enumerate(zip(f.func.returns, names)):
ret = {
"dynamic_type": dynamic_type(r.type),
"name": name,
# legacy, report ints
"type": cpp.return_type(r, symint=False).cpp_type(),
}
if r.name:
# See Note [name and field_name]
ret["field_name"] = r.name
if f.func.is_out_fn():
name_to_field_name[f.func.arguments.out[i].name] = r.name
returns.append(ret)
return returns, name_to_field_name
# arguments in yaml roughly corresponds to the public C++ API
def compute_cpp_argument_yaml(
cpp_a: Binding,
*,
schema_order: bool,
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
if isinstance(cpp_a.argument, TensorOptionsArguments):
arg: Dict[str, object] = {
"annotation": None,
"dynamic_type": "at::TensorOptions",
"is_nullable": False,
"name": cpp_a.name,
"type": cpp_a.type,
"kwarg_only": True,
}
if cpp_a.default is not None:
arg["default"] = cpp_a.default
return arg
elif isinstance(cpp_a.argument, SelfArgument):
raise AssertionError()
elif isinstance(cpp_a.argument, Argument):
return compute_argument_yaml(
cpp_a.argument,
schema_order=schema_order,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
def compute_argument_yaml(
a: Argument,
*,
schema_order: bool,
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
arg: Dict[str, object] = {
"annotation": str(a.annotation) if a.annotation else None,
"dynamic_type": dynamic_type(a.type),
"is_nullable": a.type.is_nullable(),
"name": a.name,
# legacy, report ints
"type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
}
if a.default is not None:
arg["default"] = pythonify_default(
cpp.default_expr(a.default, a.type, symint=False)
)
if a.name in kwarg_only_set:
arg["kwarg_only"] = True
if a.name in out_arg_set:
arg["output"] = True
arg["allocate"] = True
# See Note [name and field_name]
if a.name in name_to_field_name:
arg["field_name"] = name_to_field_name[a.name]
# Historically, booleans don't get their size recorded, because it
# is already built into the cpp type (e.g., std::array<bool, 4>)
l = a.type.is_list_like()
if l is not None and l.size is not None and str(l.elem) != "bool":
arg["size"] = l.size
return arg
@with_native_function
def compute_declaration_yaml(f: NativeFunction) -> object:
returns, name_to_field_name = compute_returns_yaml(f)
# These sets are used to conveniently test if an argument is a
# kwarg-only or out argument
kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
out_arg_set = {a.name for a in f.func.arguments.out}
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
cpp_args = sig_group.signature.arguments()
arguments = [
compute_cpp_argument_yaml(
cpp_a,
schema_order=False,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
for cpp_a in cpp_args
]
schema_order_jit_arguments = list(f.func.schema_order_arguments())
schema_order_arguments = [
compute_argument_yaml(
a,
schema_order=True,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
for a in schema_order_jit_arguments
]
cpp_schema_order_types = [
# NB: method here doesn't matter
r.type
for a in schema_order_jit_arguments
for r in cpp.argument(
a,
method=False,
cpp_no_default_args=set(),
faithful=False,
symint=False,
has_tensor_options=False,
)
]
# legacy, report ints
cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
is_factory_method = (
any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
and Variant.method not in f.variants
)
return OrderedDict(
[
("name", cpp.name(f.func)),
("operator_name", str(f.func.name.name)),
("overload_name", str(f.func.name.overload_name)),
("manual_kernel_registration", f.manual_kernel_registration),
(
"category_override",
f.category_override if f.category_override is not None else "",
),
("schema_string", f"aten::{f.func}"),
("arguments", arguments),
("schema_order_cpp_signature", schema_order_cpp_signature),
("schema_order_arguments", schema_order_arguments),
("method_of", compute_method_of_yaml(f.variants)),
("mode", "native"),
("python_module", "" if f.python_module is None else f.python_module),
("returns", returns),
("inplace", f.func.name.name.inplace),
("is_factory_method", is_factory_method),
("abstract", f.is_abstract),
("device_guard", f.device_guard),
("with_gil", False),
("deprecated", False),
("has_math_kernel", f.has_composite_implicit_autograd_kernel),
]
)
# See Note [Auto generated composite kernels]
def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
return (f.structured or f.structured_delegate is not None) and (
f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
)
@with_native_function_and_indices
def compute_registration_declarations(
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
) -> str:
name = dispatcher.name(f.func)
returns_type = dispatcher.returns_type(
f.func.returns
).cpp_type_registration_declarations()
args = dispatcher.arguments(f.func)
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
comment_data: Dict[str, str] = {
"schema": f"aten::{f.func}",
# TODO: What exactly is the semantics of the 'dispatch' field?
"dispatch": str(
{k for k, v in backend_indices.items() if v.has_kernel(f)}
!= {DispatchKey.CompositeImplicitAutograd}
and {k for k, v in backend_indices.items() if v.has_kernel(f)}
!= {
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
}
),
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
}
return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# RUN IT ALL
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def get_custom_build_selector(
provided_op_registration_allowlist: Optional[List[str]],
op_selection_yaml_path: Optional[str],
) -> SelectiveBuilder:
assert not (
provided_op_registration_allowlist is not None
and op_selection_yaml_path is not None
), (
"Both provided_op_registration_allowlist and "
+ "op_selection_yaml_path can NOT be provided at the "
+ "same time."
)
op_registration_allowlist: Optional[Set[str]] = None
if provided_op_registration_allowlist is not None:
op_registration_allowlist = set(provided_op_registration_allowlist)
if op_registration_allowlist is not None:
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
op_registration_allowlist,
True,
False,
)
elif op_selection_yaml_path is not None:
selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
else:
selector = SelectiveBuilder.get_nop_selector()
return selector
def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
def maybe_create_view_group(
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
if ViewSchemaKind.aliasing in d:
view = d.pop(ViewSchemaKind.aliasing)
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
view_copy = d.pop(SchemaKind.functional, None)
funcs.append(
NativeFunctionsViewGroup(
view=view,
view_copy=view_copy,
view_inplace=view_inplace,
)
)
# Take the remaining functions that weren't part of the view group
# and emit them separately
funcs.extend(d.values())
return funcs
grouped_by_views: Dict[
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
] = defaultdict(dict)
for f in native_functions:
schema = f.func.view_signature()
view_kind: ViewSchemaKind = f.view_schema_kind
# We need to group up ops relevant to the same "view", consisting of:
# view op (ViewSchemaKind.aliasing)
# view_inplace op (ViewSchemaKind.aliasing_inplace)
# view_copy op (SchemaKind.functional)
if view_kind == ViewSchemaKind.non_aliasing:
kind = f.func.kind()
assert kind not in grouped_by_views[schema]
grouped_by_views[schema][kind] = f
else:
assert view_kind not in grouped_by_views[schema]
grouped_by_views[schema][view_kind] = f
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
def get_grouped_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
def flatten_pre_group(
d: Dict[SchemaKind, NativeFunction]
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
# Invariant: any NativeFunctions that are code-generated
# should have been grouped into NativeFunctionsGroup objects
assert not any("generated" in f.tags for f in d.values())
return list(d.values())
else:
return [r]
# TODO: how come ValuesView isn't a Sequence lol
pre_grouped_native_functions = pre_group_native_functions(native_functions)
return list(
concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
)
def get_ns_grouped_kernels(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
] = dest.compute_native_function_declaration,
) -> Dict[str, List[str]]:
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
for f in grouped_native_functions:
native_function_namespaces = set()
dispatch_keys = set()
for dispatch_key, backend_idx in backend_indices.items():
backend_metadata = backend_idx.get_kernel(f)
if backend_metadata:
namespace = backend_metadata.cpp_namespace
dispatch_keys.add(dispatch_key)
native_function_namespaces.add(namespace)
else:
namespace = DEFAULT_KERNEL_NAMESPACE
assert (
len(native_function_namespaces) <= 1
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
ns_grouped_kernels[namespace].extend(
native_function_decl_gen(f, backend_idx)
)
return ns_grouped_kernels
def get_native_function_declarations_from_ns_grouped_kernels(
*,
ns_grouped_kernels: Dict[str, List[str]],
) -> List[str]:
declarations: List[str] = []
newline = "\n"
for namespace, kernels in ns_grouped_kernels.items():
ns_helper = NamespaceHelper(
namespace_str=namespace,
entity_name="",
max_level=4,
)
# Convert to a set first to remove duplicate kernel names. Backends are
# allowed to repeat kernel names; only generate the declaration once!
ordered_kernels = list(OrderedDict.fromkeys(kernels))
declarations.extend(
f"""
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
)
return declarations
# Return native function declarations grouped by their namespaces.
def get_native_function_declarations(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
] = dest.compute_native_function_declaration,
) -> List[str]:
"""
Generate kernel declarations, in `NativeFunction(s).h`.
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
:param backend_indices: kernel collections grouped by dispatch key.
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
"""
ns_grouped_kernels = get_ns_grouped_kernels(
grouped_native_functions=grouped_native_functions,
backend_indices=backend_indices,
native_function_decl_gen=native_function_decl_gen,
)
return get_native_function_declarations_from_ns_grouped_kernels(
ns_grouped_kernels=ns_grouped_kernels
)
def get_kernel_namespace(
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
) -> str:
backend_metadata = backend_idx.get_kernel(f)
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
f"with dispatch key {backend_idx.dispatch_key}"
f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
)
return (
backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
)
# Return native function definitions grouped by dispatch key and custom namespace.
# Used in RegisterDispatchKey.cpp and etc.
def get_native_function_definitions(
*,
fm: FileManager,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
symint: bool,
skip_dispatcher_op_registration: bool,
gen_dispatch_helpers: bool,
) -> List[str]:
definitions: List[str] = []
ns_definitions: Dict[str, List[str]] = defaultdict(list)
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
newline = "\n"
ns_gen = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
symint=symint,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
anonymous_gen = dest.RegisterDispatchKey(
backend_idx,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
symint=symint,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
reg_gen = dest.RegisterDispatchKey(
backend_idx,
Target.REGISTRATION,
selector,
rocm=rocm,
symint=symint,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
for f in grouped_native_functions:
kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"::native", ""
)
ns_definitions[kernel_namespace].extend(
ns_gen(f),
)
anonymous_definitions[kernel_namespace].extend(
anonymous_gen(f),
)
namespace = (
f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
)
if namespace not in registrations[kernel_namespace]:
registrations[kernel_namespace] = defaultdict(list)
registrations[kernel_namespace][namespace].extend(
reg_gen(f),
)
for kernel_namespace in ns_definitions:
if len(ns_definitions[kernel_namespace]) == 0:
continue
ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
registration_body = ""
for namespace in registrations[kernel_namespace]:
if not registrations[kernel_namespace][namespace]:
continue
registration_body += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{newline.join(registrations[kernel_namespace][namespace])}
}};"""
definitions.extend(
fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"dispatch_helpers": dest.gen_registration_helpers(backend_idx)
if gen_dispatch_helpers
else [],
"dispatch_anonymous_definitions": anonymous_definitions[
kernel_namespace
],
"static_init_dispatch_registrations": ""
if skip_dispatcher_op_registration
else registration_body,
"deferred_dispatch_registrations": "",
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
},
).split(newline)
)
return definitions
# Return native function declarations grouped by dispatch key and custom namespace.
# Used in CPUFunctions_inl.h and etc.
def get_namespaced_declaration(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
symint: bool,
) -> List[str]:
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
func = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
symint=symint,
)
for f in grouped_native_functions:
namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"native", dispatch_key.lower()
)
ns_grouped_kernels[namespace].extend(
func(f),
)
for namespace, kernels in ns_grouped_kernels.items():
if len(kernels) == 0:
continue
ns_helper = NamespaceHelper(
namespace_str=namespace, entity_name="", max_level=3
)
ordered_kernels = list(OrderedDict.fromkeys(kernels))
declarations.extend(
f"""
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
)
return declarations
# Return native function schema registration code for aten and other namespaces.
def get_native_function_schema_registrations(
*,
native_functions: Sequence[NativeFunction],
schema_selector: SelectiveBuilder,
) -> Tuple[List[str], str]:
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_native_functions[native_function.namespace].append(native_function)
schema_registrations = ""
aten_schema_registrations = []
custom_namespace = None
for namespace, funcs in ns_native_functions.items():
schema_registrations_body = list(
mapMaybe(RegisterSchema(schema_selector), funcs)
)
# NB: we have to separate aten namespace registration from other namespaces,
# because in the template we hardcoded an operator for ATen already.
if namespace == "aten":
aten_schema_registrations = schema_registrations_body
else:
custom_namespace = namespace
tab = "\t"
# if the namespace is predefined, we should use define a library fragment
# instead of a new library
torch_library_macro = (
"TORCH_LIBRARY_FRAGMENT"
if namespace in FRAGMENT_NAMESPACES
else "TORCH_LIBRARY"
)
schema_registrations += f"""
{torch_library_macro}({custom_namespace}, m) {{
{tab.join(schema_registrations_body)}
}};"""
return (aten_schema_registrations, schema_registrations)
def gen_aggregated_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
# Buck doesn't support dynamic output files, so we aggregate all operator
# headers into a single file
cpu_fm.write(
"NativeMetaFunctions.h",
lambda: {
"NativeMetaFunctions_includes": [],
"NativeMetaFunctions_declarations": list(
mapMaybe(compute_meta_function_declaration, structured_native_functions)
),
},
)
method_native_functions = [
fn for fn in native_functions if Variant.method in fn.variants
]
non_method_native_functions = [
fn for fn in native_functions if fn not in method_native_functions
]
cpu_fm.write(
"MethodOperators.h",
lambda: {
"MethodOperators_includes": [],
"MethodOperators_declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
method_native_functions,
)
),
},
)
cpu_fm.write(
"Operators.h",
lambda: {
"Operators_includes": ["#include <ATen/MethodOperators.h>"],
"Operators_declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
non_method_native_functions,
)
),
},
)
cpu_fm.write(
"Functions.h",
lambda: {
"static_dispatch_extra_headers": static_dispatch_extra_headers(
static_dispatch_idx
),
"Functions_includes": ["#include <ATen/Operators.h>"],
"Functions_declarations": list(
mapMaybe(
ComputeFunction(),
native_functions,
)
),
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_native_functions,
backend_indices=backend_indices,
)
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
"NativeFunctions_declarations": declarations,
},
)
for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
if dispatch_key in functions_keys:
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
fm.write_with_template(
f"{dispatch_key}Functions.h",
"DispatchKeyFunctions.h",
lambda: {
"dispatch_key": str(dispatch_key),
"inline_headers": inl_headers,
},
)
fm.write_with_template(
f"{dispatch_key}Functions_inl.h",
"DispatchKeyFunctions_inl.h",
lambda: {
"DispatchKeyFunctions_inl_includes": [],
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_declarations": get_namespaced_declaration(
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_indices[dispatch_key],
selector=selector,
rocm=rocm,
symint=True,
),
},
)
del fm
def gen_per_operator_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
# For CMake builds, split operator declarations into separate headers in
# the ATen/ops folder to split up header dependencies
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list)
for fn in native_functions:
functions_by_root_name[fn.root_name].append(fn)
grouped_functions_by_root_name: Dict[
str, List[Union[NativeFunction, NativeFunctionsGroup]]
] = defaultdict(list)
for group in grouped_native_functions:
name = group.root_name
grouped_functions_by_root_name[name].append(group)
for name, functions in functions_by_root_name.items():
ops_fm.write_with_template(
f"{name}_ops.h",
"Operator.h",
lambda: {
"declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
functions,
)
),
},
)
ops_fm.write_with_template(
f"{name}.h",
"Function.h",
lambda: {
"static_dispatch_ops_headers": list(
mapMaybe(
lambda fn: static_dispatch_ops_header(
fn, backend_index=static_dispatch_idx
),
functions,
)
),
"operator_includes": f"#include <ATen/ops/{name}_ops.h>",
"function_definitions": list(
mapMaybe(
ComputeFunction(),
functions,
)
),
},
)
grouped_functions = grouped_functions_by_root_name.get(name, [])
structured_functions = [
fn
for fn in grouped_functions
if isinstance(fn, NativeFunctionsGroup) and fn.structured
]
is_structured = len(structured_functions) > 0
if is_structured:
ops_fm.write_with_template(
f"{name}_meta.h",
"NativeMetaFunction.h",
lambda: {
"meta_function_declarations": list(
mapMaybe(
compute_meta_function_declaration, structured_functions
)
),
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
ops_fm.write_with_template(
f"{name}_native.h",
"NativeFunction.h",
lambda: {
"extra_includes": (
f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
),
"native_function_declarations": declarations,
},
)
for category, suffix in [
("Functions", ""),
("Operators", "_ops"),
("NativeMetaFunctions", "_meta"),
("NativeFunctions", "_native"),
]:
cpu_fm.write(
f"{category}.h",
lambda: {
f"{category}_includes": [
f"#include <ATen/ops/{name}{suffix}.h>"
for name in sorted(functions_by_root_name.keys())
],
f"{category}_declarations": [],
},
)
for dispatch_key in dispatch_keys:
if dispatch_key not in functions_keys:
continue
dispatch_namespace = dispatch_key.lower()
dispatch_names = []
for name, functions in functions_by_root_name.items():
grouped_functions = grouped_functions_by_root_name.get(name, [])
declarations = list(
concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
symint=True,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
grouped_functions,
)
)
if len(declarations) == 0:
continue
dispatch_names.append(name)
ops_fm.write_with_template(
f"{name}_{dispatch_namespace}_dispatch.h",
"DispatchKeyFunction.h",
lambda: {
"dispatch_namespace": dispatch_namespace,
"dispatch_namespaced_declarations": declarations,
},
)
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
fm.write_with_template(
f"{dispatch_key}Functions.h",
"DispatchKeyFunctions.h",
lambda: {
"dispatch_key": str(dispatch_key),
"inline_headers": inl_headers,
},
)
fm.write_with_template(
f"{dispatch_key}Functions_inl.h",
"DispatchKeyFunctions_inl.h",
lambda: {
"dispatch_namespace": dispatch_namespace,
"DispatchKeyFunctions_inl_includes": [
f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
for name in sorted(dispatch_names)
],
"dispatch_namespaced_declarations": [],
},
)
del fm
cpu_fm.write(
"MethodOperators.h",
lambda: {
"MethodOperators_includes": sorted(
f"#include <ATen/ops/{name}_ops.h>"
for name, functions in functions_by_root_name.items()
if any(Variant.method in fn.variants for fn in functions)
),
"MethodOperators_declarations": [],
},
)
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
valid_tags: Set[str],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
core_fm: FileManager,
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
per_operator_headers: bool,
) -> None:
if per_operator_headers:
gen_per_operator_headers(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
ops_fm=ops_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=rocm,
)
else:
gen_aggregated_headers(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=rocm,
)
core_fm.write(
"TensorBody.h",
lambda: {
"tensor_method_declarations": list(
mapMaybe(
ComputeTensorMethod(
target=Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
native_functions,
)
),
"tensor_method_definitions": list(
mapMaybe(
ComputeTensorMethod(
target=Target.DEFINITION,
static_dispatch_backend_indices=static_dispatch_idx,
),
native_functions,
)
),
},
)
cpu_fm.write(
"RedispatchFunctions.h",
lambda: {
"function_redispatch_definitions": list(
mapMaybe(ComputeRedispatchFunction(), native_functions)
),
},
)
cpu_fm.write(
"RegistrationDeclarations.h",
lambda: {
"registration_declarations": [
compute_registration_declarations(f, backend_indices)
for f in native_functions
],
},
)
cpu_fm.write(
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
)
def gen_aten_interned_strings() -> Dict[str, str]:
attrs = set() # All function argument names
names = set() # All ATen function names
for func in native_functions:
names.add(str(func.func.name.name))
# Some operators don't have a functional variant but we still create a
# symbol without the underscore
names.add(func.func.name.name.base)
for arg in func.func.schema_order_arguments():
attrs.add(arg.name)
# These are keywords in C++, so aren't valid symbol names
# https://en.cppreference.com/w/cpp/language/operator_alternative
names -= {
"and",
"and_eq",
"bitand",
"bitor",
"compl",
"not",
"not_eq",
"or",
"or_eq",
"xor",
"xor_eq",
}
return {
"aten_symbols": " \\\n".join(
[f"_(aten, {name})" for name in sorted(names)]
),
"attr_symbols": " \\\n".join(
[f"_(attr, {name})" for name in sorted(attrs)]
),
}
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
def gen_tags_enum() -> Dict[str, str]:
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
core_fm.write("enum_tag.h", gen_tags_enum)
def gen_source_files(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
view_groups: Sequence[NativeFunctionsViewGroup],
selector: SelectiveBuilder,
static_dispatch_idx: List[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
aoti_fm: FileManager,
core_fm: FileManager,
cpu_fm: FileManager,
cpu_vec_fm: FileManager,
cuda_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
force_schema_registration: bool,
per_operator_headers: bool,
skip_dispatcher_op_registration: bool,
) -> None:
extra_cuda_headers = """\
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/CUDAContext.h>"""
if rocm:
extra_cuda_headers = """\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>"""
for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
if per_operator_headers:
def operator_headers() -> List[str]:
headers = []
for g in grouped_native_functions:
is_registered = False
if backend_index.has_kernel(g):
is_registered = True
# The above has_kernel test on a group will only test for
# the existence of out dispatch, because that's how
# structured kernels work. But sometimes functions can be
# grouped but not be structured, and then you need to check
# each individual piece, as they may have manual dispatch
# entries.
elif isinstance(g, NativeFunctionsGroup) and any(
backend_index.has_kernel(fn) for fn in g.functions()
):
is_registered = True
# TODO: this condition is a bit questionable
# (It has to do with the fact that structured kernels get generated kernels
# to the Meta + CompositeExplicitAutogradNonFunctional keys).
elif g.structured and dispatch_key in (
DispatchKey.Meta,
DispatchKey.CompositeExplicitAutogradNonFunctional,
):
is_registered = True
if not is_registered:
continue
headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
if (
dispatch_key
== DispatchKey.CompositeExplicitAutogradNonFunctional
):
headers.append(f"#include <ATen/ops/{g.root_name}.h>")
if dispatch_key in functions_keys:
headers.append(
f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
)
return sorted(set(headers))
else:
def operator_headers() -> List[str]:
headers = ["#include <ATen/NativeFunctions.h>"]
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
headers.append("#include <ATen/Functions.h>")
if dispatch_key in functions_keys:
headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
return headers
backend_index = backend_indices[dispatch_key]
ns_grouped_native_functions = defaultdict(list)
for grouped_native_function in grouped_native_functions:
namespace = (
grouped_native_function.namespace
if isinstance(grouped_native_function, NativeFunction)
else grouped_native_function.functional.namespace
)
ns_grouped_native_functions[namespace].append(grouped_native_function)
dispatch_namespace = str(dispatch_key).lower()
# CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
# compilation will fail when `-Werror=unused-function` flag is set
gen_dispatch_helpers: bool = (
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
)
dispatch_definitions = get_native_function_definitions(
fm=fm,
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_index,
selector=selector,
rocm=rocm,
symint=True,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=gen_dispatch_helpers,
)
fm.write_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"extra_cuda_headers": extra_cuda_headers
if is_cuda_dispatch_key(dispatch_key)
else "",
"external_backend_headers": "",
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers, rocm
),
"ops_headers": operator_headers(),
"dispatch_helpers": "",
"dispatch_definitions": dispatch_definitions,
},
)
for g in structured_native_functions:
if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
continue
name = g.functional.func.name.name
if dispatch_key is DispatchKey.CPU:
assert fm is cpu_fm
fm.write_with_template(
f"UfuncCPU_{name}.cpp",
"UfuncCPU.cpp",
lambda: {
"meta_declaration": compute_meta_function_declaration(g),
"native_declaration": dest.compute_native_function_declaration(
g, backend_indices[dispatch_key]
),
"native_definitions": dest.compute_ufunc_cpu(g),
},
)
cpu_vec_fm.write_with_template(
f"UfuncCPUKernel_{name}.cpp",
"UfuncCPUKernel.cpp",
lambda: {
"name": name,
"native_definitions": dest.compute_ufunc_cpu_kernel(g),
},
)
elif dispatch_key is DispatchKey.CUDA:
cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
if rocm:
cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
fm.write_with_template(
f"UfuncCUDA_{name}.cu",
"UfuncCUDA.cu",
lambda: {
"name": name,
"cuda_headers": cuda_headers,
"meta_declaration": compute_meta_function_declaration(g),
"native_declaration": dest.compute_native_function_declaration(
g, backend_indices[dispatch_key]
),
"native_definitions": dest.compute_ufunc_cuda(g),
},
)
else:
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
def get_header(
f: NativeFunction,
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(
f, dispatch_key, backend_indices
)
return (
None
if backend_index is None
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
)
def headers_for_aoti() -> str:
headers = []
for g in grouped_native_functions:
if isinstance(g, NativeFunctionsGroup):
for f in g.functions():
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
header = get_header(f)
if header is not None:
headers.append(header)
else:
header = get_header(g)
if header is not None:
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = (
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.h",
lambda: gen_aoti_c_shim(
native_functions,
dispatch_key,
backend_indices,
header=True,
includes="",
),
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
lambda: gen_aoti_c_shim(
native_functions,
dispatch_key,
backend_indices,
header=False,
includes=headers_for_aoti() + "\n" + extra_headers,
),
)
del fm
# BackendSelect is generated specially
def gen_backend_select() -> Dict[str, List[str]]:
relevant_fns = [
fn for fn in native_functions if needs_backend_select(fn, selector)
]
return {
"ops_headers": [
f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
],
"backend_select_method_definitions": list(
mapMaybe(
ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
)
),
"backend_select_function_registrations": list(
mapMaybe(
ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
)
),
}
cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
schema_selector = selector
if force_schema_registration:
schema_selector = SelectiveBuilder.get_nop_selector()
(
aten_schema_registrations,
schema_registrations,
) = get_native_function_schema_registrations(
native_functions=native_functions, schema_selector=schema_selector
)
cpu_fm.write(
"RegisterSchema.cpp",
lambda: {
"aten_schema_registrations": []
if skip_dispatcher_op_registration
else aten_schema_registrations,
"schema_registrations": []
if skip_dispatcher_op_registration
else schema_registrations,
},
)
def key_func(
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> str:
return fn.root_name
cpu_fm.write_sharded(
"Operators.cpp",
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
"operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
"definitions": [
ComputeOperators(
Target.DEFINITION,
static_dispatch_backend_indices=static_dispatch_idx,
)(fn)
],
},
base_env={
"static_dispatch_extra_headers": static_dispatch_extra_headers(
static_dispatch_idx
),
},
num_shards=5,
sharded_keys={
"operator_headers",
"definitions",
"static_dispatch_extra_headers",
},
)
cpu_fm.write("Functions.cpp", dict)
core_fm.write("TensorMethods.cpp", dict)
core_fm.write(
"ATenOpList.cpp",
lambda: {
"aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
},
)
def functionalization_env_callable(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> Dict[str, List[str]]:
def gen_op_headers(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> List[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
return {
"ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition(
selector,
g,
),
"func_registrations": gen_functionalization_registration(
selector,
g,
backend_indices[DispatchKey.CompositeImplicitAutograd],
),
}
all_groups: List[
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
] = list(structured_native_functions) + list(
view_groups # type: ignore[assignment, arg-type, operator]
)
# Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
# The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
structured_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
}
view_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
}
for f in native_functions:
if f.func.name not in structured_map and f.func.name not in view_map:
all_groups.append(f)
cpu_fm.write_sharded(
"RegisterFunctionalization.cpp",
all_groups,
key_fn=key_func,
env_callable=functionalization_env_callable,
num_shards=4,
sharded_keys={
"ops_headers",
"func_definitions",
"func_registrations",
"func_add_back_views_definitions",
"func_add_back_views_registrations",
},
)
cpu_fm.write(
"FunctionalInverses.h",
lambda: {
"view_inverse_declarations": list(
mapMaybe(
lambda g: gen_functionalization_view_inverse_declaration(
selector, g
),
view_groups,
)
)
},
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant.
# Backends that use functionalization and don't know how to handle aliasing ops
# are expected to implement kernels for these {view}_copy kernels instead.
# The code for {view}_copy operators in core is pretty boilerplate-heavy however,
# so we codegen the following:
# (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
# These are never explicitly invoked by the functionalization pass,
# but they could theoretically be called from user code (I added these kernels for completeness,
# since the ops are part of the public API).
# (2) A derivative formula for every {view}_copy operator
# {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
# so rather than stamping all of the entries out in derivatives.yaml,
# we codegen them in.
# This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
cpu_fm.write(
"CompositeViewCopyKernels.cpp",
lambda: {
"ops_headers": [
"\n".join(
f"#include <ATen/ops/{f.root_name}_ops.h>\n"
# NB: this include is important as it ensures we
# set the visibility on generated view_copy kernels
# correctly
f"#include <ATen/ops/{f.root_name}_native.h>"
for f in (
[g.view] if g.view_copy is None else [g.view, g.view_copy]
)
)
for g in view_groups
]
+ [
"\n".join(
f"#include <ATen/ops/{f.root_name}_ops.h>"
for f in [g.inplace, g.mutable, g.functional]
if f is not None and "generated" not in f.tags
)
for g in structured_native_functions
],
"CompositeViewCopyKernel_Definitions": list(
mapMaybe(
GenCompositeViewCopyKernel(
backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
),
view_groups,
)
),
"GeneratedCompositeFunctional_Definitions": list(
mapMaybe(
gen_composite_functional_kernel,
structured_native_functions,
)
),
"GeneratedCompositeOut_Definitions": list(
mapMaybe(
gen_composite_out_kernel,
structured_native_functions,
)
),
},
)
def gen_declarations_yaml(
cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
) -> None:
cpu_fm.write(
"Declarations.yaml",
lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
)
def get_torchgen_root() -> pathlib.Path:
"""
If you're depending on torchgen out-of-tree, you can use the root to figure
out the path to native_functions.yaml
"""
return pathlib.Path(__file__).parent.resolve()
def main() -> None:
parser = argparse.ArgumentParser(description="Generate ATen source files")
parser.add_argument(
"-s",
"--source-path",
help="path to source directory for ATen",
default="aten/src/ATen",
)
parser.add_argument(
"-o",
"--output-dependencies",
help="output a list of dependencies into the given file and exit",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="run without writing any files (still updates outputs)",
)
parser.add_argument(
"--per-operator-headers",
action="store_true",
help="generate separate headers per operator in ATen/ops",
)
parser.add_argument(
"-d",
"--install-dir",
"--install_dir",
help="output directory",
default="build/aten/src/ATen",
)
parser.add_argument(
"--rocm",
action="store_true",
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
)
parser.add_argument(
"--mps",
action="store_true",
help="Generate MPS registration code when set",
)
# TODO: --op-registration-whitelist will be removed when all call-sites
# for gen.py are moved over to using the operator YAML file for mobile
# custom build.
parser.add_argument(
"--op-registration-whitelist",
"--op_registration_whitelist",
nargs="*",
help="filter op registrations by the whitelist (if set); "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
parser.add_argument(
"--op-selection-yaml-path",
"--op_selection_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"that contains the information about the set of selected operators "
"and their categories (training, ...). Each operator is either a "
"full operator name with overload or just a bare operator name. "
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--backend-whitelist",
"--backend_whitelist",
nargs="*",
help="filter dispatch backend by the whitelist (if set), "
"e.g.: CPU CUDA QuantizedCPU ...",
)
parser.add_argument(
"--static-dispatch-backend",
"--static_dispatch_backend",
nargs="*",
help="generate static dispatch code for the specific backend (if set)",
)
parser.add_argument(
"--skip-dispatcher-op-registration",
"--skip_dispatcher_op_registration",
action="store_true",
help="Avoid registering operators into the dispatcher.",
)
parser.add_argument(
"--force-schema-registration",
"--force_schema_registration",
action="store_true",
help="force it to generate schema-only registrations for all ops, including"
"those that are not listed on --op-registration-whitelist",
)
parser.add_argument(
"--generate",
type=str,
nargs="*",
choices=["headers", "sources", "declarations_yaml"],
default=["headers", "sources", "declarations_yaml"],
help="Generate only a subset of files",
)
options = parser.parse_args()
selector = get_custom_build_selector(
options.op_registration_whitelist,
options.op_selection_yaml_path,
)
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
from torchgen.model import dispatch_keys
# TODO: stop generating CUDA kernels for non-CUDA builds
ignore_keys = set()
if not options.mps:
ignore_keys.add(DispatchKey.MPS)
if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
grouped_native_functions = get_grouped_native_functions(native_functions)
structured_native_functions = [
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
]
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
# NB: It is mandatory to NOT use os.path.join here, as the install directory
# will eventually be ingested by cmake, which does not respect Windows style
# path slashes. If you switch this to use os.path.join, you'll get an error
# like:
#
# Syntax error in cmake code when parsing string
#
# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
#
# Invalid character escape '\c'.
core_install_dir = f"{options.install_dir}/core"
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
ops_install_dir = f"{options.install_dir}/ops"
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
cpu_fm = make_file_manager(options=options)
cpu_vec_fm = make_file_manager(options=options)
cuda_fm = make_file_manager(options=options)
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
aoti_fm = make_file_manager(
options=options, install_dir="torch/csrc/inductor/aoti_torch/generated"
)
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
functions_keys = {
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,
}
if options.mps:
functions_keys.add(DispatchKey.MPS)
if options.backend_whitelist:
dispatch_keys = [
k
for k in dispatch_keys
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
]
static_dispatch_idx: List[BackendIndex] = []
if options.static_dispatch_backend:
static_dispatch_idx = [
backend_indices[DispatchKey.parse(key)]
for key in options.static_dispatch_backend
]
for key in options.static_dispatch_backend:
dp_key = DispatchKey.parse(key)
if dp_key not in functions_keys:
functions_keys.add(dp_key)
if "sources" in options.generate:
gen_source_files(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
view_groups=view_groups,
selector=selector,
static_dispatch_idx=static_dispatch_idx,
backend_indices=backend_indices,
aoti_fm=aoti_fm,
core_fm=core_fm,
cpu_fm=cpu_fm,
cpu_vec_fm=cpu_vec_fm,
cuda_fm=cuda_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=options.rocm,
force_schema_registration=options.force_schema_registration,
per_operator_headers=options.per_operator_headers,
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
)
if "headers" in options.generate:
gen_headers(
native_functions=native_functions,
valid_tags=valid_tags,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
core_fm=core_fm,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
ops_fm=ops_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=options.rocm,
per_operator_headers=options.per_operator_headers,
)
if "declarations_yaml" in options.generate:
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem
for fm, prefix in [
(cpu_fm, ""),
(cpu_vec_fm, "cpu_vec_"),
(core_fm, "core_"),
(cuda_fm, "cuda_"),
(ops_fm, "ops_"),
]:
varname = prefix + depfile_stem
path = depfile_path.parent / (prefix + depfile_name)
fm.write_outputs(varname, str(path))
if __name__ == "__main__":
main()