Compare commits

...

2 Commits

Author SHA1 Message Date
7adbe421c8 Raise when CUDA not installed 2025-02-24 14:00:27 +01:00
ce6a283d2f package_name should not depend on build.toml 2025-02-24 13:55:59 +01:00

View File

@ -23,6 +23,9 @@ CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
def build_variant():
import torch
if torch.version.cuda is None:
raise AssertionError("This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled.")
torch_version = parse(torch.__version__)
cuda_version = parse(torch.version.cuda)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
@ -50,9 +53,8 @@ def install_kernel(
repo_id: str, revision: str, local_files_only: bool = False
) -> Tuple[str, str]:
"""Download a kernel for the current environment to the cache."""
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
"torch"
]["name"]
package_name = repo_id.split('/')[-1]
package_name = package_name.replace('-', '_')
repo_path = snapshot_download(
repo_id,
allow_patterns=f"build/{build_variant()}/*",