torch.mtia module for MTIA device backend (#123612)

MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
    "init",
    "is_available",
    "synchronize",
    "device_count",
    "current_device",
    "current_stream",
    "default_stream",
    "set_stream",
    "stream",
    "device",
]

```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```

---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------
@exported-using-ghexport

Differential Revision: [D52923602](https://our.internmc.facebook.com/intern/diff/D52923602/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
This commit is contained in:
egienvalue
2024-04-17 12:05:22 -07:00
committed by PyTorch MergeBot
parent cb17721899
commit d7e1bf9ff9
20 changed files with 655 additions and 21 deletions

View File

@ -713,6 +713,8 @@ def _get_available_device_type():
return "cuda"
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
return "xpu"
if hasattr(torch, "mtia") and torch.mtia.is_available():
return "mtia"
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_device_mod = getattr(torch, custom_backend_name, None)
if custom_device_mod and custom_device_mod.is_available():
@ -727,6 +729,8 @@ def _get_device_attr(get_member):
return get_member(torch.cuda)
if device_type and device_type.lower() == "xpu":
return get_member(torch.xpu) # type: ignore[attr-defined]
if device_type and device_type.lower() == "mtia":
return get_member(torch.mtia)
if device_type == torch._C._get_privateuse1_backend_name():
return get_member(getattr(torch, device_type))
# add more available device types here