from types import TracebackType from typing import List, Optional import tempfile import traceback import contextlib import inspect import os.path # This file contains utilities for ensuring dynamically compile()'d # code fragments display their line numbers in backtraces. # # The constraints: # # - We don't have control over the user exception printer (in particular, # we cannot assume the linecache trick will work, c.f. # https://stackoverflow.com/q/50515651/23845 ) # # - We don't want to create temporary files every time we compile() # some code; file creation should happen lazily only at exception # time. Arguably, you *should* be willing to write out your # generated Python code to file system, but in some situations # (esp. library code) it would violate user expectation to write # to the file system, so we try to avoid it. In particular, we'd # like to keep the files around, so users can open up the files # mentioned in the trace; if the file is invisible, we want to # avoid clogging up the filesystem. # # If this is not a constraint for you, there is a substantially simpler # way to implement the functionality in this PR: instead of using # eval/exec directly, just always write a Python file to filesystem # and compile that. # # - You have control over a context where the compiled code will get # executed, so that we can interpose while the stack is unwinding # (otherwise, we have no way to interpose on the exception printing # process.) # # There are two things you have to do to make use of the utilities here: # # - When you compile your source code, you must save its string source # in its f_globals under the magic name "__compile_source__" # # - Before running the compiled code, enter the # report_compile_source_on_error() context manager. @contextlib.contextmanager def report_compile_source_on_error(): try: yield except Exception as exc: tb = exc.__traceback__ # Walk the traceback, looking for frames that have # source attached stack = [] while tb is not None: filename = tb.tb_frame.f_code.co_filename source = tb.tb_frame.f_globals.get("__compile_source__") if filename == "" and source is not None: # What black magic are we doing here? Intuitively, what # we would like to do is overwrite the co_filename on any # frames that were generated from exec/eval so that they # point to a temporary file that has the actual line # information, so Python's default error printer can print # useful line information on it. # # Writing out the temporary file is easy. But overwriting # co_filename is not! You can't modify the code object # associated with a frame. You can, however, reconstruct # a traceback with entirely new frames from scratch, so that's # what we do. But there's another problem, which is how to # make the frame? # # The black magic is we make a frankenstein frame and code # object which resembles the original frame/code enough so # that it will print properly under traceback and the default # error printer, but IT IS NOT THE ORIGINAL FRAME (you # couldn't, e.g., execute its code with different variables # and expect it to work.) # Don't delete the temporary file so the user can inspect it # TODO: This creates a temporary file for every frame, but we # technically only need one per distinct __compile_source__ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f: f.write(source) # Create a frame. Python doesn't let you construct # FrameType directly, so just make one with compile frame = tb.tb_frame code = compile('__inspect_currentframe()', f.name, 'eval') code = code.replace(co_name=frame.f_code.co_name) # Python 3.11 only if hasattr(frame.f_code, 'co_linetable'): # We can't copy ALL of the metadata over, because you # can cause Python to segfault this way. What exactly # do we need? We need enough information for # traceback to be able to print the exception # correctly. Code reading Lib/traceback.py reveals # that traceback calls code.co_positions() in order to # get the augmented line/col numbers. Objects/codeobject.c, # specifically _PyCode_InitAddressRange, reveals that # this iterator is initialized from co_linetable and # co_firstfileno. So copy these we must! code = code.replace( # type: ignore[call-arg] co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined] co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined] ) fake_frame = eval( code, frame.f_globals, { **frame.f_locals, '__inspect_currentframe': inspect.currentframe } ) fake_tb = TracebackType( None, fake_frame, tb.tb_lasti, tb.tb_lineno ) stack.append(fake_tb) else: stack.append(tb) tb = tb.tb_next # Reconstruct the linked list tb_next = None for tb in reversed(stack): tb.tb_next = tb_next tb_next = tb raise exc.with_traceback(tb_next) # noqa: TRY200 def shorten_filename(fn, *, base=None): """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" if base is None: base = os.path.dirname(os.path.dirname(__file__)) # Truncate torch/foo.py to foo.py try: prefix = os.path.commonpath([fn, base]) except ValueError: return fn else: return fn[len(prefix) + 1:] def format_frame(frame, *, base=None, line=False): """ Format a FrameSummary in a short way, without printing full absolute path or code. The idea is the result fits on a single line. """ extra_line = "" if line: extra_line = f"{frame.line} # " return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" def format_traceback_short(tb): """Format a TracebackType in a short way, printing only the inner-most frame.""" return format_frame(traceback.extract_tb(tb)[-1]) class CapturedTraceback: __slots__ = ['tb', 'skip'] def __init__(self, tb, skip=0): self.tb = tb self.skip = skip def cleanup(self): self.tb = None def summary(self): import torch._C._profiler if self.tb is None: # TODO: Maybe indicate that the traceback was elided? return traceback.StackSummary() return _extract_symbolized_tb( torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip ) def __getstate__(self): return (None, { 'tb': None, # TB is not pickleable 'skip': self.skip, }) @staticmethod def extract(*, script=False, cpp=False, skip=0): """ Like traceback.extract_stack(), but faster (approximately 20x faster); it is fast enough that you can unconditionally log stacks this way as part of normal execution. It returns a torch._C._profiler.CapturedTraceback object that must be formatted specially with format_captured_tb. By default, this only reports Python backtraces (like extract_stack). You can set the script/cpp kwargs to also turn on TorchScript/C++ trace reporting. """ import torch._C._profiler if script or cpp: assert skip == 0, "skip with script/cpp NYI" return CapturedTraceback( torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), # Elide extract() frame if we don't have script/cpp frames. If # we do have those frames, it doesn't work so force zero. 0 if script or cpp else skip + 1 ) def format(self): """ Formats a single torch._C._profiler.CapturedTraceback into a list of strings equivalent to the output of traceback.format_list. Note that if pass it CapturedTraceback with C++ traces, it is better not to use this function and use the batch formatting API format_captured_tbs to amortize the cost of symbolization """ return traceback.format_list(self.summary()) @staticmethod def format_all(tbs): """ Bulk version of CapturedTraceback.format. Returns a list of list of strings. """ import torch._C._profiler # Directly populate tracebacks that already have cached summaries rs: List[Optional[List[str]]] = [] delayed_idxs = [] for i, tb in enumerate(tbs): if tb.tb is None: rs.append([]) else: rs.append(None) delayed_idxs.append(i) stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) for i, stb in zip(delayed_idxs, stbs): rs[i] = traceback.format_list(tbs[i].summary()) return rs def _extract_symbolized_tb(tb, skip): """ Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of pre-processed stack trace entries. """ stack = traceback.StackSummary() for f in reversed(tb[skip:]): stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name'])) return stack