mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-26 16:44:27 +08:00
Compare commits
2 Commits
v0.9.0
...
fix-packag
| Author | SHA1 | Date | |
|---|---|---|---|
| 7adbe421c8 | |||
| ce6a283d2f |
@ -23,6 +23,9 @@ CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
def build_variant():
|
||||
import torch
|
||||
|
||||
if torch.version.cuda is None:
|
||||
raise AssertionError("This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled.")
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
@ -50,9 +53,8 @@ def install_kernel(
|
||||
repo_id: str, revision: str, local_files_only: bool = False
|
||||
) -> Tuple[str, str]:
|
||||
"""Download a kernel for the current environment to the cache."""
|
||||
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
|
||||
"torch"
|
||||
]["name"]
|
||||
package_name = repo_id.split('/')[-1]
|
||||
package_name = package_name.replace('-', '_')
|
||||
repo_path = snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=f"build/{build_variant()}/*",
|
||||
|
||||
Reference in New Issue
Block a user