[ROCm] Adds initialization support for PyTorch when built from ROCm wheels. (#155285)

AMD is beginning to roll out ROCm distribution via Python wheels. This patch adds the `__init__.py` hook that is necessary to bootstrap ROCm correctly on Linux and Windows when built from these wheels.

See draft, developer documentation describing the mechanism here: https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md

This operates to similar effect as how Torch can depend on CUDA wheels, with some differences:

* ROCm libraries and checks are delegated to helpers in the `rocm_sdk` module, which knows how to find and configure access to the installed libraries. This limits the amount of plumbing and path machinations that must match up between the framework and ROCm.
* When building torch against ROCm, no ROCm system install is needed: instead the proper SDK development wheel is installed and the `CMAKE_PREFIX_PATH` is obtained via `rocm-sdk path --cmake`.
* It is expected that whoever produces such a build will also place a generated `_rocm_init.py` in the `torch` module with initialization logic to preload libraries, check versions, verify GPU compatibility, etc.
* See [build_prod_wheels.py](https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py) for an example build script that is being used to generate nightlies in this configuration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155285
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Stella Laurenzo
2025-06-07 02:59:00 +00:00
committed by PyTorch MergeBot
parent f140fac8dc
commit 30387ab2e4

View File

@ -155,6 +155,21 @@ assert __all__ == sorted(__all__)
# Load the extension module
################################################################################
# If PyTorch was built against the ROCm runtime wheels, then there will be
# a _rocm_init module and it will define an initialize() function which can
# prepare ROCm for use. See general documentation on ROCm runtime wheels:
# https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md
# Since this module is only ever added to the wheel if built for such a
# deployment, it is always safe to attempt.
try:
from . import _rocm_init # type: ignore[attr-defined]
except ImportError:
pass
else:
_rocm_init.initialize()
del _rocm_init
if sys.platform == "win32":
def _load_dll_libraries() -> None: