You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
7.5 KiB

# Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, List, Tuple, Union
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
)
from torchgen.utils import assert_never
KERNEL_KEY_VERSION = 1
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
class ScalarType(IntEnum):
Byte = 0
Char = 1
Short = 2
Int = 3
Long = 4
Float = 6
Double = 7
Bool = 11
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
@dataclass(frozen=True)
class ETKernelKeyOpArgMeta:
arg_name: str
dtype: str
# The order of the dimensions if entry is a Tensor
dim_order: Tuple[int, ...]
def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
return f"{dtype_str};{dim_str}"
@dataclass(frozen=True)
class ETKernelKey:
# Field undefined is default = True
arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all
default: bool = False
version: int = KERNEL_KEY_VERSION
@staticmethod
def gen_from_yaml(
args: Dict[str, Tuple[str, str]],
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
dim_order_alias_map: Dict[str, List[int]],
) -> List["ETKernelKey"]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey)
Args:
args: Mapping from argument name to kernel specs
Kernel specs are a tuple of (dtype, dim_order).
Currently tuple entries must be aliased via the alias map arguments
type_alias_map: Mapping from type alias to potential type enums
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
Used for lookup by args
dim_order_alias_map: Mapping from alias to a list of dimension orders
Used for lookup by args
"""
# Cast to dim order to int
dim_order_alias_map = {
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
}
kernel_keys = []
# Get all used Dtype Alias
dtype_alias_used = set()
for type_alias, dim_order in args.values():
# Enforce usage of alias initially
# TODO: Support inlined arguments
assert type_alias in type_alias_map, "Undefined type alias: " + str(
type_alias
)
assert (
dim_order in dim_order_alias_map
), "Undefined dim_order alias: " + str(dim_order)
dtype_alias_used.add(type_alias)
# Generate all permutations of dtype alias values
alias_dtypes = [
[(alias, dtype) for dtype in type_alias_map[alias]]
for alias in dtype_alias_used
]
alias_permutations = [
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
]
# Using each alias value permutation, generate kernel keys
op_arg_cache = {}
for permutation in alias_permutations:
arg_list = []
for arg_name, arg_spec in args.items():
dtype = permutation[arg_spec[0]]
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
if (
cache_key := (arg_name, dtype, tuple(dim_order))
) not in op_arg_cache:
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
arg_list.append(op_arg_cache[cache_key])
kernel_keys.append(ETKernelKey(tuple(arg_list)))
return kernel_keys
def to_native_string(self) -> str:
if self.default:
return "default"
return (
"v"
+ str(KERNEL_KEY_VERSION)
+ "/"
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
)
@dataclass(frozen=True)
class ETKernelIndex:
index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
m = self.get_kernels(g)
return m is not None
def get_kernels(
self, g: Union[NativeFunction, NativeFunctionsGroup]
) -> Dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = g.functional
else:
assert_never(g)
if f.func.name not in self.index:
return {}
return self.index[f.func.name]
@staticmethod
def grow_from_backend_indices(
kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
) -> None:
for dk in backend_indices:
index = backend_indices[dk]
for op, backend_metadata in index.items():
if op in kernel_index:
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
else:
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
@staticmethod
def from_backend_indices(
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex":
kernel_index: Dict[
OperatorName, Dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)
def grow(
self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex":
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self
def _to_backend_index(self) -> BackendIndex:
"""
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
"""
index: Dict[OperatorName, BackendMetadata] = {}
for op in self.index:
kernel_dict = self.index[op]
assert (
len(kernel_dict.values()) == 1
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
index[op] = kernel_dict.get(
ETKernelKey(default=True),
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
)
return BackendIndex(
dispatch_key=DispatchKey.CPU,
use_out_as_primary=False,
device_guard=False,
external=False,
index=index,
)
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod
def merge_indices(
index_a: "ETKernelIndex", index_b: "ETKernelIndex"
) -> "ETKernelIndex":
combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items():
for key, metadata in entry.items():
combined[op][key] = metadata
return ETKernelIndex(combined)