mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Add ROCm device discovery (#122)
* Add ROCm device discovery * Ruff * Address review comments * Ruff * Reorg torch import * Remove redundant import * Apply suggestions from code review Co-authored-by: Daniël de Kok <me@danieldk.eu> * Address review comments * Validat device type * Clean diff * black * Sync test with repo changes * black again --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
@ -135,6 +135,10 @@ kernel_layer_mapping = {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -261,7 +265,6 @@ Capabilities behave as follows:
|
||||
an existing kernel, the new kernel will replace the old kernel.
|
||||
- When there are multiple kernels that support a capability, the kernel
|
||||
with the smaller capability interval will be used. E.g. given:
|
||||
|
||||
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
||||
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
||||
- `kernelize` runs on a system with capability 8.6.
|
||||
@ -271,6 +274,12 @@ Capabilities behave as follows:
|
||||
tend to be more optimized for a specific set of GPUs. **This behavior
|
||||
might still change in the future.**
|
||||
|
||||
### Registering kernels for specific ROCm capabilities
|
||||
|
||||
Registering kernels for the ROCm architecture follows the exact same
|
||||
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
|
||||
a kernel to a range of ROCm capabilities.
|
||||
|
||||
### Loading from a local repository for testing
|
||||
|
||||
The `LocalLayerRepository` class is provided to load a repository from
|
||||
|
@ -1,4 +1,5 @@
|
||||
[pytest]
|
||||
markers =
|
||||
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
|
@ -37,7 +37,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
@ -122,6 +121,8 @@ class Device:
|
||||
"""Create an appropriate repository set for this device type."""
|
||||
if self.type == "cuda":
|
||||
return _CUDARepos()
|
||||
elif self.type == "rocm":
|
||||
return _ROCMRepos()
|
||||
elif self.type == "mps":
|
||||
return _MPSRepos()
|
||||
else:
|
||||
@ -181,6 +182,51 @@ class CUDAProperties:
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ROCMProperties:
|
||||
"""
|
||||
ROCM-specific device properties for capability-based kernel selection.
|
||||
|
||||
This class defines ROCM compute capability constraints for kernel selection, allowing kernels to specify
|
||||
minimum and maximum ROCM compute capabilities they support.
|
||||
|
||||
Args:
|
||||
min_capability (`int`):
|
||||
Minimum ROCM compute capability required (e.g., 75 for compute capability 7.5).
|
||||
max_capability (`int`):
|
||||
Maximum ROCM compute capability supported (e.g., 90 for compute capability 9.0).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kernels import ROCMProperties, Device
|
||||
|
||||
# Define ROCM properties for modern GPUs (compute capability 7.5 to 9.0)
|
||||
rocm_props = ROCMProperties(min_capability=75, max_capability=90)
|
||||
|
||||
# Create a device with these properties
|
||||
device = Device(type="rocm", properties=rocm_props)
|
||||
```
|
||||
|
||||
Note:
|
||||
ROCM compute capabilities are represented as integers where the major and minor versions are concatenated.
|
||||
For example, compute capability 7.5 is represented as 75, and 8.6 is represented as 86.
|
||||
"""
|
||||
|
||||
min_capability: int
|
||||
max_capability: int
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ROCMProperties):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.min_capability == other.min_capability
|
||||
and self.max_capability == other.max_capability
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
class LayerRepositoryProtocol(Protocol):
|
||||
@property
|
||||
def layer_name(self) -> str: ...
|
||||
@ -452,6 +498,46 @@ class _CUDARepos(_DeviceRepos):
|
||||
self.repos_by_capability.insert(min_capability, max_capability, repos)
|
||||
|
||||
|
||||
class _ROCMRepos(_DeviceRepos):
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.repos_by_capability = IntervalTree()
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
capability = _find_capability()
|
||||
return self.repos_by_capability.find_smallest_interval(capability)
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
assert device.properties is None or isinstance(
|
||||
device.properties, ROCMProperties
|
||||
)
|
||||
|
||||
min_capability = (
|
||||
0 if device.properties is None else device.properties.min_capability
|
||||
)
|
||||
max_capability = (
|
||||
sys.maxsize
|
||||
if device.properties is None
|
||||
else device.properties.max_capability
|
||||
)
|
||||
|
||||
self.repos_by_capability.insert(min_capability, max_capability, repos)
|
||||
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "rocm", "mps", "cpu"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
)
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
||||
"_KERNEL_MAPPING", default={}
|
||||
)
|
||||
@ -703,8 +789,8 @@ def kernelize(
|
||||
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
|
||||
kernelizes the model for training with `torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. The device type will be inferred from the model parameters
|
||||
when not provided.
|
||||
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
If set to `False`, an exception will be raised in such cases.
|
||||
@ -746,7 +832,6 @@ def kernelize(
|
||||
kernelized_model = kernelize(model)
|
||||
```
|
||||
"""
|
||||
import torch
|
||||
|
||||
if mode == Mode.FALLBACK:
|
||||
raise ValueError("Mode.FALLBACK can only be used to register kernel mappings.")
|
||||
@ -760,7 +845,8 @@ def kernelize(
|
||||
if device is None:
|
||||
device_type = _find_device(model)
|
||||
elif isinstance(device, str):
|
||||
device_type = Device(type=torch.device(device).type)
|
||||
_validate_device_type(device)
|
||||
device_type = Device(type=device)
|
||||
else:
|
||||
device_type = Device(device.type)
|
||||
|
||||
@ -948,6 +1034,18 @@ def _validate_layer(*, check_cls, cls):
|
||||
)
|
||||
|
||||
|
||||
def _is_cuda_platform():
|
||||
import torch
|
||||
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
def _is_rocm_platform():
|
||||
import torch
|
||||
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def _find_device(model: "nn.Module") -> Device:
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
@ -956,7 +1054,15 @@ def _find_device(model: "nn.Module") -> Device:
|
||||
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
||||
)
|
||||
|
||||
return Device(type=param.device.type)
|
||||
dev_type = param.device.type
|
||||
if dev_type == "cuda":
|
||||
# Refine based on actual platform
|
||||
if _is_rocm_platform():
|
||||
return Device(type="rocm")
|
||||
elif _is_cuda_platform():
|
||||
return Device(type="cuda")
|
||||
|
||||
return Device(type=dev_type)
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
@ -3,11 +3,22 @@ import sys
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
|
||||
has_cuda = (
|
||||
hasattr(torch.version, "cuda")
|
||||
and torch.version.cuda is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
has_rocm = (
|
||||
hasattr(torch.version, "hip")
|
||||
and torch.version.hip is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "cuda_only" in item.keywords and not has_cuda:
|
||||
pytest.skip("skipping CUDA-only test on host without CUDA")
|
||||
if "rocm_only" in item.keywords and not has_rocm:
|
||||
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
|
@ -34,7 +34,11 @@ kernel_layer_mapping = {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-test/op-without-fake-test",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
),
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-test/op-without-fake-test",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
},
|
||||
"SiluAndMulStringDevice": {
|
||||
"cuda": LayerRepository(
|
||||
@ -126,6 +130,55 @@ def test_hub_forward(cls, device):
|
||||
assert silu_and_mul_with_kernel.n_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.rocm_only
|
||||
def test_hub_forward_rocm():
|
||||
torch.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
X = torch.randn((32, 64))
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
SiluAndMulNoCompileKernel(), device="rocm", mode=Mode.INFERENCE
|
||||
)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert silu_and_mul.n_calls == 1
|
||||
# Should use kernel (n_calls == 0) if ROCm kernel is available, otherwise fallback (n_calls == 1)
|
||||
# The exact behavior depends on whether the test kernel exists for ROCm
|
||||
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
||||
|
||||
|
||||
def test_rocm_kernel_mapping():
|
||||
"""Test that ROCm shorthand device mapping works correctly."""
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
# Test that the mapping is processed correctly
|
||||
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
|
||||
mapping = _KERNEL_MAPPING.get()
|
||||
|
||||
# Verify the mapping exists
|
||||
assert "SiluAndMul" in mapping
|
||||
assert "rocm" in mapping["SiluAndMul"]
|
||||
|
||||
# Verify the repository is correctly stored
|
||||
rocm_repos = mapping["SiluAndMul"]["rocm"]
|
||||
assert rocm_repos is not None
|
||||
assert (
|
||||
rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation"
|
||||
)
|
||||
assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul"
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_capability():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
Reference in New Issue
Block a user