Add support for fetching ROCm kernels (#59)

This commit is contained in:
Daniël de Kok
2025-03-25 15:11:03 +01:00
committed by GitHub
parent 3808108d62
commit ff55bc201b

View File

@ -23,18 +23,21 @@ CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
def build_variant() -> str:
import torch
if torch.version.cuda is None:
raise AssertionError(
"This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
)
if torch.version.cuda is not None:
cuda_version = parse(torch.version.cuda)
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
else:
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
torch_version = parse(torch.__version__)
cuda_version = parse(torch.version.cuda)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
def universal_build_variant() -> str: