You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.4 KiB
92 lines
2.4 KiB
5 months ago
|
r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
|
||
|
|
||
|
from contextlib import contextmanager
|
||
|
|
||
|
try:
|
||
|
from torch._C import _nvtx
|
||
|
except ImportError:
|
||
|
|
||
|
class _NVTXStub:
|
||
|
@staticmethod
|
||
|
def _fail(*args, **kwargs):
|
||
|
raise RuntimeError(
|
||
|
"NVTX functions not installed. Are you sure you have a CUDA build?"
|
||
|
)
|
||
|
|
||
|
rangePushA = _fail
|
||
|
rangePop = _fail
|
||
|
markA = _fail
|
||
|
|
||
|
_nvtx = _NVTXStub() # type: ignore[assignment]
|
||
|
|
||
|
__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
|
||
|
|
||
|
|
||
|
def range_push(msg):
|
||
|
"""
|
||
|
Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
|
||
|
|
||
|
Args:
|
||
|
msg (str): ASCII message to associate with range
|
||
|
"""
|
||
|
return _nvtx.rangePushA(msg)
|
||
|
|
||
|
|
||
|
def range_pop():
|
||
|
"""Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
|
||
|
return _nvtx.rangePop()
|
||
|
|
||
|
|
||
|
def range_start(msg) -> int:
|
||
|
"""
|
||
|
Mark the start of a range with string message. It returns an unique handle
|
||
|
for this range to pass to the corresponding call to rangeEnd().
|
||
|
|
||
|
A key difference between this and range_push/range_pop is that the
|
||
|
range_start/range_end version supports range across threads (start on one
|
||
|
thread and end on another thread).
|
||
|
|
||
|
Returns: A range handle (uint64_t) that can be passed to range_end().
|
||
|
|
||
|
Args:
|
||
|
msg (str): ASCII message to associate with the range.
|
||
|
"""
|
||
|
return _nvtx.rangeStartA(msg)
|
||
|
|
||
|
|
||
|
def range_end(range_id) -> None:
|
||
|
"""
|
||
|
Mark the end of a range for a given range_id.
|
||
|
|
||
|
Args:
|
||
|
range_id (int): an unique handle for the start range.
|
||
|
"""
|
||
|
_nvtx.rangeEnd(range_id)
|
||
|
|
||
|
|
||
|
def mark(msg):
|
||
|
"""
|
||
|
Describe an instantaneous event that occurred at some point.
|
||
|
|
||
|
Args:
|
||
|
msg (str): ASCII message to associate with the event.
|
||
|
"""
|
||
|
return _nvtx.markA(msg)
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def range(msg, *args, **kwargs):
|
||
|
"""
|
||
|
Context manager / decorator that pushes an NVTX range at the beginning
|
||
|
of its scope, and pops it at the end. If extra arguments are given,
|
||
|
they are passed as arguments to msg.format().
|
||
|
|
||
|
Args:
|
||
|
msg (str): message to associate with the range
|
||
|
"""
|
||
|
range_push(msg.format(*args, **kwargs))
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
range_pop()
|