mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-22 05:48:52 +08:00
Compare commits
5 Commits
add-get-lo
...
type-kerne
Author | SHA1 | Date | |
---|---|---|---|
e7173d9a28 | |||
e487945cf6 | |||
338752c367 | |||
53723ea986 | |||
95d3a758a9 |
@ -71,6 +71,24 @@ def universal_build_variant() -> str:
|
||||
return "torch-universal"
|
||||
|
||||
|
||||
# Metaclass to allow overriding the `__repr__` method for kernel modules.
|
||||
class _KernelModuleMeta(type):
|
||||
def __repr__(self):
|
||||
return "<class 'kernel_module'>"
|
||||
|
||||
|
||||
# Custom module type to identify dynamically loaded kernel modules.
|
||||
# Using a subclass lets us distinguish these from regular imports.
|
||||
class _KernelModuleType(ModuleType, metaclass=_KernelModuleMeta):
|
||||
"""Marker class for modules loaded dynamically from a path."""
|
||||
|
||||
module_name: str
|
||||
is_kernel: bool = True
|
||||
|
||||
def __repr__(self):
|
||||
return f"<kernel_module '{self.module_name}' from '{self.__file__}'>"
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
||||
# it would also be used for other imports. So, we make a module name that
|
||||
@ -84,6 +102,9 @@ def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if module is None:
|
||||
raise ImportError(f"Cannot load module {module_name} from spec")
|
||||
module.__class__ = _KernelModuleType
|
||||
assert isinstance(module, _KernelModuleType) # for mypy type checking
|
||||
module.module_name = module_name
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
@ -248,24 +269,8 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
"""
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
# Presume we were given the top level path of the kernel repository.
|
||||
for base_path in [repo_path, repo_path / "build"]:
|
||||
# Prefer the universal variant if it exists.
|
||||
for v in [universal_variant, variant]:
|
||||
package_path = base_path / v / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
# If we didn't find the package in the repo we may have a explicit
|
||||
# package path.
|
||||
package_path = repo_path / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
|
||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(
|
||||
|
@ -10,16 +10,10 @@ def kernel():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel_path():
|
||||
def local_kernel():
|
||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||
# is the kernel repository path.
|
||||
return package_name, path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel(local_kernel_path):
|
||||
package_name, path = local_kernel_path
|
||||
return get_local_kernel(path.parent.parent, package_name)
|
||||
|
||||
|
||||
@ -72,39 +66,6 @@ def test_local_kernel(local_kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_local_kernel_path_types(local_kernel_path, device):
|
||||
package_name, path = local_kernel_path
|
||||
|
||||
# Top-level repo path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
|
||||
kernel = get_local_kernel(path.parent.parent, package_name)
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
kernel.gelu_fast(y, x)
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Build directory path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
|
||||
kernel = get_local_kernel(path.parent.parent / "build", package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Explicit package path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
|
||||
kernel = get_local_kernel(path, package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
|
Reference in New Issue
Block a user