You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

546 lines
17 KiB

from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
Expr,
NamedCType,
opmath_t,
scalar_t,
StructuredImplSignature,
VectorizedCType,
)
from torchgen.api.ufunc import UfunctorBindings
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
NativeFunctionsGroup,
ScalarType,
UfuncKey,
)
from torchgen.utils import OrderedSet
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CUDA STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# NB: not bothering to generate dispatch stub forward declaration in header,
# we can just paste it whereever necessary
# TODO: use BackendIndex
# dispatch_key: DispatchKey # only CPU/CUDA right now
# Represents functors for implementing CUDA ufuncs.
# Functors are templated by scalar_t because when USERS instantiate functors
# they are templated. A functor looks something like this:
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
# : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
@dataclass(frozen=True)
class UfunctorSignature:
g: NativeFunctionsGroup
scalar_tensor_idx: Optional[int]
name: str
def arguments(self) -> UfunctorBindings:
return ufunc.ufunctor_arguments(
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
)
def fields(self) -> List[Binding]:
# fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
def returns_type(self) -> CType:
# TODO: don't hardcode; return type will be inferred based on tags on
# the native function
return BaseCType(scalar_t)
def decl_fields(self) -> str:
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
def inline_defn_ctor(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
# NB: hypothetically could do this with translate but the
# transition here is very regular
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
return f"{self.name}({args_str}) : {init_str} {{}}"
def decl_apply(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().apply)
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
@dataclass(frozen=True)
class UfuncSignature:
g: NativeFunctionsGroup
name: str
compute_t: CType
def arguments(self) -> List[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# steps:
# 1. take the functional signature
# 2. use api.ufunc to convert it to template signature. this establishes
# the type of the template function
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
# this establish context in which we call the template signature
#
# StructuredImplSignature context
# ~> functor constructor sig
#
# Functor constructor context
# ~> functor fields sig
#
# Functor apply context (functor fields + functor apply sig)
# ~> template sig
#
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
num_tensors = sum(
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
)
return num_tensors == 2
def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup,
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors.
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: List[str] = []
loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1,
UfuncKey.CUDAFunctorOnOther: 0,
UfuncKey.CUDAFunctor: None,
}
if eligible_for_binary_scalar_specialization(g):
keys = [
UfuncKey.CUDAFunctorOnSelf,
UfuncKey.CUDAFunctorOnOther,
UfuncKey.CUDAFunctor,
]
else:
keys = [UfuncKey.CUDAFunctor]
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
assert k not in loops, f"cannot use {k} on non-binary function"
for k in keys:
# If the key was directly defined, skip functor codegen; we assume the
# user already done it for us
if k in loops:
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
)
for dtype in loops[k].supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
continue
# Note [ScalarOnly and Generic must match names for CUDA]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Otherwise, look in ANY of the generic entries. For simplicity of
# codegen, both ScalarOnly and Generic are defined, the ufunc name
# must match (if they didn't match, we'd have to generate distinct
# functors per dtype, which is awful, so we're not going to do it unless
# someone really forces us to)
ufunc_name = None
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
if lk not in loops:
continue
if ufunc_name is None:
ufunc_name = loops[lk].name
else:
# See Note [ScalarOnly and Generic must match names for CUDA]
assert (
ufunc_name == loops[lk].name
), "ScalarOnly and Generic must have same ufunc name"
supported_dtypes |= loops[lk].supported_dtypes
assert ufunc_name is not None
name = f"{k}_{ufunc_name}"
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
)
for dtype in supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
ufunc_sig = UfuncSignature(
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
)
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
ufunctors.append(
f"""
template <typename scalar_t>
struct {ufunctor_sig.name} {{
using opmath_t = at::opmath_type<scalar_t>;
{ufunctor_sig.decl_fields()}
{ufunctor_sig.inline_defn_ctor()}
__device__ {ufunctor_sig.decl_apply()} {{
return {ufunc_sig.call(apply_ctx)};
}}
}};
"""
)
return ufunctor_sigs, "\n".join(ufunctors)
@dataclass(frozen=True)
class BinaryScalarSpecializationConfig:
scalar_idx: int
ctor_tensor: str
ufunc_key: UfuncKey
BinaryScalarSpecializationConfigs = [
BinaryScalarSpecializationConfig(
scalar_idx=0,
ctor_tensor="self",
ufunc_key=UfuncKey.CUDAFunctorOnOther,
),
BinaryScalarSpecializationConfig(
scalar_idx=1,
ctor_tensor="other",
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
),
]
def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding],
) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;"
body += "if (false) {}\n" # for ease of codegen
for config in BinaryScalarSpecializationConfigs:
if config.ufunc_key not in inner_loops:
continue
ufunctor_sig = inner_loops[config.ufunc_key]
scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway)
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
ctx.append(
Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
)
)
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
)
# NB: ufunctor must be allocated before iter.remove_operand is called,
# as it relies on iter
body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
iter.remove_operand({scalar_idx});
gpu_kernel(iter, ufunctor);
}}"""
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
)
body += f"""
else {{
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
"""
return body
@with_native_function
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
# First, build the functors, indexing them by dtype
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
# Next, build the conditionals
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
stub_sig = StubSignature(g)
return f"""
{ufunctors}
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
{dtype_cases_str}
);
}}
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
{sig.defn()} {{
{stub_sig.direct_call(sig.arguments())};
}}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CPU STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@dataclass(frozen=True)
class StubSignature:
g: NativeFunctionsGroup
@property
def name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_stub"
@property
def kernel_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_kernel"
@property
def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> List[Binding]:
return ufunc.stub_arguments(self.g)
def type(self) -> str:
cpp_args = self.arguments()
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
def dispatch_decl(self) -> str:
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
def dispatch_defn(self) -> str:
return f"DEFINE_DISPATCH({self.name})"
def kernel_defn(self) -> str:
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
def type_defn(self) -> str:
return f"using {self.type_name} = {self.type()}"
# must be called from context where this is TensorIteratorBase*
def call(self, ctx: Sequence[Binding]) -> str:
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# used in CUDA to skip the unnecessary dynamic dispatch
def direct_call(self, ctx: Sequence[Binding]) -> str:
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@with_native_function
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
return f"""
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.dispatch_defn()};
{sig.defn()} {{
{stub_sig.call(sig.arguments())};
}}
"""
def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding],
) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
scalar_loop = inner_loops[UfuncKey.CPUScalar]
vec_loop = None
if UfuncKey.CPUVector in inner_loops:
vec_loop = inner_loops[UfuncKey.CPUVector]
# NB: We DON'T use translate here, because translate is
# incapable of CSE'ing the scalar accesses in case it is also
# used by Vectorized; also, the unpacking here is very simple
# and only affects Scalar; everything else is implicitly captured
# by the lambda
# Setup scalar in scope
body = []
ctx = []
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
if vec_loop is not None:
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
)
ctx.append(
Expr(
f"_v_{b.name}",
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
)
)
# Setup lambda signature
# NB: simplified version of ufunctor_arguments
scalar_bindings = []
vec_bindings = []
for a in g.functional.func.arguments.flat_non_out:
if not a.type.is_tensor_like():
continue
assert a.type == BaseType(BaseTy.Tensor)
scalar_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, BaseCType(scalar_t)),
argument=a,
)
)
if vec_loop is not None:
vec_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
argument=a,
)
)
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
r: List[Union[Expr, Binding]] = []
r.extend(ctx)
r.extend(b)
return r
body_str = "\n".join(body)
if vec_loop is not None:
return f"""
{body_str}
cpu_kernel_vec(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
);
"""
else:
return f"""
{body_str}
cpu_kernel(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
);
"""
@with_native_function
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = []
# ORDER MATTERS: this specifies overriding precedence
if k in loops: # should happen rarely
lks.append(k)
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
lks.append(UfuncKey.ScalarOnly)
if UfuncKey.Generic in loops:
lks.append(UfuncKey.Generic)
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
for lk in lks:
for dtype in loops[lk].supported_dtypes:
compute_t: CType
if k is UfuncKey.CPUScalar:
compute_t = BaseCType(scalar_t)
elif k is UfuncKey.CPUVector:
compute_t = VectorizedCType(BaseCType(scalar_t))
else:
raise AssertionError()
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
if k not in inner_ufunc_sigs:
inner_ufunc_sigs[k] = UfuncSignature(
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
)
# Build the conditionals
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
return f"""
namespace {{
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
{dtype_cases_str}
);
}}
}} // anonymous namespace
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
"""