cpu is not (yet) a supported device type (#132)

Fixes #131.
This commit is contained in:
Daniël de Kok
2025-08-25 16:25:58 +02:00
committed by GitHub
parent 767e7ccf13
commit 7611021100
2 changed files with 7 additions and 11 deletions

View File

@ -87,7 +87,7 @@ class Device:
Args: Args:
type (`str`): type (`str`):
The device type (e.g., "cuda", "mps", "cpu"). The device type (e.g., "cuda", "mps", "rocm").
properties ([`CUDAProperties`], *optional*): properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
@ -531,7 +531,7 @@ class _ROCMRepos(_DeviceRepos):
def _validate_device_type(device_type: str) -> None: def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported.""" """Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps", "cpu"} supported_devices = {"cuda", "rocm", "mps"}
if device_type not in supported_devices: if device_type not in supported_devices:
raise ValueError( raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
@ -789,7 +789,7 @@ def kernelize(
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE` The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
kernelizes the model for training with `torch.compile`. kernelizes the model for training with `torch.compile`.
device (`Union[str, torch.device]`, *optional*): device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu". The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
The device type will be inferred from the model parameters when not provided. The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`): use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found. Whether to use the original forward method of modules when no compatible kernel could be found.

View File

@ -110,24 +110,20 @@ def test_arg_kinds():
@pytest.mark.cuda_only @pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_hub_forward(cls):
def test_hub_forward(cls, device):
torch.random.manual_seed(0) torch.random.manual_seed(0)
silu_and_mul = SiluAndMul() silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device) X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X) Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE) silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X) Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y) torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1 assert silu_and_mul.n_calls == 1
if device == "cuda": assert silu_and_mul_with_kernel.n_calls == 0
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.rocm_only @pytest.mark.rocm_only