import ast import builtins import dis import enum import inspect import re import typing import warnings from textwrap import dedent from typing import Type import torch from torch._C import ( _GeneratorType, AnyType, AwaitType, BoolType, ComplexType, DeviceObjType, DictType, EnumType, FloatType, FutureType, InterfaceType, IntType, ListType, NoneType, NumberType, OptionalType, StreamObjType, StringType, TensorType, TupleType, UnionType, ) from torch._sources import get_source_lines_and_file from .._jit_internal import ( # type: ignore[attr-defined] _Await, _qualified_name, Any, BroadcastingList1, BroadcastingList2, BroadcastingList3, Dict, Future, is_await, is_dict, is_future, is_ignored_fn, is_list, is_optional, is_tuple, is_union, List, Optional, Tuple, Union, ) from ._state import _get_script_class if torch.distributed.rpc.is_available(): from torch._C import RRefType from .._jit_internal import is_rref, RRef from torch._ops import OpOverloadPacket class Module: def __init__(self, name, members): self.name = name self.members = members def __getattr__(self, name): try: return self.members[name] except KeyError: raise RuntimeError( f"Module {self.name} has no member called {name}" ) from None class EvalEnv: env = { "torch": Module("torch", {"Tensor": torch.Tensor}), "Tensor": torch.Tensor, "typing": Module("typing", {"Tuple": Tuple}), "Tuple": Tuple, "List": List, "Dict": Dict, "Optional": Optional, "Union": Union, "Future": Future, "Await": _Await, } def __init__(self, rcb): self.rcb = rcb if torch.distributed.rpc.is_available(): self.env["RRef"] = RRef def __getitem__(self, name): if name in self.env: return self.env[name] if self.rcb is not None: return self.rcb(name) return getattr(builtins, name, None) def get_signature(fn, rcb, loc, is_method): if isinstance(fn, OpOverloadPacket): signature = try_real_annotations(fn.op, loc) else: signature = try_real_annotations(fn, loc) if signature is not None and is_method: # If this is a method, then the signature will include a type for # `self`, but type comments do not contain a `self`. So strip it # away here so everything is consistent (`inspect.ismethod` does # not work here since `fn` is unbound at this point) param_types, return_type = signature param_types = param_types[1:] signature = (param_types, return_type) if signature is None: type_line, source = None, None try: source = dedent("".join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass # This might happen both because we failed to get the source of fn, or # because it didn't have any annotations. if type_line is not None: signature = parse_type_line(type_line, rcb, loc) return signature def is_function_or_method(the_callable): # A stricter version of `inspect.isroutine` that does not pass for built-in # functions return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) def is_vararg(the_callable): if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 # If `the_callable` is a class, de-sugar the call so we can still get # the signature the_callable = the_callable.__call__ if is_function_or_method(the_callable): return inspect.getfullargspec(the_callable).varargs is not None else: return False def get_param_names(fn, n_args): if isinstance(fn, OpOverloadPacket): fn = fn.op if ( not is_function_or_method(fn) and callable(fn) and is_function_or_method(fn.__call__) ): # noqa: B004 # De-sugar calls to classes fn = fn.__call__ if is_function_or_method(fn): if is_ignored_fn(fn): fn = inspect.unwrap(fn) return inspect.getfullargspec(fn).args else: # The `fn` was not a method or function (maybe a class with a __call__ # method, so use a default param name list) return [str(i) for i in range(n_args)] def check_fn(fn, loc): # Make sure the function definition is not a class instantiation try: source = dedent("".join(get_source_lines_and_file(fn)[0])) except (OSError, TypeError): return if source is None: return py_ast = ast.parse(source) if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): raise torch.jit.frontend.FrontendError( loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", ) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise torch.jit.frontend.FrontendError( loc, "Expected a single top-level function" ) def _eval_no_call(stmt, glob, loc): """Evaluate statement as long as it does not contain any method/function calls.""" bytecode = compile(stmt, "", mode="eval") for insn in dis.get_instructions(bytecode): if "CALL" in insn.opname: raise RuntimeError( f"Type annotation should not contain calls, but '{stmt}' does" ) return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 def parse_type_line(type_line, rcb, loc): """Parse a type annotation specified as a comment. Example inputs: # type: (Tensor, torch.Tensor) -> Tuple[Tensor] # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor """ arg_ann_str, ret_ann_str = split_type_line(type_line) try: arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError( "Failed to parse the argument list of a type annotation" ) from e if not isinstance(arg_ann, tuple): arg_ann = (arg_ann,) try: ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError( "Failed to parse the return type of a type annotation" ) from e arg_types = [ann_to_type(ann, loc) for ann in arg_ann] return arg_types, ann_to_type(ret_ann, loc) def get_type_line(source): """Try to find the line containing a comment with the type annotation.""" type_comment = "# type:" lines = source.split("\n") lines = list(enumerate(lines)) type_lines = list(filter(lambda line: type_comment in line[1], lines)) # `type: ignore` comments may be needed in JIT'ed functions for mypy, due # to the hack in torch/_VF.py. # An ignore type comment can be of following format: # 1) type: ignore # 2) type: ignore[rule-code] # This ignore statement must be at the end of the line # adding an extra backslash before the space, to avoid triggering # one of the checks in .github/workflows/lint.yml type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) if len(type_lines) == 0: # Catch common typo patterns like extra spaces, typo in 'ignore', etc. wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") wrong_type_lines = list( filter(lambda line: wrong_type_pattern.search(line[1]), lines) ) if len(wrong_type_lines) > 0: raise RuntimeError( "The annotation prefix in line " + str(wrong_type_lines[0][0]) + " is probably invalid.\nIt must be '# type:'" + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 + "\nfor examples" ) return None elif len(type_lines) == 1: # Only 1 type line, quit now return type_lines[0][1].strip() # Parse split up argument types according to PEP 484 # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code return_line = None parameter_type_lines = [] for line_num, line in type_lines: if "# type: (...) -> " in line: return_line = (line_num, line) break elif type_comment in line: parameter_type_lines.append(line) if return_line is None: raise RuntimeError( "Return type line '# type: (...) -> ...' not found on multiline " "type annotation\nfor type lines:\n" + "\n".join([line[1] for line in type_lines]) + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" ) def get_parameter_type(line): item_type = line[line.find(type_comment) + len(type_comment) :] return item_type.strip() types = map(get_parameter_type, parameter_type_lines) parameter_types = ", ".join(types) return return_line[1].replace("...", parameter_types) def split_type_line(type_line): """Split the comment with the type annotation into parts for argument and return types. For example, for an input of: # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] This function will return: ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") """ start_offset = len("# type:") try: arrow_pos = type_line.index("->") except ValueError: raise RuntimeError( "Syntax error in type annotation (cound't find `->`)" ) from None return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() def try_real_annotations(fn, loc): """Try to use the Py3.5+ annotation syntax to get the type.""" try: # Note: anything annotated as `Optional[T]` will automatically # be returned as `Union[T, None]` per # https://github.com/python/typing/blob/master/src/typing.py#L850 sig = inspect.signature(fn) except ValueError: return None all_annots = [sig.return_annotation] + [ p.annotation for p in sig.parameters.values() ] if all(ann is sig.empty for ann in all_annots): return None arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] return_type = ann_to_type(sig.return_annotation, loc) return arg_types, return_type # Finds common type for enum values belonging to an Enum class. If not all # values have the same type, AnyType is returned. def get_enum_value_type(e: Type[enum.Enum], loc): enum_values: List[enum.Enum] = list(e) if not enum_values: raise ValueError(f"No enum values defined for: '{e.__class__}'") types = {type(v.value) for v in enum_values} ir_types = [try_ann_to_type(t, loc) for t in types] # If Enum values are of different types, an exception will be raised here. # Even though Python supports this case, we chose to not implement it to # avoid overcomplicate logic here for a rare use case. Please report a # feature request if you find it necessary. res = torch._C.unify_type_list(ir_types) if not res: return AnyType.get() return res def is_tensor(ann): if issubclass(ann, torch.Tensor): return True if issubclass( ann, ( torch.LongTensor, torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ShortTensor, torch.HalfTensor, torch.CharTensor, torch.ByteTensor, torch.BoolTensor, ), ): warnings.warn( "TorchScript will treat type annotations of Tensor " "dtype-specific subtypes as if they are normal Tensors. " "dtype constraints are not enforced in compilation either." ) return True return False def _fake_rcb(inp): return None def try_ann_to_type(ann, loc, rcb=None): ann_args = typing.get_args(ann) # always returns a tuple! if ann is inspect.Signature.empty: return TensorType.getInferred() if ann is None: return NoneType.get() if inspect.isclass(ann) and is_tensor(ann): return TensorType.get() if is_tuple(ann): # Special case for the empty Tuple type annotation `Tuple[()]` if len(ann_args) == 1 and ann_args[0] == (): return TupleType([]) return TupleType([try_ann_to_type(a, loc) for a in ann_args]) if is_list(ann): elem_type = try_ann_to_type(ann_args[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann_args[0], loc) value = try_ann_to_type(ann_args[1], loc) # Raise error if key or value is None if key is None: raise ValueError( f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" ) if value is None: raise ValueError( f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" ) return DictType(key, value) if is_optional(ann): if issubclass(ann_args[1], type(None)): contained = ann_args[0] else: contained = ann_args[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) return OptionalType(valid_type) if is_union(ann): # TODO: this is hack to recognize NumberType if set(ann_args) == {int, float, complex}: return NumberType.get() inner: List = [] # We need these extra checks because both `None` and invalid # values will return `None` # TODO: Determine if the other cases need to be fixed as well for a in typing.get_args(ann): if a is None: inner.append(NoneType.get()) maybe_type = try_ann_to_type(a, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) inner.append(maybe_type) return UnionType(inner) # type: ignore[arg-type] if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann_args[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann_args[0], loc)) if is_await(ann): elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() return AwaitType(elementType) if ann is float: return FloatType.get() if ann is complex: return ComplexType.get() if ann is int or ann is torch.SymInt: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(ann.__torch_script_interface__) if ann is torch.device: return DeviceObjType.get() if ann is torch.Generator: return _GeneratorType.get() if ann is torch.Stream: return StreamObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if _get_script_class(ann) is None: scripted_class = torch.jit._script._recursive_compile_class(ann, loc) name = scripted_class.qualified_name() else: name = _qualified_name(ann) return EnumType(name, get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): maybe_script_class = _get_script_class(ann) if maybe_script_class is not None: return maybe_script_class if torch._jit_internal.can_compile_class(ann): return torch.jit._script._recursive_compile_class(ann, loc) # Maybe resolve a NamedTuple to a Tuple Type if rcb is None: rcb = _fake_rcb return torch._C._resolve_type_from_object(ann, loc, rcb) def ann_to_type(ann, loc, rcb=None): the_type = try_ann_to_type(ann, loc, rcb) if the_type is not None: return the_type raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") __all__ = [ "Any", "List", "BroadcastingList1", "BroadcastingList2", "BroadcastingList3", "Tuple", "is_tuple", "is_list", "Dict", "is_dict", "is_optional", "is_union", "TensorType", "TupleType", "FloatType", "ComplexType", "IntType", "ListType", "StringType", "DictType", "AnyType", "Module", # TODO: Consider not exporting these during wildcard import (reserve # that for the types; for idiomatic typing code.) "get_signature", "check_fn", "get_param_names", "parse_type_line", "get_type_line", "split_type_line", "try_real_annotations", "try_ann_to_type", "ann_to_type", ]