import ast import dataclasses import inspect import re import string import sys from collections import namedtuple from textwrap import dedent from typing import List, Tuple # noqa: F401 import torch import torch.jit.annotations from torch import _jit_internal from torch._C._jit_tree_views import ( Apply, Assert, Assign, Attribute, AugAssign, BinOp, Break, ClassDef, Const, Continue, Decl, Def, Delete, DictComp, DictLiteral, Dots, EmptyTypeAnnotation, ExprStmt, FalseLiteral, For, Ident, If, ListComp, ListLiteral, NoneLiteral, Param, Pass, Property, Raise, Return, Select, SliceExpr, Starred, Stmt, StringLiteral, Subscript, TernaryIf, TrueLiteral, TupleLiteral, UnaryOp, Var, While, With, WithItem, ) from torch._jit_internal import ( # noqa: F401 _is_drop_fn, FunctionModifiers, is_static_fn, should_drop, ) from torch._sources import ( get_source_lines_and_file, make_source_context, parse_def, ParsedDef as _ParsedDef, ) from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace _IS_ASTUNPARSE_INSTALLED = False try: import astunparse # type: ignore[import] _IS_ASTUNPARSE_INSTALLED = True except ImportError: pass # Borrowed from cPython implementation # https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411# _reserved_prefix = "__jit" _reserved_names = {"print"} _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits) def is_reserved_name(name): return name.startswith(_reserved_prefix) or name in _reserved_names pretty_node_names = { ast.FunctionDef: "function definitions", ast.For: "for loops", ast.Delete: "del statements", ast.ClassDef: "class definitions", ast.With: "with statements", ast.Raise: "raise statements", ast.Assert: "assertions", ast.Import: "import statements", ast.ImportFrom: "import statements", ast.Global: "global variables", ast.Break: "break statements", ast.Continue: "continue statements", } node_start_tokens = { ast.FunctionDef: "def", ast.For: "for", ast.Delete: "del", ast.ClassDef: "class", ast.With: "with", ast.Raise: "raise", ast.Assert: "assert", ast.Import: "import", ast.ImportFrom: "from", ast.Global: "global", ast.Break: "break", ast.Continue: "continue", } pretty_node_names.update( { ast.AsyncFunctionDef: "async function definitions", ast.AsyncFor: "async for loops", ast.AsyncWith: "async with statements", ast.Try: "try blocks", ast.Nonlocal: "nonlocal variables", } ) node_start_tokens.update( { ast.AsyncFunctionDef: "async def", ast.AsyncFor: "async for", ast.AsyncWith: "async with", ast.Try: "try", ast.Nonlocal: "nonlocal", } ) pretty_node_names.update( { ast.AnnAssign: "annotated assignments", } ) # NB: no specific token for AnnAssign class FrontendError(Exception): def __init__(self, source_range, msg): self.source_range = source_range self.msg = msg # This has to be instantiated here so the ErrorReport is accurate to the # call stack when the FrontendError was raised self.error_report = torch._C.ErrorReport(self.source_range) def __str__(self): return self.msg + self.error_report.what().lstrip() class NotSupportedError(FrontendError): pass class UnsupportedNodeError(NotSupportedError): def __init__(self, ctx, offending_node, reason=""): # If we don't have a specific token, we default to length of 1 node_type = type(offending_node) range_len = len(node_start_tokens.get(node_type, " ")) source_range = ctx.make_range( offending_node.lineno, offending_node.col_offset, offending_node.col_offset + range_len, ) feature_name = pretty_node_names.get(node_type, node_type.__name__) msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported" super().__init__(source_range, msg) class FrontendTypeError(FrontendError): pass def build_withitems(ctx, items): items = [build_withitem(ctx, i) for i in items] return list(items) def build_stmts(ctx, stmts): stmts = [build_stmt(ctx, s) for s in stmts] return list(filter(None, stmts)) def get_class_properties(cls, self_name): """ Get a list of Property objects representing the properties of a class. Args: cls: The class to get properties of. self_name: The name of the class that the properties should belong to. Returns: A list of Property objects corresponding to the properties of cls. Property here refers to the subclass of TreeView. """ props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property)) # Any property that should not compiled must be in this list on the Module. unused_properties = getattr(cls, "__jit_unused_properties__", []) # Create Property TreeView objects from inspected property objects. properties = [] for prop in props: if prop[0] not in unused_properties and not should_drop(prop[1].fget): getter = get_jit_def( prop[1].fget, f"__{prop[0]}_getter", self_name=self_name ) setter = ( get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None ) properties.append( Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter) ) return properties def get_class_assigns(ctx, cls_ast): assigns = [] def maybe_build_assign(builder, entry): nonlocal assigns try: assigns.append(builder(ctx, entry)) except NotSupportedError: pass for entry in cls_ast.body: if isinstance(entry, ast.Assign): maybe_build_assign(StmtBuilder.build_Assign, entry) elif isinstance(entry, ast.AnnAssign): maybe_build_assign(StmtBuilder.build_AnnAssign, entry) return assigns def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ and not _is_drop_fn(m), ) def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls # Get and parse the source code for this class sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack() ) source = "".join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) # Special case for dataclasses. In general we need access to the source code for # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes # magic methods for classes, and we can't get the source code for these methods. As a # workaround, we synthesize TorchScript-friendly implementations ourselves. if dataclasses.is_dataclass(cls): # Detect whether the user manually implemented any of the magic methods. If they did, # we don't want to synthesize/override them. overrides = { method.name for method in class_ast.body if isinstance(method, ast.FunctionDef) and method.name in DATACLASS_MAGIC_METHODS } for i, (name, _) in enumerate(methods): # Is this a magic method we can synthesize? synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name) if synthesizer_fn and name not in overrides: parsed_def = synthesizer_fn(cls) methods[i] = name, parsed_def func = getattr(cls, name) _jit_internal.loader.cache(func, parsed_def.source) method_defs = [ get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) for (name, obj) in methods ] properties = get_class_properties(cls, self_name) leading_whitespace_len = len(source.split("\n", 1)[0]) - len( dedent_src.split("\n", 1)[0] ) ctx = make_source_context( source, filename, file_lineno, leading_whitespace_len, False ) assigns = get_class_assigns(ctx, class_ast) return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns) def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. Args: fn: A function object to compile or a pre-parsed ParsedDef object def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn type_line = torch.jit.annotations.get_type_line(parsed_def.source) fn_def = parsed_def.ast.body[0] if is_classmethod: arg_name = fn_def.args.args[0].arg # Insert a statement that assigns the first argument to the class assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] fn_def.body.insert(0, assign_stmt) # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse( 'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")' ) if len(unused_fn_def.body) != 1 or not isinstance( unused_fn_def.body[0], ast.FunctionDef ): raise RuntimeError( f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}" ) unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` fn_def.args.kwarg = fn_def.args.vararg = None for arg in fn_def.args.args + fn_def.args.kwonlyargs: # Replace potentially unsupported type annotations by "Any" arg.annotation = unused_def.args.args[0].annotation if _is_drop_fn(fn): # Dropping potentially unsupported return type annotation for jit._drop fn_def.returns = None fn_def.type_comment = None # If MonkeyType is installed, get all the consolidated type traces # for the arguments from type_trace_db type_trace_db = torch.jit._script._get_type_trace_db() pdt_arg_types = None if monkeytype_trace and not isinstance(fn, _ParsedDef): # type: ignore[truthy-function] qualname = get_qualified_name(fn) pdt_arg_types = type_trace_db.get_args_types(qualname) return build_def( parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types, ) # TODO: more robust handling of recognizing ignore context manager def is_torch_jit_ignore_context_manager(stmt): # checks if the statement is torch.jit.ignore context manager if isinstance(stmt.items[0].context_expr, ast.Call): # extract torch part function = stmt.items[0].context_expr.func if isinstance(function, ast.Attribute): attr_name = function.attr attr_value = function.value if attr_name == "_IgnoreContextManager" and isinstance( attr_value, ast.Attribute ): # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager) if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name): if attr_value.value.id == "torch": return True return False class Builder: def __call__(self, ctx, node): method = getattr(self, "build_" + node.__class__.__name__, None) if method is None: raise UnsupportedNodeError(ctx, node) return method(ctx, node) def build_class_def(ctx, py_def, methods, properties, self_name, assigns): r = ctx.make_range( py_def.lineno, py_def.col_offset, py_def.col_offset + len("class") ) return ClassDef( Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns ) def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None): body = py_def.body r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def")) param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types) return_type = None if getattr(py_def, "returns", None) is not None: return_type = build_expr(ctx, py_def.returns) decl = Decl(r, param_list, return_type) is_method = self_name is not None if type_line is not None: type_comment_decl = torch._C.parse_type_comment(type_line) decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) _vararg_kwarg_err = ( "Compiled functions can't take variable number of arguments " "or use keyword-only arguments with defaults" ) def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): if py_args.kwarg is not None: expr = py_args.kwarg ctx_range = ctx.make_range( expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) ) raise NotSupportedError(ctx_range, _vararg_kwarg_err) if py_args.vararg is not None: expr = py_args.vararg ctx_range = ctx.make_range( expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) ) raise NotSupportedError(ctx_range, _vararg_kwarg_err) if len(py_args.kw_defaults) > 0: # kw_defaults is a list of the values for the kwargs (which default to None), # so they don't actually have line numbers. for arg in py_args.kw_defaults: if arg is not None: ctx_range = build_expr(ctx, arg).range() raise NotSupportedError(ctx_range, _vararg_kwarg_err) # List of Tuple of args and type as inferred by profile directed typing arg_and_types = [ ( arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None, ) for arg in py_args.args ] arg_and_types_kwonlyargs = [ ( arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None, ) for arg in py_args.kwonlyargs ] result = [ build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) for arg, arg_type in arg_and_types ] result += [ build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type) for arg, arg_type in arg_and_types_kwonlyargs ] return result def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None): # NB: In Python3 py_arg is a pair of (str arg, expr? annotation) name = py_arg.arg r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name)) if getattr(py_arg, "annotation", None) is not None: annotation_expr = build_expr(ctx, py_arg.annotation) elif pdt_arg_type: annotation_expr = Var(Ident(r, pdt_arg_type)) elif self_name is not None and name == "self": annotation_expr = Var(Ident(r, self_name)) else: annotation_expr = EmptyTypeAnnotation(r) return Param(annotation_expr, Ident(r, name), kwarg_only) def build_ignore_context_manager(ctx, stmt): InputType = namedtuple("InputType", ["name", "ann"]) OutputType = namedtuple("OutputType", ["name", "ann"]) def process_ins_outs(args): # parse the context manager to figure out inputs and outputs # with their annotated types # TODO: add input, output validator inputs = [] outputs = [] for arg in args: var_name = arg.arg var_ann = arg.value.value var_decl_type, var_ann = var_ann.split(":") if var_decl_type == "inp": inputs.append(InputType(var_name, var_ann)) if var_decl_type == "out": outputs.append(OutputType(var_name, var_ann)) return inputs, outputs def create_unique_name_ext(ctx, stmt): # extension will be based on the full path filename plus # the line number of original context manager fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename) return f"{fn}_{stmt.lineno}" def build_return_ann_stmt(outputs): return_type_ann = "" return_statement_str = "return " if len(outputs) == 0: return_type_ann += " -> None" if len(outputs) == 1: return_type_ann = " -> " + outputs[0].ann return_statement_str += outputs[0].name if len(outputs) > 1: return_type_ann = " -> Tuple" return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]" return_statement_str += ", ".join([var.name for var in outputs]) return return_type_ann, return_statement_str def build_args(args): return ", ".join([arg.name for arg in args]) inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords) # build the replacement function str with given inputs and outputs ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt) ignore_function_str = "\ndef " + ignore_function_name ignore_function_str += ( "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")" ) return_ann, return_stmt = build_return_ann_stmt(outputs) ignore_function_str += return_ann + ": pass" # first create the functionDef object from just declaration ignore_function = ast.parse(ignore_function_str).body[0] # dump the body of context manager to dummy function ignore_function.body = stmt.body # type: ignore[attr-defined] # insert return statement to the function return_stmt = ast.parse(return_stmt).body[0] ignore_function.body.append(return_stmt) # type: ignore[attr-defined] # registers the custom function in the global context ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function) ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}' exec(ignore_func_str) # noqa: P204 # build the statements as: # , , ... = torch.jit.frontend.(, ) assign_str_lhs = build_args(outputs) # this function will be registered in torch.jit.frontend module by default assign_str_rhs = ( f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")" ) if len(outputs) > 0: assign_str = assign_str_lhs + " = " + assign_str_rhs else: assign_str = assign_str_rhs assign_ast = ast.parse(assign_str).body[0] return assign_ast def get_default_args(fn): if fn is None: return {} signature = inspect.signature(fn) return { k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty } def get_default_args_for_class(cls): """ Get default arguments for all methods in a class (except for static methods). Args: cls: type - The class type to inspect for default arguments. Returns: A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any] that maps each argument name to its default value. """ # Get methods (except static methods because those are compiled separately as # if they were independent script functions). methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__, ) # Get method defaults. Property defaults do not need to be considered # because setters cannot be invoked without a value. defaults = { method_name: get_default_args(method_impl) for method_name, method_impl in methods } return defaults class WithItemBuilder(Builder): @staticmethod def build_withitem(ctx, item): lineno = item.context_expr.lineno start = item.context_expr.col_offset end = start + len(pretty_node_names[ast.With]) op_vars = item.optional_vars r = ctx.make_range(lineno, start, end) return WithItem( r, build_expr(ctx, item.context_expr), build_expr(ctx, op_vars) if op_vars else None, ) class StmtBuilder(Builder): augassign_map = { ast.Add: "+", ast.Sub: "-", ast.Mult: "*", ast.Div: "/", ast.Mod: "%", ast.BitOr: "|", ast.BitAnd: "&", ast.BitXor: "^", ast.LShift: "<<", ast.RShift: ">>", ast.Pow: "**", } @staticmethod def build_Expr(ctx, stmt): value = stmt.value if value.__class__.__name__ == "Str": # If a statement is a string literal expression, # then it is a docstring. Just ignore it. return None else: return ExprStmt(build_expr(ctx, value)) @staticmethod def build_Assign(ctx, stmt): rhs = build_expr(ctx, stmt.value) lhs = [build_expr(ctx, x) for x in stmt.targets] return Assign(lhs, rhs) @staticmethod def build_AnnAssign(ctx, stmt): if stmt.value is None: raise UnsupportedNodeError(ctx, stmt, reason="without assigned value") # Disallow type annotations on instance attributes outside of __init__ if ( type(stmt.target) == ast.Attribute and stmt.target.value.id == "self" # type: ignore[attr-defined] and ctx.funcname != "__init__" ): start = stmt.col_offset end = start + len(f"self.{stmt.target.attr}") if hasattr(stmt.annotation, "id"): end += len(f": {stmt.annotation.id}") sr = ctx.make_range(stmt.lineno, start, end) raise ValueError( "Type annotations on instance attributes must be declared in " f"__init__, not '{ctx.funcname}': {sr}" ) rhs = build_expr(ctx, stmt.value) lhs = build_expr(ctx, stmt.target) the_type = build_expr(ctx, stmt.annotation) return Assign([lhs], rhs, the_type) @staticmethod def build_Delete(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del")) return Delete(r, [build_expr(ctx, target) for target in stmt.targets]) @staticmethod def build_Return(ctx, stmt): r = ctx.make_range( stmt.lineno, stmt.col_offset, stmt.col_offset + len("return") ) return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value)) @staticmethod def build_Raise(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise")) expr = build_expr(ctx, stmt.exc) return Raise(r, expr) @staticmethod def build_Assert(ctx, stmt): r = ctx.make_range( stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert") ) test = build_expr(ctx, stmt.test) msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None return Assert(r, test, msg) @staticmethod def build_AugAssign(ctx, stmt): lhs = build_expr(ctx, stmt.target) rhs = build_expr(ctx, stmt.value) op = type(stmt.op) if op in StmtBuilder.augassign_map: op_token = StmtBuilder.augassign_map[op] else: raise NotSupportedError( find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)), "unsupported kind of augmented assignment: " + op.__name__, ) return AugAssign(lhs, op_token, rhs) @staticmethod def build_While(ctx, stmt): if stmt.orelse: # TODO: try to recover the location of else:? Python doesn't give us useful # annotations in this case raise NotSupportedError( None, "else branches of while loops aren't supported" ) r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while")) return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body)) @staticmethod def build_For(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for")) if stmt.orelse: raise NotSupportedError(r, "else branches of for loops aren't supported") return For( r, [build_expr(ctx, stmt.target)], [build_expr(ctx, stmt.iter)], build_stmts(ctx, stmt.body), ) @staticmethod def build_If(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if")) return If( r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body), build_stmts(ctx, stmt.orelse), ) @staticmethod def build_Print(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print")) if stmt.dest: raise NotSupportedError( r, "print statements with non-default destinations aren't supported" ) args = [build_expr(ctx, val) for val in stmt.values] return ExprStmt(Apply(Var(Ident(r, "print")), args, [])) @staticmethod def build_Pass(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass")) return Pass(r) @staticmethod def build_Break(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break")) return Break(r) @staticmethod def build_Continue(ctx, stmt): r = ctx.make_range( stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue") ) return Continue(r) @staticmethod def build_With(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with")) # Handle ignore context manager if is_torch_jit_ignore_context_manager(stmt): if not _IS_ASTUNPARSE_INSTALLED: raise RuntimeError( "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \ please install it in your Python environment" ) assign_ast = build_ignore_context_manager(ctx, stmt) return build_stmt(ctx, assign_ast) return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body)) class ExprBuilder(Builder): binop_map = { ast.Add: "+", ast.Sub: "-", ast.Mult: "*", ast.Div: "/", ast.Pow: "**", ast.Mod: "%", ast.FloorDiv: "//", ast.BitAnd: "&", ast.BitXor: "^", ast.BitOr: "|", ast.LShift: "<<", ast.RShift: ">>", } binop_map[ast.MatMult] = "@" unop_map = { ast.Not: "not", ast.USub: "-", ast.Invert: "~", } boolop_map = { ast.And: "and", ast.Or: "or", } cmpop_map = { ast.Eq: "==", ast.NotEq: "!=", ast.LtE: "<=", ast.Lt: "<", ast.GtE: ">=", ast.Gt: ">", ast.Is: "is", ast.IsNot: "is not", ast.In: "in", ast.NotIn: "not in", } @staticmethod def build_Attribute(ctx, expr): base = build_expr(ctx, expr.value) # expr.attr is just a string, so it's not annotated in any way, so we have # to build the range manually source = ctx.source.encode("utf-8") def get_char(index): return chr(source[index]) start_pos = base.range().end + 1 while get_char(start_pos) in string.whitespace: # Skip whitespace start_pos += 1 end_pos = start_pos + len(expr.attr) name_range = ctx.make_raw_range(start_pos, end_pos) return Select(base, Ident(name_range, expr.attr)) @staticmethod def build_Call(ctx, expr): func = build_expr(ctx, expr.func) args = [build_expr(ctx, py_arg) for py_arg in expr.args] if hasattr(expr, "starargs") and expr.starargs: stararg_expr = build_expr(ctx, expr.starargs) args += [Starred(stararg_expr.range(), stararg_expr)] kwargs = [] for kw in expr.keywords: kw_expr = build_expr(ctx, kw.value) # XXX: we could do a better job at figuring out the range for the name here if not kw.arg: raise NotSupportedError( kw_expr.range(), "keyword-arg expansion is not supported" ) kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr)) return Apply(func, args, kwargs) @staticmethod def build_Ellipsis(ctx, expr): r = ctx.make_range( expr.lineno, expr.col_offset, expr.col_offset + 3 ) # len("...") == 3 return Dots(r) @staticmethod def build_Name(ctx, expr): r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id)) if expr.id.startswith(_reserved_prefix): raise NotSupportedError( r, "names of variables used in JIT-ed functions " "can't start with " + _reserved_prefix, ) if expr.id == "True": return TrueLiteral(r) elif expr.id == "False": return FalseLiteral(r) elif expr.id == "None": return NoneLiteral(r) elif expr.id == "Ellipsis": return Dots(r) return Var(Ident(r, expr.id)) @staticmethod def build_NameConstant(ctx, expr): r = ctx.make_range( expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)) ) if expr.value is True: return TrueLiteral(r) elif expr.value is False: return FalseLiteral(r) elif expr.value is None: return NoneLiteral(r) elif expr.value == Ellipsis: return Dots(r) else: raise ValueError("Name constant value unsupported: " + str(expr.value)) @staticmethod def build_BinOp(ctx, expr): lhs = build_expr(ctx, expr.left) rhs = build_expr(ctx, expr.right) op = type(expr.op) if op == ast.Div and not ctx.uses_true_division: err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) raise FrontendError( err_range, "Division of ints in TorchScript uses Python 3 true " "division semantics. Please put `from __future__ " "import division` at the top of your file", ) op_token = ExprBuilder.binop_map.get(op) if op_token is None: err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) raise NotSupportedError( err_range, "unsupported binary operator: " + op.__name__ ) return BinOp(op_token, lhs, rhs) @staticmethod def build_UnaryOp(ctx, expr): sub_expr = build_expr(ctx, expr.operand) op = type(expr.op) op_token = ExprBuilder.unop_map.get(op) if op_token is None: raise NotSupportedError( expr.range(), "unsupported unary operator: " + op.__name__ ) r = ctx.make_range( expr.lineno, expr.col_offset, expr.col_offset + len(op_token) ) return UnaryOp(r, op_token, sub_expr) @staticmethod def build_BoolOp(ctx, expr): if len(expr.values) < 2: raise AssertionError( "expected at least 2 values in BoolOp, but got " + str(len(expr.values)) ) sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values] op = type(expr.op) op_token = ExprBuilder.boolop_map.get(op) if op_token is None: err_range = ctx.make_raw_range( sub_exprs[0].range().end, sub_exprs[1].range().start ) raise NotSupportedError( err_range, "unsupported boolean operator: " + op.__name__ ) lhs = sub_exprs[0] for rhs in sub_exprs[1:]: lhs = BinOp(op_token, lhs, rhs) return lhs @staticmethod def build_IfExp(ctx, expr): return TernaryIf( build_expr(ctx, expr.test), build_expr(ctx, expr.body), build_expr(ctx, expr.orelse), ) @staticmethod def build_Compare(ctx, expr): operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)] result = None for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]): op = type(op_) op_token = ExprBuilder.cmpop_map.get(op) r = ctx.make_raw_range(lhs.range().end, rhs.range().start) if op_token is None: raise NotSupportedError( r, "unsupported comparison operator: " + op.__name__ ) if op == ast.NotIn: # NB: `not in` is just `not( in )`, so we don't introduce new tree view # but just make it a nested call in our tree view structure in_expr = BinOp("in", lhs, rhs) cmp_expr = UnaryOp(r, "not", in_expr) else: cmp_expr = BinOp(op_token, lhs, rhs) if result is None: result = cmp_expr else: result = BinOp("and", result, cmp_expr) return result @staticmethod def build_Subscript(ctx, expr): def build_SliceExpr(ctx, base, slice_expr): lower = ( build_expr(ctx, slice_expr.lower) if slice_expr.lower is not None else None ) upper = ( build_expr(ctx, slice_expr.upper) if slice_expr.upper is not None else None ) step = ( build_expr(ctx, slice_expr.step) if slice_expr.step is not None else None ) return SliceExpr(base.range(), lower, upper, step) def build_Index(ctx, base, index_expr): if isinstance(index_expr.value, ast.Tuple): raise NotSupportedError( base.range(), "slicing multiple dimensions with tuples not supported yet", ) return build_expr(ctx, index_expr.value) def build_ExtSlice(ctx, base, extslice): sub_exprs = [] for expr in extslice.dims: sub_type = type(expr) if sub_type is ast.Index: sub_exprs.append(build_Index(ctx, base, expr)) elif sub_type is ast.Slice: sub_exprs.append(build_SliceExpr(ctx, base, expr)) elif sub_type is ast.Ellipsis: sub_exprs.append(Dots(base.range())) else: raise NotSupportedError( base.range(), f"slicing multiple dimensions with {sub_type} not supported", ) return sub_exprs base = build_expr(ctx, expr.value) sub_type = type(expr.slice) if sub_type is ast.Index: if isinstance(expr.slice.value, ast.Tuple): # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] # XXX: Indexing using a list is **different**! It triggers advanced indexing. indices = [ build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts ] if not indices: # `col_offset` is an int, but `end_col_offset` is # `Optional[int]`. The magic number is here to make # sure we can parse `()` on any machine r = ctx.make_range( expr.lineno, expr.slice.value.col_offset, expr.slice.value.col_offset + 2, ) tup = TupleLiteral(r, []) indices.append(tup) return Subscript(base, indices) else: return Subscript(base, [build_expr(ctx, expr.slice.value)]) elif sub_type is ast.Slice: return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) elif sub_type is ast.ExtSlice: return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) elif sys.version_info >= ( 3, 9, ): # In Python3.9 array indicies are not wrapped in ast.Index if sub_type is ast.Tuple: # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] indices = [] for index_expr in expr.slice.elts: if isinstance(index_expr, ast.Slice): indices.append(build_SliceExpr(ctx, base, index_expr)) else: indices.append(build_expr(ctx, index_expr)) # Special-case logic for `typing.Tuple[()]` if not indices: # See note above r.e. magic number r = ctx.make_range( expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2 ) tup = TupleLiteral(r, []) indices.append(tup) return Subscript(base, indices) return Subscript(base, [build_expr(ctx, expr.slice)]) else: # Ellipsis (can only happen in Python 2) raise NotSupportedError(base.range(), "ellipsis is not supported") @staticmethod def build_List(ctx, expr): return ListLiteral( ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), [build_expr(ctx, e) for e in expr.elts], ) @staticmethod def build_Tuple(ctx, expr): return TupleLiteral( ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), [build_expr(ctx, e) for e in expr.elts], ) @staticmethod def build_Dict(ctx, expr): range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) if expr.keys and not expr.keys[0]: raise NotSupportedError( range, "Dict expansion (e.g. `{**dict}`) is not supported" ) return DictLiteral( range, [build_expr(ctx, e) for e in expr.keys], [build_expr(ctx, e) for e in expr.values], ) @staticmethod def build_Num(ctx, expr): value = str(expr.value) r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value)) return Const(r, value) @staticmethod def build_Constant(ctx, expr): value = expr.value if value is None or isinstance(value, bool): # NB: this check has to happen before the int check because bool is # a subclass of int return ExprBuilder.build_NameConstant(ctx, expr) if isinstance(value, (int, float, complex)): return ExprBuilder.build_Num(ctx, expr) elif isinstance(value, str): return ExprBuilder.build_Str(ctx, expr) elif isinstance(value, type(Ellipsis)): return ExprBuilder.build_Ellipsis(ctx, expr) else: error_range = ctx.make_range( expr.lineno, expr.col_offset, expr.col_offset + len(str(value)) ) raise FrontendError(error_range, "Unknown Constant expression type") @staticmethod def build_Str(ctx, expr): value = str(expr.value) r = ctx.make_range( expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1 ) return StringLiteral(r, value) @staticmethod def build_JoinedStr(ctx, expr): s = "" args = [] for value in expr.values: r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1) if isinstance(value, ast.FormattedValue): if value.conversion != -1: raise NotSupportedError(r, "Don't support conversion in JoinedStr") if value.format_spec is not None: raise NotSupportedError(r, "Don't support formatting in JoinedStr") s += "{}" args.append(build_expr(ctx, value.value)) elif isinstance(value, ast.Str): s += value.s else: raise NotSupportedError(r, "Unsupported value in JoinedStr") r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, []) @staticmethod def build_ListComp(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) if len(stmt.generators) != 1: raise NotSupportedError(r, "Only a single generator is currently supported") if len(stmt.generators[0].ifs) != 0: raise NotSupportedError(r, "Comprehension ifs are not supported yet") elt_expr = build_expr(ctx, stmt.elt) target_expr = build_expr(ctx, stmt.generators[0].target) iter_expr = build_expr(ctx, stmt.generators[0].iter) return ListComp(r, elt_expr, target_expr, iter_expr) @staticmethod def build_GeneratorExp(ctx, stmt): # Convert Generator expression to ListComp return ExprBuilder.build_ListComp(ctx, stmt) @staticmethod def build_DictComp(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) if len(stmt.generators) != 1: raise NotSupportedError(r, "Only a single generator is currently supported") if len(stmt.generators[0].ifs) != 0: raise NotSupportedError(r, "Comprehension ifs are not supported yet") key_expr = build_expr(ctx, stmt.key) value_expr = build_expr(ctx, stmt.value) target_expr = build_expr(ctx, stmt.generators[0].target) iter_expr = build_expr(ctx, stmt.generators[0].iter) return DictComp(r, key_expr, value_expr, target_expr, iter_expr) @staticmethod def build_Starred(ctx, expr): r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) return Starred(r, build_expr(ctx, expr.value)) build_expr = ExprBuilder() build_stmt = StmtBuilder() build_withitem = WithItemBuilder() def find_before(ctx, pos, substr, offsets=(0, 0)): new_pos = ctx.source[:pos].rindex(substr) return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])