You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
9.0 KiB
266 lines
9.0 KiB
5 months ago
|
import textwrap
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List, Optional, Sequence, Tuple
|
||
|
|
||
|
from torchgen.api.translate import translate
|
||
|
from torchgen.api.types import DispatcherSignature
|
||
|
from torchgen.context import method_with_native_function
|
||
|
from torchgen.model import (
|
||
|
Argument,
|
||
|
BaseTy,
|
||
|
BaseType,
|
||
|
FunctionSchema,
|
||
|
ListType,
|
||
|
NativeFunction,
|
||
|
OptionalType,
|
||
|
Return,
|
||
|
SchemaKind,
|
||
|
Type,
|
||
|
)
|
||
|
from torchgen.utils import mapMaybe
|
||
|
|
||
|
|
||
|
def is_tensor(typ: Type) -> bool:
|
||
|
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
|
||
|
|
||
|
|
||
|
def is_optional_tensor(typ: Type) -> bool:
|
||
|
return isinstance(typ, OptionalType) and is_tensor(typ.elem)
|
||
|
|
||
|
|
||
|
def is_tensor_list(typ: Type) -> bool:
|
||
|
return isinstance(typ, ListType) and is_tensor(typ.elem)
|
||
|
|
||
|
|
||
|
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
||
|
result = f"""\
|
||
|
Tensor {name}_value;
|
||
|
optional<int64_t> {name}_bdim;
|
||
|
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, {cur_level_var});"""
|
||
|
return textwrap.dedent(result).split("\n")
|
||
|
|
||
|
|
||
|
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
||
|
result = f"""\
|
||
|
optional<Tensor> {name}_value;
|
||
|
optional<int64_t> {name}_bdim;
|
||
|
if ({name}) {{
|
||
|
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
|
||
|
}}"""
|
||
|
return textwrap.dedent(result).split("\n")
|
||
|
|
||
|
|
||
|
def gen_unwraps(
|
||
|
flat_arguments: Sequence[Argument], cur_level_var: str
|
||
|
) -> Tuple[str, List[str]]:
|
||
|
arg_names = [a.name for a in flat_arguments]
|
||
|
arg_types = [a.type for a in flat_arguments]
|
||
|
|
||
|
tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
|
||
|
optional_tensors = [
|
||
|
name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
|
||
|
]
|
||
|
|
||
|
unwraps = []
|
||
|
for tensor in tensors:
|
||
|
unwraps += unwrap_tensor(tensor, cur_level_var)
|
||
|
|
||
|
for opt_tensor in optional_tensors:
|
||
|
unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
|
||
|
unwrap_code = "\n".join(unwraps)
|
||
|
|
||
|
unwrapped_arg_list = []
|
||
|
for arg in arg_names:
|
||
|
if arg in tensors or arg in optional_tensors:
|
||
|
unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
|
||
|
else:
|
||
|
unwrapped_arg_list.append(arg)
|
||
|
return unwrap_code, unwrapped_arg_list
|
||
|
|
||
|
|
||
|
def gen_case_where_all_bdims_are_none(
|
||
|
outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
|
||
|
) -> str:
|
||
|
conditions = []
|
||
|
flat_args = schema.arguments.flat_all
|
||
|
for arg in flat_args:
|
||
|
if not arg.type.is_tensor_like():
|
||
|
continue
|
||
|
conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
|
||
|
|
||
|
sig = DispatcherSignature.from_schema(schema)
|
||
|
translated_args = ", ".join(
|
||
|
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
|
||
|
)
|
||
|
return f"""\
|
||
|
if ({' && '.join(conditions)}) {{
|
||
|
return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
|
||
|
}}"""
|
||
|
|
||
|
|
||
|
def gen_returns(
|
||
|
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
|
||
|
) -> str:
|
||
|
idx = 0
|
||
|
wrapped_returns = []
|
||
|
for ret in returns:
|
||
|
if is_tensor(ret.type):
|
||
|
wrapped_returns.append(
|
||
|
f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
|
||
|
)
|
||
|
idx += 2
|
||
|
elif is_tensor_list(ret.type):
|
||
|
wrapped_returns.append(
|
||
|
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
|
||
|
)
|
||
|
idx += 2
|
||
|
else:
|
||
|
wrapped_returns.append(f"std::get<{idx}>({results_var})")
|
||
|
idx += 1
|
||
|
if len(wrapped_returns) == 1:
|
||
|
result = f"return {wrapped_returns[0]};"
|
||
|
else:
|
||
|
result = f'return std::make_tuple({", ".join(wrapped_returns)});'
|
||
|
return result
|
||
|
|
||
|
|
||
|
def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
|
||
|
return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
|
||
|
|
||
|
|
||
|
def is_mutated_arg(argument: Argument) -> bool:
|
||
|
return argument.annotation is not None and argument.annotation.is_write
|
||
|
|
||
|
|
||
|
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||
|
# Assumptions:
|
||
|
# - only one argument is being modified in-place
|
||
|
# - the argument that is being modified in-place is the first argument
|
||
|
# - all returns are either Tensor, tuple of Tensor, or TensorList
|
||
|
schema = native_function.func
|
||
|
sig = DispatcherSignature.from_schema(schema)
|
||
|
returns = schema.returns
|
||
|
|
||
|
# Check assumptions. If these are invalid we return None
|
||
|
# and punt the work to handle them to the future.
|
||
|
assert schema.kind() == SchemaKind.inplace
|
||
|
if not is_mutated_arg(schema.arguments.flat_all[0]):
|
||
|
return None
|
||
|
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
|
||
|
return None
|
||
|
|
||
|
# Only support cases where all returns are Tensors or vector<Tensor>
|
||
|
if len(returns) == 0:
|
||
|
return None
|
||
|
if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
|
||
|
return None
|
||
|
if not accepts_at_least_one_tensor_input(schema):
|
||
|
return None
|
||
|
|
||
|
cur_level_var = "cur_level"
|
||
|
|
||
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||
|
|
||
|
return f"""\
|
||
|
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||
|
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||
|
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
|
||
|
int64_t {cur_level_var} = maybe_layer->layerId();
|
||
|
{textwrap.indent(bdims_all_none_case, " ")}
|
||
|
{textwrap.indent(unwraps, " ")}
|
||
|
batch_rule({', '.join(unwrapped_arg_list)});
|
||
|
return {schema.arguments.flat_all[0].name};
|
||
|
}}"""
|
||
|
|
||
|
|
||
|
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
|
||
|
schema = native_function.func
|
||
|
sig = DispatcherSignature.from_schema(schema)
|
||
|
cur_level_var = "cur_level"
|
||
|
|
||
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||
|
|
||
|
return f"""\
|
||
|
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||
|
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||
|
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
|
||
|
int64_t {cur_level_var} = maybe_layer->layerId();
|
||
|
{textwrap.indent(bdims_all_none_case, " ")}
|
||
|
{textwrap.indent(unwraps, " ")}
|
||
|
batch_rule({', '.join(unwrapped_arg_list)});
|
||
|
}}"""
|
||
|
|
||
|
|
||
|
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||
|
schema = native_function.func
|
||
|
sig = DispatcherSignature.from_schema(schema)
|
||
|
returns = schema.returns
|
||
|
|
||
|
# Only support cases where all returns are Tensors or vector<Tensor>
|
||
|
if not accepts_at_least_one_tensor_input(schema):
|
||
|
return None
|
||
|
if len(returns) == 0:
|
||
|
return gen_vmap_plumbing_no_returns(native_function)
|
||
|
if not all(ret.type.is_tensor_like() for ret in returns):
|
||
|
return None
|
||
|
# in-place views need special handling
|
||
|
if "inplace_view" in native_function.tags:
|
||
|
return None
|
||
|
|
||
|
if schema.kind() == SchemaKind.inplace:
|
||
|
return gen_vmap_inplace_plumbing(native_function)
|
||
|
|
||
|
# Don't support these (mutable, out, scratch)
|
||
|
if schema.kind() != SchemaKind.functional:
|
||
|
return None
|
||
|
|
||
|
results_var = "results"
|
||
|
cur_level_var = "cur_level"
|
||
|
|
||
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||
|
|
||
|
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
|
||
|
return f"""\
|
||
|
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||
|
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||
|
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
|
||
|
int64_t {cur_level_var} = maybe_layer->layerId();
|
||
|
{textwrap.indent(bdims_all_none_case, " ")}
|
||
|
{textwrap.indent(unwraps, " ")}
|
||
|
auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
|
||
|
{wrapped_returns}
|
||
|
}}"""
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class ComputeBatchRulePlumbing:
|
||
|
@method_with_native_function
|
||
|
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||
|
opname = str(f.func.name)
|
||
|
result = gen_vmap_plumbing(f)
|
||
|
return result
|
||
|
|
||
|
|
||
|
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
|
||
|
body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
|
||
|
return f"""
|
||
|
#pragma once
|
||
|
#include <ATen/Operators.h>
|
||
|
#include <ATen/functorch/PlumbingHelper.h>
|
||
|
|
||
|
namespace at {{ namespace functorch {{
|
||
|
|
||
|
{body}
|
||
|
|
||
|
}}}} // namespace at::functorch
|
||
|
"""
|