You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

465 lines
17 KiB

5 months ago
from typing import Any, Dict, List, Optional, Tuple, Union
from torchgen.api.types import (
BaseCppType,
BaseCType,
boolT,
CType,
deviceT,
doubleT,
generatorT,
layoutT,
ListCType,
longT,
memoryFormatT,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
stringT,
SymIntT,
VectorCType,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
ListType,
OperatorName,
OptionalType,
Return,
TensorOptionsArguments,
Type,
)
_valueT: Optional[BaseCppType] = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
# tensor internally tracks a ValueT representing the IR node that would have
# actually produced the value of this tensor for real.
#
# This is configurable because different lazy tensor backends (LTC vs XLA) will
# have different IR representations. (Though, arguably, after unification they
# shouldn't!)
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val: BaseCppType) -> None:
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type, properties: "LazyIrProperties", *, symint: bool
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
Type conversion for lazy currently consists of
(1) changing at::Tensors into lazy::Values
(2) wrapping everything in a BaseCType
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
This is incomplete- there are assertions in places that it's expected to need to add
more types as the codegen is used with more operators.
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
if properties.TreatScalarsAsConstants:
return BaseCType(scalarT)
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
if symint:
return BaseCType(getValueT())
else:
return BaseCType(longT)
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
return BaseCType(doubleT)
elif typ.name == BaseTy.str:
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Generator:
return BaseCType(generatorT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
return BaseCType(memoryFormatT)
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
elif typ.elem == BaseType(BaseTy.SymInt):
# TODO: return a value type. The problem here is analogous to
# the problem with tensorListValueT: if you have SymInt[] you
# cannot conveniently save the list of Value directly, as nodes
# expect to save values as a vector for ALL arguments. So you
# need a separate IR node that represents all of the size nodes
# assembled into a list. I'm not an LTC dev so I don't want to
# figure it out right now. Y'all figure it out...
return VectorCType(BaseCType(longT))
else:
return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
# TODO: Determining this based off of CType is bad; this should be computed
# from Type directly; then the same logic as process_ir_type can be used
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
"""
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
return (
typ.type == getValueT()
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif typ == VectorCType(BaseCType(SymIntT)):
# TODO: report True for this
return False
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem, properties)
return False
def isSymIntType(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
def isWrappedScalarType(typ: Type) -> bool:
"""
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
Since we literally change the type from scalarT to valueT, information is lost.
This function helps build a list of wrapped scalars to save that information
"""
if isinstance(typ, BaseType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.name == BaseTy.Scalar
elif isinstance(typ, (OptionalType, ListType)):
return isWrappedScalarType(typ.elem)
return False
# TODO: dedupe with Type.is_generator_like
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
elif isinstance(typ, (OptionalType)):
return isGeneratorType(typ.elem)
return False
# This class caches a few derived properties computed from an Argument
# and LazyIrProperties
class LazyArgument:
name: str
orig_type: Type
lazy_type_: Optional[CType]
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool
# Whether or not we are treating this as symint or not
symint: bool
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
self.name = arg.name
self.orig_type = arg.type
self.symint = symint
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = symint and (
isSymIntType(arg.type)
or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
# TODO: lists of symints are not currently treated as value types
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_lazy_value = isValueType(self.lazy_type, properties)
@property
def lazy_type(self) -> CType:
assert (
self.lazy_type_ is not None
), f"Attempted to access lazy_type for invalid argument {self.name}"
return self.lazy_type_
class LazyIrProperties:
"""Collection of properties for an IR node
The property groups are listed below. Each group is mutually
exclusive, meaning that only one property from each group can be True
at any one time. The properties can be accessed as if they were normal
attributes. The mutual exclusivity is automatically handled.
"""
Properties: Tuple[Tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
"ShapeCache", # Utilize the shape cache to defer computation
),
(
"Lower", # Codegen full lower function
"LowerDeclOnly", # Codegen only lower function declaration
),
(
"CanBeReused", # Codegen full reuse function
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
),
(
"CreateFn", # Codegen full create function
"CreateFnDeclOnly", # Codegen only create function declaration
),
(
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
),
)
def __init__(self, *default_properties: str):
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
LazyIrProperties.Properties
)
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)
def __getattr__(self, key: str) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
return properties[values] == key
return self.__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
properties[values] = key if value else None
return value
raise KeyError(f"Invalid property: {key}")
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
#
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: "OperatorName"
positional_args: Tuple[LazyArgument, ...]
keyword_args: Tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: Tuple["Return", ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None
# original function schema
func: FunctionSchema
# Whether or not we are code-genning for SymInt or not
symint: bool
properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
"Lower",
"CanBeReused",
)
opkind: Optional[str] = None
def __init__(
self,
func: FunctionSchema,
properties: Optional[LazyIrProperties] = None,
*,
symint: bool,
):
if properties:
self.properties = properties
self.func = func
self.symint = symint
positional_args: List[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = func.arguments.self_arg.argument
positional_args.append(
LazyArgument(arg, self.properties, symint=symint)
)
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
keyword_args: List[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
"post_tensor_options_kwarg_only",
"out",
]:
curr_args = getattr(func.arguments, arg_field)
if curr_args is not None:
if isinstance(curr_args, TensorOptionsArguments):
curr_args = curr_args.all()
for arg in curr_args:
if isGeneratorType(arg.type):
assert (
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(
arg.name, arg.type # type:ignore[arg-type]
)
keyword_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@property
def node_name(self) -> str:
"""
Return camel-case version of op in node.
Note: This function also appends any `overload_name` in the operation.
For example, if the op is `bitwise_and.Tensor`, the returned name
will be `BitwiseAndTensor`.
"""
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
return "".join(word.capitalize() or "" for word in op_name.split("_"))
@property
def aten_name(self) -> str:
return str(self.name.name)
@property
def base_name(self) -> str:
return f"{self.name.name.base}"
def filtered_args(
self,
positional: bool = True,
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = True,
) -> List[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: List[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
args.extend(self.keyword_args)
if values and scalars and generator:
return args
elif values and scalars:
return [a for a in args if not a.is_generator]
elif values:
return [a for a in args if a.is_lazy_value]
elif scalars:
return [
a
for a in args
if not a.is_lazy_value and (generator or not a.is_generator)
]
return []
@property
def positional_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)