You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
6.5 KiB

5 months ago
from dataclasses import dataclass
from typing import List, Optional
import torchgen.api.types as api_types
from torchgen.api import cpp, structured
from torchgen.api.types import (
ArgName,
BaseCppType,
BaseCType,
Binding,
ConstRefCType,
CType,
NamedCType,
scalarT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
FunctionSchema,
NativeFunctionsGroup,
Type,
)
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
return f"ufunc_{func.name.name}_{dispatch_key}"
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
return schema_kernel_name(g.out.func, dispatch_key)
# Tensors are omitted (as they are stored in TensorIterator), everything else is
# passed along (technically, we can pass tensors along too, it just wastes
# argument registers)
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
# Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif t == BaseType(BaseTy.Tensor):
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
if scalar_t == api_types.scalar_t:
return api_types.opmath_t
raise NotImplementedError
# NB: Tensors in constructor are stored in opmath_t, not scalar_t
# because Tensor in constructor = its a scalar tensor partially applied =
# it can be higher precision and we want to compute in that higher precision
#
# NB: CUDA only
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Only Tensors ever get passed directly to operator()
#
# NB: CUDA only
# (Actually, this works for CPU too)
def ufunctor_apply_type(
t: Type, *, binds: ArgName, scalar_t: BaseCppType
) -> NamedCType:
if t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(scalar_t))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# The actual ufunc template function the user writes. Everything here
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
# in CPU
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, compute_t)
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, compute_t)
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
return Binding(
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
name=a.name,
default=None,
argument=a,
)
@dataclass(frozen=True)
class UfunctorBindings:
ctor: List[Binding]
apply: List[Binding]
# ufunctors are a CUDA-only concept representing functors that take some of
# their arguments on a host-side constructor, and the rest in the device-side
# apply. E.g.,
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition
def ufunctor_arguments(
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
) -> UfunctorBindings:
ctor = []
apply = []
for a in g.functional.func.arguments.flat_non_out:
if a.type.is_tensor_like():
if scalar_tensor_idx == 0:
# put it in the ctor anyway
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
scalar_tensor_idx = None
else:
if scalar_tensor_idx is not None:
scalar_tensor_idx -= 1
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
else:
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
assert scalar_tensor_idx is None
return UfunctorBindings(ctor=ctor, apply=apply)
# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
# which do the actual computation in question. E.g.,
#
# template <typename T>
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
# return self + alpha * other;
# }
#
# In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
return [
ufunc_argument(a, compute_t=compute_t)
for a in g.functional.func.arguments.flat_non_out
]
# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
# vectorized versions. E.g.,
#
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
# stubs drop all tensor arguments (they are implicit in the TensorIterator
# argument and keep everything else)
return [
r
for a in g.out.func.arguments.flat_non_out
if not a.type.is_tensor_like()
for r in structured.argument(a)
]