You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

304 lines
11 KiB

import itertools
from warnings import warn
import torch
import torch.cuda
from torch.autograd import (
_disable_profiler_legacy,
_enable_profiler_legacy,
DeviceType,
ProfilerConfig,
ProfilerState,
)
from torch.autograd.profiler_util import (
_filter_name,
_filter_stack_entry,
_rewrite_name,
EventList,
FunctionEvent,
MEMORY_EVENT_NAME,
)
__all__ = ["profile"]
class profile:
"""DEPRECATED: use torch.profiler instead."""
def __init__(
self,
enabled=True,
*,
use_cuda=False,
record_shapes=False,
with_flops=False,
profile_memory=False,
with_stack=False,
with_modules=False,
):
self.enabled: bool = enabled
if not self.enabled:
return
self.use_cuda = use_cuda
self.function_events = None
self.entered = False
self.record_shapes = record_shapes
self.with_flops = with_flops
self.record_shapes |= self.with_flops
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_modules = with_modules
if self.use_cuda and not torch.cuda.is_available():
warn("CUDA is not available, disabling CUDA profiling")
self.use_cuda = False
if self.use_cuda:
self.profiler_kind = ProfilerState.CUDA
else:
self.profiler_kind = ProfilerState.CPU
def config(self):
return ProfilerConfig(
self.profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack,
self.with_flops,
self.with_modules,
# avoid exposing _ExperimentalConfig this in legacy public API
torch._C._profiler._ExperimentalConfig(),
)
def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("Profiler context manager is not reentrant")
self.entered = True
self._start_trace()
return self
def _start_trace(self):
_enable_profiler_legacy(self.config())
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
if self.use_cuda:
torch.cuda.synchronize()
records = _disable_profiler_legacy()
parsed_results = _parse_legacy_records(records)
self.function_events = EventList(
parsed_results,
use_cuda=self.use_cuda,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
)
self.function_events._build_tree()
return False
def __repr__(self):
if self.function_events is None:
return "<unfinished profiler_legacy.profile>"
return repr(self.function_events)
def __str__(self):
if self.function_events is None:
return "<unfinished profile.profiler_legacy.profile>"
return str(self.function_events)
def _check_finish(self):
if self.function_events is None:
raise RuntimeError("Profiler didn't finish running")
def table(
self,
sort_by=None,
row_limit=100,
max_src_column_width=75,
max_name_column_width=55,
max_shapes_column_width=80,
header=None,
top_level_events_only=False,
):
self._check_finish()
assert self.function_events is not None
return self.function_events.table(
sort_by=sort_by,
row_limit=row_limit,
max_src_column_width=max_src_column_width,
max_name_column_width=max_name_column_width,
max_shapes_column_width=max_shapes_column_width,
header=header,
top_level_events_only=top_level_events_only,
)
table.__doc__ = EventList.table.__doc__
def export_chrome_trace(self, path):
self._check_finish()
assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
assert self.with_stack, "export_stacks() requires with_stack=True"
return self.function_events.export_stacks(path, metric)
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
key_averages.__doc__ = EventList.key_averages.__doc__
def total_average(self):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__
@property
def self_cpu_time_total(self):
"""Return CPU time as the sum of self times across all events."""
self._check_finish()
assert self.function_events is not None
return self.function_events.self_cpu_time_total
def _parse_legacy_records(thread_records):
def _get_record_key(record):
"""Return a tuple for correlating start and end records in `_parse_legacy_records`."""
return (record.handle(), record.node_id())
next_id = 0
start_record = None
functions = []
record_stack = []
# '__start_profile' is not guaranteed to be first, so we must find it here
for record in itertools.chain.from_iterable(thread_records):
name = record.name()
if start_record is None and name == "__start_profile":
start_record = record
assert start_record is not None and not start_record.is_remote()
for thread_record_list in thread_records:
# accumulated memory allocations per handle
cpu_memory_allocs = {}
cuda_memory_allocs = {}
# ranges per handle
range_starts = {}
filtered_handles = set()
prev_record = None
for record in thread_record_list:
record_key = _get_record_key(record)
if _filter_name(record.name()) or record_key in filtered_handles:
filtered_handles.add(record_key)
continue
if record.kind() == "push":
# workaround to reduce double logging from operator
# wrappers and redispatch
if prev_record is not None:
duplicate = (
prev_record.name() == record.name()
and prev_record.kind() == record.kind()
and prev_record.node_id() == record.node_id()
)
if duplicate:
filtered_handles.add(record_key)
continue
range_starts[record_key] = record
cpu_memory_allocs[record_key] = 0
cuda_memory_allocs[record_key] = 0
elif record.kind() == "pop":
assert (
record_key in range_starts
), f"""Expected record with key {record_key} to exist in range_starts.
This means that the pop event did not have a corresponding push."""
start = range_starts[record_key]
cpu_memory_usage = cpu_memory_allocs[record_key]
cuda_memory_usage = cuda_memory_allocs[record_key]
is_async = start.is_async() or (start.thread_id() != record.thread_id())
is_remote_event = record.is_remote()
start_flops = start.flops()
fe = FunctionEvent(
id=record.handle(),
node_id=record.node_id(),
name=_rewrite_name(name=start.name(), with_wildcard=True),
trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
thread=start.thread_id(),
start_us=start_record.cpu_elapsed_us(start),
end_us=start_record.cpu_elapsed_us(record),
fwd_thread=start.fwd_thread_id(),
input_shapes=start.shapes(),
stack=[
entry for entry in start.stack() if _filter_stack_entry(entry)
],
scope=start.scope(),
cpu_memory_usage=cpu_memory_usage,
cuda_memory_usage=cuda_memory_usage,
is_async=is_async,
is_remote=is_remote_event,
sequence_nr=start.sequence_nr(),
device_type=DeviceType.CPU,
is_legacy=True,
flops=start_flops,
)
# note: async events have only cpu total time
if not is_async and start.has_cuda():
duration = start.cuda_elapsed_us(record)
if duration > 0:
fe.append_kernel(start.name(), start.device(), duration)
functions.append(fe)
del range_starts[record_key]
del cpu_memory_allocs[record_key]
del cuda_memory_allocs[record_key]
elif record.kind() == "memory_alloc":
num_open_handles_cpu = len(cpu_memory_allocs)
num_open_handles_cuda = len(cuda_memory_allocs)
assert num_open_handles_cpu == num_open_handles_cuda
for handle in cpu_memory_allocs.keys():
cpu_memory_allocs[handle] += record.cpu_memory_usage()
for handle in cuda_memory_allocs.keys():
cuda_memory_allocs[handle] += record.cuda_memory_usage()
if num_open_handles_cpu == 0:
# output event as a top-level memory event
fe = FunctionEvent(
id=0,
name=MEMORY_EVENT_NAME,
trace_name=None,
thread=0,
start_us=0,
end_us=0,
stack=[],
cpu_memory_usage=record.cpu_memory_usage(),
cuda_memory_usage=record.cuda_memory_usage(),
is_legacy=True,
)
functions.append(fe)
prev_record = record
# Sort functions by start time then by end time ascending.
# This ensures that--in the case of nested events which
# have the same start time (which may happen due to the
# granularity of the given clock tick)--we always show
# the outermost nested call first. This adds stability
# in how FunctionEvents appear
functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
return functions