You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

623 lines
22 KiB

r"""
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
It stores information on accesses to tensors to determine if they are synchronized
or not. When enabled in a python program and a possible data race is detected, a
detailed warning will be printed and the program will exit.
It can be enabled either by importing this module and calling
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
environment variable.
"""
import enum
import functools
import inspect
import io
import logging
import sys
import textwrap
import traceback
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
import torch
import torch.utils._cuda_trace as cuda_trace
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
DEFAULT_STREAM_ID = 0
TK = TypeVar("TK")
TVa = TypeVar("TVa")
TVb = TypeVar("TVb")
DataPtr = int
StreamId = int
EventId = int
SeqNum = int
logger = logging.getLogger(__name__)
class AccessType(enum.Enum):
READ = enum.auto()
WRITE = enum.auto()
def __str__(self):
return "reading from" if self is AccessType.READ else "writing to"
@dataclass
class Access:
r"""Stores information about a single access to a tensor by a kernel.
Args:
type: either AccessType.READ or AccessType.Write.
seq_num: the sequential number of the kernel performing the access.
stream: the stream id of the stream executing the kernel.
operator: the schema of the launched kernel, which lists the
arguments and return type.
aliases: the arguments in the schema this access corresponds to.
is_output: Whether the tensor was an output of the kernel.
stack_trace: the stack summary object captured during access.
"""
type: AccessType
seq_num: SeqNum
stream: StreamId
operator: str
aliases: List[str]
is_output: bool
stack_trace: traceback.StackSummary
class SynchronizationError(Exception):
"""Base class for errors detected by CUDA Sanitizer."""
pass
class UnsynchronizedAccessError(SynchronizationError):
"""Stores information about two unsynchronized accesses to one data pointer."""
def __init__(
self,
data_ptr: DataPtr,
allocation_stack_trace: Optional[traceback.StackSummary],
current_access: Access,
previous_access: Access,
):
self.data_ptr = data_ptr
self.allocation_stack_trace = allocation_stack_trace
self.current_access = current_access
self.previous_access = previous_access
def __str__(self):
def format_access(access: Access):
message.write(f"{access.operator}\n{access.type}")
if access.aliases:
message.write(" argument(s) " + ", ".join(access.aliases))
if access.is_output:
message.write(", and to")
if access.is_output:
message.write(" the output")
message.write(
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
)
with io.StringIO() as message:
message.write(
textwrap.dedent(
f"""\
============================
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
Access by stream {self.current_access.stream} during kernel:
"""
)
)
format_access(self.current_access)
message.write(
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
)
format_access(self.previous_access)
if self.allocation_stack_trace:
message.write(
"Tensor was allocated with stack trace:\n"
f"{''.join(self.allocation_stack_trace.format())}"
)
else:
message.write("Trace for tensor allocation not found.")
return message.getvalue()
class CUDASanitizerErrors(Exception):
"""Wrapper class for errors reported by CUDA Sanitizer."""
def __init__(self, errors: List[SynchronizationError]):
self.errors = errors
def __str__(self):
return f"detected {len(self.errors)} errors"
@dataclass
class TensorInfo:
r"""Stores information about a single tensor and recent accesses to it.
Args:
allocation_stack_trace: the stack summary object captured during tensor
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
reads: list of read accesses to the tensor that were performed since
the last write.
write: the last write access to the tensor.
"""
allocation_stack_trace: Optional[traceback.StackSummary]
reads: List[Access] = field(default_factory=list)
write: Optional[Access] = None
class _TensorsAccessed:
def __init__(self):
self.accesses: Dict[DataPtr, TensorInfo] = {}
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
if data_ptr not in self.accesses:
logger.info(
"Found tensor with pointer: %s, but no matching tensor "
"allocation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
data_ptr,
)
self.create_tensor(data_ptr, None)
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
if data_ptr in self.accesses:
logger.info(
"Found duplicate tensor allocation in the trace for tensor with "
"pointer: %s. Assuming the trace for tensor deallocation "
"wasn't caught and backfilling it now. "
"Perhaps the sanitizer was enabled after some torch operations?",
data_ptr,
)
self.delete_tensor(data_ptr)
def create_tensor(
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
) -> None:
self.accesses[data_ptr] = TensorInfo(stack_trace)
def delete_tensor(self, data_ptr: DataPtr) -> None:
del self.accesses[data_ptr]
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
return True if self.accesses[data_ptr].reads else False
def get_allocation_stack_trace(
self, data_ptr: DataPtr
) -> Optional[traceback.StackSummary]:
return self.accesses[data_ptr].allocation_stack_trace
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
return self.accesses[data_ptr].write
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
return self.accesses[data_ptr].reads
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
self.accesses[data_ptr].reads.append(access)
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
self.accesses[data_ptr].write = access
self.accesses[data_ptr].reads = []
class StreamSynchronizations:
def __init__(self):
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
self.host_sync_state: Dict[StreamId, SeqNum] = {}
self.create_stream(DEFAULT_STREAM_ID)
def _ensure_stream_exists(self, stream: StreamId) -> None:
if stream not in self.current_sync_states:
logger.info(
"Found Stream with id: %s, but no matching stream "
"creation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
stream,
)
self.create_stream(stream)
def _ensure_event_exists(self, event: EventId) -> None:
if event not in self.recorded_sync_states:
logger.info(
"Found Event with id: %s, but no matching event "
"creation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
event,
)
self.create_event(event)
def _ensure_event_does_not_exist(self, event: EventId) -> None:
if event in self.recorded_sync_states:
logger.info(
"Found duplicate event creation in the trace for event with "
"id: %s. Assuming the trace for event deletion wasn't caught "
"and backfilling it now. "
"Perhaps the sanitizer was enabled after some torch operations?",
event,
)
self.delete_event(event)
def create_stream(self, stream: StreamId) -> None:
if stream in self.current_sync_states:
logger.info(
"Found duplicate Stream creation in the trace for Stream with "
"id: %s. PyTorch Streams are only created once, so this "
"trace entry is ignored.",
stream,
)
else:
self.host_sync_state[stream] = 0
self.current_sync_states[stream] = self.host_sync_state.copy()
def create_event(self, event: EventId) -> None:
self._ensure_event_does_not_exist(event)
self.recorded_sync_states[event] = {}
def delete_event(self, event: EventId) -> None:
self._ensure_event_exists(event)
del self.recorded_sync_states[event]
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
self._ensure_stream_exists(stream)
self.current_sync_states[stream][stream] = seq_num
def record_state(self, event: EventId, stream: StreamId) -> None:
self._ensure_event_exists(event)
self._ensure_stream_exists(stream)
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
def _state_wait_for_other(
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
) -> None:
for stream, seq_num in other.items():
state[stream] = max(state.get(stream, -1), seq_num)
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
self._ensure_stream_exists(stream)
self._ensure_event_exists(event)
self._state_wait_for_other(
self.current_sync_states[stream], self.recorded_sync_states[event]
)
def all_streams_wait_for_event(self, event: EventId) -> None:
self._ensure_event_exists(event)
for stream in self.current_sync_states.keys():
self.stream_wait_for_event(stream, event)
self._state_wait_for_other(
self.host_sync_state, self.recorded_sync_states[event]
)
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
self._ensure_stream_exists(stream)
for state in self.current_sync_states.values():
self._state_wait_for_other(state, self.current_sync_states[stream])
self._state_wait_for_other(
self.host_sync_state, self.current_sync_states[stream]
)
def sync_all_streams(self) -> None:
for stream, state in self.current_sync_states.items():
self.host_sync_state[stream] = state[stream]
for state in self.current_sync_states.values():
self._state_wait_for_other(state, self.host_sync_state)
def is_ordered_after(
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
) -> bool:
self._ensure_stream_exists(current_stream)
self._ensure_stream_exists(other_stream)
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
class EventHandler:
"""Analyzes CSAN trace for synchronization errors.
Stores information on each stream's synchronizations with other streams as well
as tensor accesses to determine whether a given kernel launch might cause a
data race.
"""
def __init__(self):
self.tensors_accessed = _TensorsAccessed()
self.syncs = StreamSynchronizations()
self.seq_num: SeqNum = 0
def _handle_kernel_launch(
self,
stream: StreamId,
read_only: Set[DataPtr],
read_write: Set[DataPtr],
outputs: Set[DataPtr],
operator: str,
tensor_aliases: Dict[int, List[str]],
) -> List[SynchronizationError]:
def check_conflict(
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
) -> None:
if previous_access is None:
return
if not self.syncs.is_ordered_after(
current_access.stream, previous_access.seq_num, previous_access.stream
):
error_list.append(
UnsynchronizedAccessError(
data_ptr,
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
current_access,
previous_access,
)
)
error_list: List[SynchronizationError] = []
self.seq_num += 1
self.syncs.update_seq_num(stream, self.seq_num)
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
)
# The stack trace generated in this way is in the inverse order, so it must be
# reversed.
stack_trace.reverse()
for data_ptr in read_only:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
current_access = Access(
AccessType.READ,
self.seq_num,
stream,
operator,
tensor_aliases[data_ptr],
data_ptr in outputs,
stack_trace,
)
check_conflict(
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
)
self.tensors_accessed.add_read(data_ptr, current_access)
for data_ptr in read_write:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
current_access = Access(
AccessType.WRITE,
self.seq_num,
stream,
operator,
tensor_aliases[data_ptr],
data_ptr in outputs,
stack_trace,
)
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
for previous_access in self.tensors_accessed.get_reads(data_ptr):
check_conflict(data_ptr, current_access, previous_access)
else:
check_conflict(
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
)
self.tensors_accessed.set_write(data_ptr, current_access)
return error_list
def _handle_event_creation(self, event: EventId) -> None:
self.syncs.create_event(event)
def _handle_event_deletion(self, event: EventId) -> None:
self.syncs.delete_event(event)
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
self.syncs.record_state(event, stream)
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
self.syncs.stream_wait_for_event(stream, event)
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
)
# The stack trace generated in this way is in the inverse order, so it must be
# reversed.
stack_trace.reverse()
self.tensors_accessed.create_tensor(
data_ptr,
stack_trace,
)
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
self.tensors_accessed.delete_tensor(data_ptr)
def _handle_stream_creation(self, stream: StreamId) -> None:
self.syncs.create_stream(stream)
def _handle_device_synchronization(self) -> None:
self.syncs.sync_all_streams()
def _handle_stream_synchronization(self, stream: StreamId) -> None:
self.syncs.all_streams_wait_for_stream(stream)
def _handle_event_synchronization(self, event: EventId) -> None:
self.syncs.all_streams_wait_for_event(event)
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
for arg, value in a.items():
if arg in b:
yield arg, value, b[arg]
def zip_arguments(
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterator[Tuple[torch.Argument, Any]]:
schema_args = schema.arguments[: len(args)]
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
yield from zip(schema_args, args)
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
yield (argument, value)
class ArgumentHandler:
def __init__(self):
self.dataptrs_read: Set[DataPtr] = set()
self.dataptrs_written: Set[DataPtr] = set()
self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
self.outputs: Set[DataPtr] = set()
def _handle_argument(
self,
value: Any,
is_write: bool,
name: Optional[str] = None,
is_output: bool = False,
) -> None:
if isinstance(value, torch.Tensor) and value.is_cuda:
data_ptr = value.data_ptr()
if is_write:
self.dataptrs_written.add(data_ptr)
else:
self.dataptrs_read.add(data_ptr)
self.tensor_aliases.setdefault(data_ptr, [])
if name is not None:
self.tensor_aliases[data_ptr].append(name)
if is_output:
self.outputs.add(data_ptr)
def parse_inputs(
self,
schema: torch.FunctionSchema,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
for argument, value in zip_arguments(schema, args, kwargs):
is_write = argument.alias_info is not None and argument.alias_info.is_write
pytree.tree_map_(
functools.partial(
self._handle_argument, is_write=is_write, name=argument.name
),
value,
)
def parse_outputs(self, outputs: Any) -> None:
pytree.tree_map_(
functools.partial(self._handle_argument, is_write=True, is_output=True),
outputs,
)
class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self):
self.event_handler = EventHandler()
torch._C._activate_cuda_trace()
cuda_trace.register_callback_for_cuda_event_creation(
self.event_handler._handle_event_creation
)
cuda_trace.register_callback_for_cuda_event_deletion(
self.event_handler._handle_event_deletion
)
cuda_trace.register_callback_for_cuda_event_record(
self.event_handler._handle_event_record
)
cuda_trace.register_callback_for_cuda_event_wait(
self.event_handler._handle_event_wait
)
cuda_trace.register_callback_for_cuda_memory_allocation(
self.event_handler._handle_memory_allocation
)
cuda_trace.register_callback_for_cuda_memory_deallocation(
self.event_handler._handle_memory_deallocation
)
cuda_trace.register_callback_for_cuda_stream_creation(
self.event_handler._handle_stream_creation
)
cuda_trace.register_callback_for_cuda_device_synchronization(
self.event_handler._handle_device_synchronization
)
cuda_trace.register_callback_for_cuda_stream_synchronization(
self.event_handler._handle_stream_synchronization
)
cuda_trace.register_callback_for_cuda_event_synchronization(
self.event_handler._handle_event_synchronization
)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
argument_handler = ArgumentHandler()
argument_handler.parse_inputs(func._schema, args, kwargs)
outputs = func(*args, **kwargs)
argument_handler.parse_outputs(outputs)
errors = self.event_handler._handle_kernel_launch(
torch.cuda.current_stream().cuda_stream,
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
argument_handler.dataptrs_written,
argument_handler.outputs,
func._schema,
argument_handler.tensor_aliases,
)
if errors:
for error in errors:
print(error, file=sys.stderr)
raise CUDASanitizerErrors(errors)
return outputs
class CUDASanitizer:
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
context manager in the enable function/destructor, respectively. This is to
explicitly set the lifetime of the dispatch mode object to that of the application.
This approach was deemed more elegant than using the atexit module.
"""
def __init__(self):
self.dispatch = CUDASanitizerDispatchMode()
self.enabled = False
def enable(self):
self.dispatch.__enter__()
self.enabled = True
def __del__(self):
if self.enabled:
self.dispatch.__exit__(None, None, None)
def enable_cuda_sanitizer():
"""Enable CUDA Sanitizer.
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
for synchronization errors. All data races found will be printed to the standard
error output along with stack traces of suspected causes. For best results, the
sanitizer should be enabled at the very beginning of the program.
"""
cuda_sanitizer.enable()
cuda_sanitizer = CUDASanitizer()