mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now. torch.mtia has following APIs similar to other backends. The lazy_init is also supported. ``` __all__ = [ "init", "is_available", "synchronize", "device_count", "current_device", "current_stream", "default_stream", "set_stream", "stream", "device", ] ``` ------------ For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon. ``` def _accelerator_hooks_device_count() -> _int: ... def _accelerator_hooks_set_current_device(device_index: _int) -> None: ... def _accelerator_hooks_get_current_device() -> _int : ... def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ... def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ... ``` --------- Adding get_device_module API to retrieve device modules for different device types. ``` def get_device_module(device: Optional[Union[torch.device, str]] = None) ``` --------- @exported-using-ghexport Differential Revision: [D52923602](https://our.internmc.facebook.com/intern/diff/D52923602/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612 Approved by: https://github.com/albanD ghstack dependencies: #123611
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb17721899
commit
d7e1bf9ff9
@ -713,6 +713,8 @@ def _get_available_device_type():
|
||||
return "cuda"
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
||||
return "xpu"
|
||||
if hasattr(torch, "mtia") and torch.mtia.is_available():
|
||||
return "mtia"
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, custom_backend_name, None)
|
||||
if custom_device_mod and custom_device_mod.is_available():
|
||||
@ -727,6 +729,8 @@ def _get_device_attr(get_member):
|
||||
return get_member(torch.cuda)
|
||||
if device_type and device_type.lower() == "xpu":
|
||||
return get_member(torch.xpu) # type: ignore[attr-defined]
|
||||
if device_type and device_type.lower() == "mtia":
|
||||
return get_member(torch.mtia)
|
||||
if device_type == torch._C._get_privateuse1_backend_name():
|
||||
return get_member(getattr(torch, device_type))
|
||||
# add more available device types here
|
||||
|
Reference in New Issue
Block a user