cpu is not (yet) a supported device type (#132)

Fixes #131.
This commit is contained in:
Daniël de Kok
2025-08-25 16:25:58 +02:00
committed by GitHub
parent 767e7ccf13
commit 7611021100
2 changed files with 7 additions and 11 deletions

View File

@ -87,7 +87,7 @@ class Device:
Args:
type (`str`):
The device type (e.g., "cuda", "mps", "cpu").
The device type (e.g., "cuda", "mps", "rocm").
properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
@ -531,7 +531,7 @@ class _ROCMRepos(_DeviceRepos):
def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps", "cpu"}
supported_devices = {"cuda", "rocm", "mps"}
if device_type not in supported_devices:
raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
@ -789,7 +789,7 @@ def kernelize(
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
kernelizes the model for training with `torch.compile`.
device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found.

View File

@ -110,24 +110,20 @@ def test_arg_kinds():
@pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
def test_hub_forward(cls):
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device)
X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
if device == "cuda":
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.rocm_only