You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.3 KiB
170 lines
5.3 KiB
5 months ago
|
import ctypes
|
||
|
|
||
|
import torch
|
||
|
from torch._streambase import _EventBase, _StreamBase
|
||
|
from .._utils import _dummy_type
|
||
|
|
||
|
|
||
|
if not hasattr(torch._C, "_XpuStreamBase"):
|
||
|
# Define dummy base classes
|
||
|
torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
|
||
|
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
|
||
|
|
||
|
|
||
|
class Stream(torch._C._XpuStreamBase, _StreamBase):
|
||
|
r"""Wrapper around a XPU stream.
|
||
|
|
||
|
A XPU stream is a linear sequence of execution that belongs to a specific
|
||
|
device, independent from other streams.
|
||
|
|
||
|
Args:
|
||
|
device(torch.device or int, optional): a device on which to allocate
|
||
|
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||
|
integer, this will use the current device.
|
||
|
priority(int, optional): priority of the stream, should be 0 or
|
||
|
negative, where negative numbers indicate higher priority. By default,
|
||
|
streams have priority 0.
|
||
|
"""
|
||
|
|
||
|
def __new__(cls, device=None, priority=0, **kwargs):
|
||
|
# setting device manager is expensive, so we avoid it unless necessary
|
||
|
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||
|
return super().__new__(cls, priority=priority, **kwargs)
|
||
|
else:
|
||
|
with torch.xpu.device(device):
|
||
|
return super().__new__(cls, priority=priority, **kwargs)
|
||
|
|
||
|
def wait_event(self, event):
|
||
|
r"""Make all future work submitted to the stream wait for an event.
|
||
|
|
||
|
Args:
|
||
|
event (torch.xpu.Event): an event to wait for.
|
||
|
"""
|
||
|
event.wait(self)
|
||
|
|
||
|
def wait_stream(self, stream):
|
||
|
r"""Synchronize with another stream.
|
||
|
|
||
|
All future work submitted to this stream will wait until all kernels
|
||
|
submitted to a given stream at the time of call complete.
|
||
|
|
||
|
Args:
|
||
|
stream (Stream): a stream to synchronize.
|
||
|
"""
|
||
|
self.wait_event(stream.record_event())
|
||
|
|
||
|
def record_event(self, event=None):
|
||
|
r"""Record an event.
|
||
|
|
||
|
Args:
|
||
|
event (torch.xpu.Event, optional): event to record. If not given, a new one
|
||
|
will be allocated.
|
||
|
|
||
|
Returns:
|
||
|
Recorded event.
|
||
|
"""
|
||
|
if event is None:
|
||
|
event = Event()
|
||
|
event.record(self)
|
||
|
return event
|
||
|
|
||
|
def query(self):
|
||
|
r"""Check if all the work submitted has been completed.
|
||
|
|
||
|
Returns:
|
||
|
A boolean indicating if all kernels in this stream are completed.
|
||
|
"""
|
||
|
return super().query()
|
||
|
|
||
|
def synchronize(self):
|
||
|
r"""Wait for all the kernels in this stream to complete."""
|
||
|
super().synchronize()
|
||
|
|
||
|
@property
|
||
|
def _as_parameter_(self):
|
||
|
return ctypes.c_void_p(self.sycl_queue)
|
||
|
|
||
|
def __eq__(self, o):
|
||
|
if isinstance(o, Stream):
|
||
|
return super().__eq__(o)
|
||
|
return False
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((self.sycl_queue, self.device))
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
|
||
|
|
||
|
|
||
|
class Event(torch._C._XpuEventBase, _EventBase):
|
||
|
r"""Wrapper around a XPU event.
|
||
|
|
||
|
XPU events are synchronization markers that can be used to monitor the
|
||
|
device's progress, and to synchronize XPU streams.
|
||
|
|
||
|
The underlying XPU events are lazily initialized when the event is first
|
||
|
recorded. After creation, only streams on the same device may record the
|
||
|
event. However, streams on any device can wait on the event.
|
||
|
|
||
|
Args:
|
||
|
enable_timing (bool, optional): indicates if the event should measure time
|
||
|
(default: ``False``)
|
||
|
"""
|
||
|
|
||
|
def __new__(cls, enable_timing=False):
|
||
|
return super().__new__(cls, enable_timing=enable_timing)
|
||
|
|
||
|
def record(self, stream=None):
|
||
|
r"""Record the event in a given stream.
|
||
|
|
||
|
Uses ``torch.xpu.current_stream()`` if no stream is specified. The
|
||
|
stream's device must match the event's device.
|
||
|
"""
|
||
|
if stream is None:
|
||
|
stream = torch.xpu.current_stream()
|
||
|
super().record(stream)
|
||
|
|
||
|
def wait(self, stream=None):
|
||
|
r"""Make all future work submitted to the given stream wait for this event.
|
||
|
|
||
|
Use ``torch.xpu.current_stream()`` if no stream is specified.
|
||
|
"""
|
||
|
if stream is None:
|
||
|
stream = torch.xpu.current_stream()
|
||
|
super().wait(stream)
|
||
|
|
||
|
def query(self):
|
||
|
r"""Check if all work currently captured by event has completed.
|
||
|
|
||
|
Returns:
|
||
|
A boolean indicating if all work currently captured by event has
|
||
|
completed.
|
||
|
"""
|
||
|
return super().query()
|
||
|
|
||
|
def elapsed_time(self, end_event):
|
||
|
r"""Return the time elapsed.
|
||
|
|
||
|
Time reported in milliseconds after the event was recorded and
|
||
|
before the end_event was recorded.
|
||
|
"""
|
||
|
return super().elapsed_time(end_event)
|
||
|
|
||
|
def synchronize(self):
|
||
|
r"""Wait for the event to complete.
|
||
|
|
||
|
Waits until the completion of all work currently captured in this event.
|
||
|
This prevents the CPU thread from proceeding until the event completes.
|
||
|
"""
|
||
|
super().synchronize()
|
||
|
|
||
|
@property
|
||
|
def _as_parameter_(self):
|
||
|
return ctypes.c_void_p(self.sycl_event)
|
||
|
|
||
|
def __repr__(self):
|
||
|
if self.sycl_event:
|
||
|
return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
|
||
|
else:
|
||
|
return "torch.xpu.Event(uninitialized)"
|