You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
3.1 KiB

5 months ago
import functools
import hashlib
import os
from torch._dynamo.device_interface import get_interface_for_device
@functools.lru_cache(None)
def has_triton_package() -> bool:
try:
import triton
return triton is not None
except ImportError:
return False
@functools.lru_cache(None)
def has_triton() -> bool:
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
triton_supported_devices = {"cuda": cuda_extra_check}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
device_interface = get_interface_for_device(device)
if device_interface.is_available() and extra_check(device_interface):
return True
return False
return is_device_compatible_with_triton() and has_triton_package()
@functools.lru_cache(None)
def triton_backend_hash():
from triton.common.backend import get_backend, get_cuda_version_key
import torch
if torch.version.hip:
# Does not work with ROCm
return None
if not torch.cuda.is_available():
return None
backend = get_backend("cuda")
if backend is None:
return get_cuda_version_key()
else:
return backend.get_version_key()
@functools.lru_cache
def triton_key():
import pkgutil
import triton
TRITON_PATH = os.path.dirname(os.path.abspath(triton.__file__))
contents = []
# This is redundant. Doing it to be consistent with upstream.
# frontend
with open(os.path.join(TRITON_PATH, "compiler", "compiler.py"), "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, "compiler")
backends_path = os.path.join(TRITON_PATH, "compiler", "backends")
for lib in pkgutil.iter_modules([compiler_path, backends_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, "language")
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
from triton import __version__
return f"{__version__}" + "-".join(contents)
@functools.lru_cache(None)
def triton_hash_with_backend():
import torch
if torch.version.hip:
# Does not work with ROCm
return None
backend_hash = triton_backend_hash()
key = f"{triton_key()}-{backend_hash}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()