mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add get_device_index for custom device (#98804)
Fixes #ISSUE_NUMBER as the title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98804 Approved by: https://github.com/ngimel
This commit is contained in:
@ -846,6 +846,8 @@ class TestExtensionUtils(TestCase):
|
||||
with torch.autocast(device_type=custom_backend_name):
|
||||
pass
|
||||
|
||||
self.assertEqual(torch._utils._get_device_index('foo:1'), 1)
|
||||
self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2)
|
||||
|
||||
class TestDeviceUtils(TestCase):
|
||||
def test_basic(self):
|
||||
|
@ -650,6 +650,10 @@ def _get_available_device_type():
|
||||
return "cuda"
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
||||
return "xpu"
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, custom_backend_name, None)
|
||||
if custom_device_mod and custom_device_mod.is_available():
|
||||
return custom_backend_name
|
||||
# add more available device types here
|
||||
return None
|
||||
|
||||
@ -660,6 +664,8 @@ def _get_device_attr(get_member):
|
||||
return get_member(torch.cuda)
|
||||
if device_type and device_type.lower() == "xpu":
|
||||
return get_member(torch.xpu) # type: ignore[attr-defined]
|
||||
if device_type == torch._C._get_privateuse1_backend_name():
|
||||
return get_member(getattr(torch, device_type))
|
||||
# add more available device types here
|
||||
return None
|
||||
|
||||
|
@ -56,6 +56,9 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
||||
(5) set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
|
||||
Sets the random number generator state of the specified `foo` device.
|
||||
|
||||
And there are some common funcs:
|
||||
(1) is_available() -> bool:
|
||||
Returns a bool indicating if `foo` is currently available.
|
||||
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
|
||||
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
|
||||
|
||||
|
Reference in New Issue
Block a user