You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

374 lines
13 KiB

5 months ago
import functools
import re
from collections import deque
from dataclasses import dataclass
from typing import Dict, List
from torch.autograd import _KinetoEvent
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
order = reversed if reverse else lambda x: x
remaining = deque(order(tree))
while remaining:
curr_event = next_fn(remaining)
yield curr_event
for child_event in order(children_fn(curr_event)):
remaining.append(child_event)
traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True)
traverse_bfs = functools.partial(
_traverse, next_fn=lambda x: x.popleft(), reverse=False
)
@dataclass
class EventMetrics:
duration_time_ns: int = 0
self_time_ns: int = 0
idle_time_ns: int = 0
queue_depth: int = 0
@property
def fraction_idle_time(self):
if self.duration_time_ns == 0:
return 0.0
return self.idle_time_ns / self.duration_time_ns
@dataclass
class Interval:
start: int
end: int
queue_depth: int = 0
class EventKey:
def __init__(self, event):
self.event = event
def __hash__(self):
return hash(self.event.id)
def __eq__(self, other):
return self.event.id == other.event.id
def __repr__(self):
return f"{self.event.name}"
def intervals_overlap(self, intervals: List[Interval]):
overlap_time = 0
intervals = sorted(intervals, key=lambda x: x.start)
if intervals:
overlap_start = max(self.event.start_time_ns, intervals[0].start)
overlap_end = min(self.event.end_time_ns, intervals[0].end)
if overlap_start < overlap_end:
overlap_time += overlap_end - overlap_start
i, j = 0, 1
while j < len(intervals):
prev_interval = intervals[i]
curr_interval = intervals[j]
j += 1
if prev_interval.end > curr_interval.start:
# Completely subsumed by previous interval
if prev_interval.end > curr_interval.end:
j += 1
continue
else:
curr_interval.start = prev_interval.end
i = j
overlap_start = max(self.event.start_time_ns, curr_interval.start)
overlap_end = min(self.event.end_time_ns, curr_interval.end)
if overlap_start < overlap_end:
overlap_time += overlap_end - overlap_start
return overlap_time
class BasicEvaluation:
def __init__(self, prof: profile):
self.profile = prof
self.metrics: Dict[EventKey, EventMetrics] = {}
self.compute_self_time()
self.event_keys = sorted(
(e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns
)
self.events = [e.event for e in self.event_keys]
self.cuda_events: List[_KinetoEvent] = []
self.queue_depth_list = self.compute_queue_depth()
self.compute_idle_time()
def compute_self_time(self):
"""
Computes event's self time(total time - time in child ops).
"""
assert self.profile.kineto_results is not None
stack = deque(self.profile.kineto_results.experimental_event_tree())
# standard iterating dfs
while stack:
curr_event = stack.pop()
self_time = curr_event.duration_time_ns
for child_event in curr_event.children:
self_time -= child_event.duration_time_ns
stack.append(child_event)
assert (
EventKey(curr_event) not in self.metrics
), f"Duplicate id: {curr_event.id}, {curr_event.name}"
self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
self.metrics[
EventKey(curr_event)
].duration_time_ns = curr_event.duration_time_ns
def compute_queue_depth(self):
"""
Computes queue_depth at each event. This will calculate the queue depth data for
All the events in the tree.
This will return a list of Interval of queue depth data of cuda launch and kernels.
"""
assert self.profile.kineto_results is not None
cuda_event_list = self.profile.kineto_results.events()
def is_cuda_launch_kernel(e):
# TODO: find a better way to identify cudaLaunchKernel
return e.name == "cudaLaunchKernel"
def is_cuda_kernel(e):
# TODO: find a better way to identify CUDA Kernel
return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower()
cuda_launch_events = sorted(
(e for e in cuda_event_list if is_cuda_launch_kernel(e)),
key=lambda x: x.start_us(),
)
cuda_kernel_events = sorted(
(e for e in cuda_event_list if is_cuda_kernel(e)),
key=lambda x: x.start_us(),
)
self.cuda_events = sorted(
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_us()
)
kernel_mapping: Dict[_KinetoEvent, int] = {}
last_mapped_kernel = 0
for cuda_launch_event in cuda_launch_events:
index = index_of_first_match(
cuda_kernel_events,
lambda x: x.linked_correlation_id()
== cuda_launch_event.linked_correlation_id(),
start=last_mapped_kernel,
)
kernel_mapping[cuda_launch_event] = index
last_mapped_kernel = index if index is not None else last_mapped_kernel
current_kernel_index = 0
spawned_kernel_index = -1
all_events = cuda_launch_events + cuda_kernel_events + self.events
def new_old_event_comparator(event):
if hasattr(event, "start_us"):
return event.start_us() * 1000
if hasattr(event, "start_time_ns"):
return event.start_time_ns
raise Exception("Unknown Event Type")
queue_depth_list: List[Interval] = []
all_events.sort(key=new_old_event_comparator)
for event in all_events:
# Find latest cuda kernel event
if hasattr(event, "start_us"):
start_time = event.start_us() * 1000
end_time = (event.start_us() + event.duration_us()) * 1000
# Find current spawned cuda kernel event
if event in kernel_mapping and kernel_mapping[event] is not None:
spawned_kernel_index = kernel_mapping[event]
elif hasattr(event, "start_time_ns"):
start_time = event.start_time_ns # type: ignore[attr-defined]
end_time = event.end_time_ns # type: ignore[attr-defined]
while (
current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
<= start_time # type: ignore[possibly-undefined]
):
current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
current_queue_depth = max(current_queue_depth, 0)
if hasattr(event, "start_us"):
queue_depth_list.append(
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
)
elif hasattr(event, "start_time_ns"):
self.metrics[EventKey(event)].queue_depth = current_queue_depth
return queue_depth_list
def compute_idle_time(self):
"""
Computes idle time of the profile.
"""
# Based on queue_depth_list, we can calculate idle time for all the events
idle = False
idle_start = 0
idle_intervals: List[Interval] = []
if self.queue_depth_list and self.events:
idle_intervals += [
Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),
Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns),
]
for data_point in self.queue_depth_list:
if data_point.queue_depth == 0 and not idle:
idle_start = data_point.end
idle = True
if data_point.queue_depth > 0 and idle:
idle_intervals.append(Interval(idle_start, data_point.start))
idle = False
event_list = [e.event for e in self.metrics.keys()]
for event in event_list:
self.metrics[EventKey(event)].idle_time_ns = EventKey(
event
).intervals_overlap(idle_intervals)
def rank_events(self, length):
"""
Filter and Rank the events based on some heuristics:
1) Events that are in the falling phase of the queue depth.
2) Events that have a high idle_time, self_time difference.
Parameters:
length: The number of events to return.
"""
# Find the interval when qd is falling to 0
import torch
queue_depth_list = list(reversed(self.queue_depth_list))
qd_values = [e.queue_depth for e in queue_depth_list]
bottom_threashold = 0
top_threashold = 4
decrease_interval = []
i = 0
while i < len(qd_values):
if qd_values[i] > bottom_threashold:
i += 1
continue
for j in range(i + 1, len(qd_values)):
# Find next zero and if the max value between them exceeds
# the threshold, then we have a falling interval
next_minimum_idx = index_of_first_match(
qd_values, lambda x: x <= bottom_threashold, start=j
)
peak_idx = argmax(qd_values, start=j, end=next_minimum_idx)
# if is a valid peak, we add to list and continue
if peak_idx is not None and qd_values[peak_idx] >= top_threashold:
decrease_interval.append(
Interval(
queue_depth_list[peak_idx].start, queue_depth_list[i].start
)
)
i = next_minimum_idx if next_minimum_idx is not None else i
break
i += 1
# Filter out events that are not in the decrease interval
event_list = [
event
for event in self.metrics.keys()
if event.intervals_overlap(decrease_interval)
]
if event_list:
self_time = torch.tensor(
[self.metrics[event].self_time_ns for event in event_list],
dtype=torch.float32,
)
idle_time = torch.tensor(
[self.metrics[event].fraction_idle_time for event in event_list],
dtype=torch.float32,
)
normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time)
normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time)
heuristic_score_list = normalized_gain + 0.6 * normalized_self
# Sort events by heuristic
event_list = [
event
for _, event in sorted(
zip(heuristic_score_list, event_list),
key=lambda x: x[0],
reverse=True,
)
]
event_list = event_list[:length]
return event_list
def get_optimizable_events(self, length: int = 1, print_enable: bool = True):
event_list = self.rank_events(length)
if not print_enable:
return event_list
output = "Optimizable events:\n" if event_list else "No events to optimize\n"
output += "\n".join(
[
f"""{'-'*80}
Event: {event}
Source code location: {source_code_location(event.event)}
Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
{'-'*80}"""
for event in event_list
]
)
if print_enable:
print(output)
return event_list
def index_of_first_match(seq, predicate, start=0, end=None):
if end is None or end >= len(seq):
end = len(seq)
for i in range(start, end):
if predicate(seq[i]):
return i
return None
def argmax(seq, key=lambda x: x, start=0, end=None):
seq = seq[start:end]
if len(seq) == 0:
return None
return seq.index(max(seq, key=key)) + start
def source_code_location(event):
while event is not None:
match = re.search(r"\.py\(.*\)", event.name)
if match is None:
event = event.parent
continue
return event.name
return "No source code location found"
# Provide an OSS workaround for cudagraphs + CUPTI issue
# https://github.com/pytorch/pytorch/issues/75504
# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when
# we stop supporting older CUDA versions.
def _init_for_cuda_graphs():
from torch.autograd.profiler import profile
with profile():
pass