diff --git a/docs/source/accelerator/autoload.md b/docs/source/accelerator/autoload.md new file mode 100644 index 000000000000..97664adcd735 --- /dev/null +++ b/docs/source/accelerator/autoload.md @@ -0,0 +1,86 @@ +# Autoload Mechanism + +The **Autoload** mechanism in PyTorch simplifies the integration of a custom backend by enabling automatic discovery and initialization at runtime. This eliminates the need for explicit imports or manual initialization, allowing developers to seamlessly integrate a new accelerator or backend into PyTorch. + +## Background + +The **Autoload Device Extension** proposal in PyTorch is centered on improving support for various hardware backend devices, especially those implemented as out-of-the-tree extensions (not part of PyTorch’s main codebase). Currently, users must manually import or load these device-specific extensions to use them, which complicates the experience and increases cognitive overhead. + +In contrast, in-tree devices (devices officially supported within PyTorch) are seamlessly integrated—users don’t need extra imports or steps. The goal of autoloading is to make out-of-the-tree devices just as easy to use, so users can follow the standard PyTorch device programming model without explicit loading or code changes. This would allow existing PyTorch applications to run on new devices without any modification, making hardware support more user-friendly and reducing barriers to adoption. + +For more information about the background of **Autoload**, please refer to its [RFC](https://github.com/pytorch/pytorch/issues/122468). + +## Design + +The core idea of **Autoload** is to Use Python’s plugin discovery (entry points) so PyTorch automatically loads out-of-tree device extensions when torch is imported—no explicit user import needed. + +For more instructions of the design of **Autoload**, please refer to [**How it works**](https://docs.pytorch.org/tutorials/unstable/python_extension_autoload.html#how-it-works). + +## Implementation + +This tutorial will take **OpenReg** as a new out-of-the-tree device and guide you through the steps to enable and use the **Autoload** mechanism. + +### Entry Point Setup + +To enable **Autoload**, register the `_autoload` function as an entry point in `setup.py` file. + +::::{tab-set} + +:::{tab-item} Python + +```{eval-rst} +.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/setup.py + :language: python + :start-after: LITERALINCLUDE START: SETUP + :end-before: LITERALINCLUDE END: SETUP + :linenos: + :emphasize-lines: 9-13 +``` + +::: + +:::: + +### Backend Setup + +Define the initialization hook `_autoload` for backend initialization. This hook will be automatically invoked by PyTorch during startup. + +::::{tab-set-code} +```{eval-rst} +.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py + :language: python + :start-after: LITERALINCLUDE START: AUTOLOAD + :end-before: LITERALINCLUDE END: AUTOLOAD + :linenos: + :emphasize-lines: 10-12 +``` + + +:::: + +## Result + +After setting up the entry point and backend, build and install your backend. Now, we can use the new accelerator without explicitly importing it. + +```{eval-rst} +.. grid:: 2 + + .. grid-item-card:: :octicon:`terminal;1em;` Without Autoload + :class-card: card-prerequisites + + :: + + >>> import torch + >>> import torch_openreg + >>> torch.tensor(1, device="openreg") + tensor(1, device='openreg:0') + + .. grid-item-card:: :octicon:`terminal;1em;` With Autoload + :class-card: card-prerequisites + + :: + + >>> import torch # Automatically import torch_openreg + >>> torch.tensor(1, device="openreg") + tensor(1, device='openreg:0') +``` diff --git a/docs/source/accelerator/index.md b/docs/source/accelerator/index.md index 68db62e07597..70f25812bb9e 100644 --- a/docs/source/accelerator/index.md +++ b/docs/source/accelerator/index.md @@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k :glob: :maxdepth: 1 +autoload operators ``` diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md index 83ec85b1055c..9474c85a1b84 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md @@ -124,6 +124,16 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll - Per-operator Fallback: See `sub.Tensor` - Global Fallback: See `wrapper_cpu_fallback` +### Autoload + +- Autoload Machanism + + When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends. + + - Registering the backend with Python `entry points`: See `setup` in `setup.py` + - Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py` + - Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example) + ## Installation and Usage ### Installation @@ -139,7 +149,6 @@ After installation, you can use the `openreg` device in Python just like any oth ```python import torch -import torch_openreg if not torch.openreg.is_available(): print("OpenReg backend is not available in this build.") diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py index 0768653e1ac4..8c1496387570 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -28,6 +28,12 @@ def make_relative_rpath_args(path): def get_pytorch_dir(): + # Disable autoload of the accelerator + + # We must do this for two reasons: + # We only need to get the PyTorch installation directory, so whether the accelerator is loaded or not is irrelevant + # If the accelerator has been previously built and not uninstalled, importing torch will cause a circular import error + os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" import torch return os.path.dirname(os.path.realpath(torch.__file__)) @@ -127,6 +133,7 @@ def main(): ] } + # LITERALINCLUDE START: SETUP setup( packages=find_packages(), package_data=package_data, @@ -135,7 +142,13 @@ def main(): "clean": BuildClean, # type: ignore[misc] }, include_package_data=False, + entry_points={ + "torch.backends": [ + "torch_openreg = torch_openreg:_autoload", + ], + }, ) + # LITERALINCLUDE END: SETUP if __name__ == "__main__": diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py index 45b2343070fe..18cee1615705 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py @@ -9,7 +9,7 @@ if sys.platform == "win32": _load_dll_libraries() del _load_dll_libraries - +# LITERALINCLUDE START: AUTOLOAD import torch_openreg._C # type: ignore[misc] import torch_openreg.openreg @@ -17,3 +17,11 @@ import torch_openreg.openreg torch.utils.rename_privateuse1_backend("openreg") torch._register_device_module("openreg", torch_openreg.openreg) torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) + + +def _autoload(): + # It is a placeholder function here to be registered as an entry point. + pass + + +# LITERALINCLUDE END: AUTOLOAD diff --git a/test/test_openreg.py b/test/test_openreg.py index 7ee8ccefcd09..c0d99f5a6ac1 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -10,7 +10,6 @@ from unittest.mock import patch import numpy as np import psutil -import torch_openreg # noqa: F401 import torch from torch.serialization import safe_globals diff --git a/test/test_transformers_privateuse1.py b/test/test_transformers_privateuse1.py index 0aa15260d094..31023875f886 100644 --- a/test/test_transformers_privateuse1.py +++ b/test/test_transformers_privateuse1.py @@ -4,8 +4,6 @@ import unittest from collections import namedtuple from functools import partial -import torch_openreg # noqa: F401 - import torch from torch.nn.attention import SDPBackend from torch.testing._internal.common_nn import NNTestCase