mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 20:46:42 +08:00
@ -87,7 +87,7 @@ class Device:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (`str`):
|
type (`str`):
|
||||||
The device type (e.g., "cuda", "mps", "cpu").
|
The device type (e.g., "cuda", "mps", "rocm").
|
||||||
properties ([`CUDAProperties`], *optional*):
|
properties ([`CUDAProperties`], *optional*):
|
||||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||||
|
|
||||||
@ -531,7 +531,7 @@ class _ROCMRepos(_DeviceRepos):
|
|||||||
|
|
||||||
def _validate_device_type(device_type: str) -> None:
|
def _validate_device_type(device_type: str) -> None:
|
||||||
"""Validate that the device type is supported."""
|
"""Validate that the device type is supported."""
|
||||||
supported_devices = {"cuda", "rocm", "mps", "cpu"}
|
supported_devices = {"cuda", "rocm", "mps"}
|
||||||
if device_type not in supported_devices:
|
if device_type not in supported_devices:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||||
@ -789,7 +789,7 @@ def kernelize(
|
|||||||
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
|
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
|
||||||
kernelizes the model for training with `torch.compile`.
|
kernelizes the model for training with `torch.compile`.
|
||||||
device (`Union[str, torch.device]`, *optional*):
|
device (`Union[str, torch.device]`, *optional*):
|
||||||
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
|
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
|
||||||
The device type will be inferred from the model parameters when not provided.
|
The device type will be inferred from the model parameters when not provided.
|
||||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||||
|
@ -110,24 +110,20 @@ def test_arg_kinds():
|
|||||||
|
|
||||||
@pytest.mark.cuda_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
def test_hub_forward(cls):
|
||||||
def test_hub_forward(cls, device):
|
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
|
|
||||||
silu_and_mul = SiluAndMul()
|
silu_and_mul = SiluAndMul()
|
||||||
X = torch.randn((32, 64), device=device)
|
X = torch.randn((32, 64), device="cuda")
|
||||||
Y = silu_and_mul(X)
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
|
||||||
Y_kernel = silu_and_mul_with_kernel(X)
|
Y_kernel = silu_and_mul_with_kernel(X)
|
||||||
|
|
||||||
torch.testing.assert_close(Y_kernel, Y)
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
|
|
||||||
assert silu_and_mul.n_calls == 1
|
assert silu_and_mul.n_calls == 1
|
||||||
if device == "cuda":
|
assert silu_and_mul_with_kernel.n_calls == 0
|
||||||
assert silu_and_mul_with_kernel.n_calls == 0
|
|
||||||
else:
|
|
||||||
assert silu_and_mul_with_kernel.n_calls == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.rocm_only
|
@pytest.mark.rocm_only
|
||||||
|
Reference in New Issue
Block a user