[Openreg][PrivateUse1] Enable CI for openreg (#151007)

Changes:
- move test_openreg.py from test/cpp_extensions/open_registration_extension/ to test/
- update README.md for openreg
- enable CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151007
Approved by: https://github.com/albanD
ghstack dependencies: #151005
This commit is contained in:
FFFrog
2025-04-16 09:57:56 +08:00
committed by PyTorch MergeBot
parent a9dbbe1aee
commit abbca37fe8
3 changed files with 37 additions and 14 deletions

View File

@ -1,29 +1,37 @@
# PyTorch OpenReg
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
## How to use
Install as standalone with `python setup.py develop` (or install) from this folder.
You can run test via `python test/test_openreg.py`.
You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`.
## Design principles
For simplicity anything that can be implemented from python is done so.
A real implementation will most likely want to call these different APIs from c++ directly.
The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
The codebase is split as follows:
- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module.
- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation
- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth.
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
- `pytorch_openreg/__init__.py`
- imports torch to get core state initialized.
- imports `._aten_impl` to register our aten op implementations to torch.
- imports `.C` to load our c++ extension that registers more ops, allocator and hooks.
- renames the PrivateUse1 backend and register our python-side module.
- `pytorch_openreg/_aten_impl.py`
- Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation.
- `pytorch_openreg/_device_daemon.py`
- contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers).
- contains `Driver`, which as user-process driver to deal with some information needed to be done in driver.
- contains `Executor`, which as device-process exector to do something related device logic.
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process.
- The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
## Next steps
Currently, the autograd test is disabled because it's missing the getStream implementation.
The main next step would be to:
- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations.
- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver.
- Add RNG Generator.
Longer term:
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.
- Build this module in the CI environment and enable Device-generic tests on this device.

View File

@ -268,6 +268,7 @@ RUN_PARALLEL_BLOCKLIST = [
"test_multiprocessing",
"test_multiprocessing_spawn",
"test_namedtuple_return_api",
"test_openreg",
"test_overrides",
"test_show_pickle",
"test_tensorexpr",
@ -1222,6 +1223,7 @@ CUSTOM_HANDLERS = {
"test_autoload_enable": test_autoload_enable,
"test_autoload_disable": test_autoload_disable,
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
"test_openreg": run_test_with_openreg,
"test_transformers_privateuse1": run_test_with_openreg,
}
@ -1512,10 +1514,14 @@ def get_selected_tests(options) -> list[str]:
# Filter to only run functorch tests when --functorch option is specified
if options.functorch:
selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
selected_tests = list(
filter(lambda test_name: test_name in FUNCTORCH_TESTS, selected_tests)
)
if options.cpp:
selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS]
selected_tests = list(
filter(lambda test_name: test_name in CPP_TESTS, selected_tests)
)
else:
# Exclude all C++ tests otherwise as they are still handled differently
# than Python test at the moment

View File

@ -1,18 +1,25 @@
# Owner(s): ["module: cpp"]
import os
import unittest
import psutil
import pytorch_openreg # noqa: F401
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
IS_LINUX,
run_tests,
skipIfTorchDynamo,
TestCase,
)
class TestOpenReg(TestCase):
def test_initializes(self):
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
@unittest.skipIf(not IS_LINUX, "Only works on linux")
def test_autograd_init(self):
# Make sure autograd is initialized
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
@ -70,6 +77,7 @@ class TestOpenReg(TestCase):
self.assertEqual(generator.device.type, "openreg")
self.assertEqual(generator.device.index, 1)
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
def test_pin_memory(self):
cpu_a = torch.randn(10)
self.assertFalse(cpu_a.is_pinned())
@ -108,6 +116,7 @@ class TestOpenReg(TestCase):
self.assertNotEqual(0, event2.event_id)
self.assertNotEqual(event1.event_id, event2.event_id)
@skipIfTorchDynamo()
def test_event_elapsed_time(self):
stream = torch.Stream(device="openreg:1")
e1 = torch.Event(device="openreg:1", enable_timing=True)