You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.3 KiB
60 lines
2.3 KiB
5 months ago
|
import contextlib
|
||
|
|
||
|
import torch
|
||
|
|
||
|
__all__ = ["start", "stop", "profile"]
|
||
|
|
||
|
|
||
|
def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
|
||
|
r"""Start OS Signpost tracing from MPS backend.
|
||
|
|
||
|
The generated OS Signposts could be recorded and viewed in
|
||
|
XCode Instruments Logging tool.
|
||
|
|
||
|
Args:
|
||
|
mode(str): OS Signpost tracing mode could be "interval", "event",
|
||
|
or both "interval,event".
|
||
|
The interval mode traces the duration of execution of the operations,
|
||
|
whereas event mode marks the completion of executions.
|
||
|
See document `Recording Performance Data`_ for more info.
|
||
|
wait_until_completed(bool): Waits until the MPS Stream complete
|
||
|
executing each encoded GPU operation. This helps generating single
|
||
|
dispatches on the trace's timeline.
|
||
|
Note that enabling this option would affect the performance negatively.
|
||
|
|
||
|
.. _Recording Performance Data:
|
||
|
https://developer.apple.com/documentation/os/logging/recording_performance_data
|
||
|
"""
|
||
|
mode_normalized = mode.lower().replace(" ", "")
|
||
|
torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
|
||
|
|
||
|
|
||
|
def stop():
|
||
|
r"""Stops generating OS Signpost tracing from MPS backend."""
|
||
|
torch._C._mps_profilerStopTrace()
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def profile(mode: str = "interval", wait_until_completed: bool = False):
|
||
|
r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
|
||
|
|
||
|
Args:
|
||
|
mode(str): OS Signpost tracing mode could be "interval", "event",
|
||
|
or both "interval,event".
|
||
|
The interval mode traces the duration of execution of the operations,
|
||
|
whereas event mode marks the completion of executions.
|
||
|
See document `Recording Performance Data`_ for more info.
|
||
|
wait_until_completed(bool): Waits until the MPS Stream complete
|
||
|
executing each encoded GPU operation. This helps generating single
|
||
|
dispatches on the trace's timeline.
|
||
|
Note that enabling this option would affect the performance negatively.
|
||
|
|
||
|
.. _Recording Performance Data:
|
||
|
https://developer.apple.com/documentation/os/logging/recording_performance_data
|
||
|
"""
|
||
|
try:
|
||
|
start(mode, wait_until_completed)
|
||
|
yield
|
||
|
finally:
|
||
|
stop()
|