Add basic OpenReg module scaffolding with autograd (#131708)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131708
Approved by: https://github.com/ezyang
This commit is contained in:
albanD
2024-08-02 17:17:55 -04:00
committed by PyTorch MergeBot
parent df59084012
commit 3d87dfc088
8 changed files with 435 additions and 0 deletions

View File

@ -55,6 +55,9 @@ per-file-ignores =
torch/distributed/_functional_collectives.py: TOR901
torch/distributed/_spmd/data_parallel.py: TOR901
torch/distributed/_tensor/_collective_utils.py: TOR901
# This is a full package that happen to live within the test
# folder, so ok to skip
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901
optional-ascii-coding = True
exclude =
./.git,

View File

@ -0,0 +1,13 @@
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core.
## How to use
Install as standalone with `python setup.py develop` (or install) from this folder.
You can run test via `python test/test_openreg.py`.
## Design principles
For simplicity anything that can be implemented from python is done so.
A real implementation will most likely want to call these different APIs from c++ directly.
The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn.
Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate.

View File

@ -0,0 +1,45 @@
import torch
# Global properties of our device
NUM_DEVICES = 7
# Create our python implementation dict so that the C++ module
# can access it during its initialization
_IMPL_REGISTRY = {}
# Load the C++ Module
import pytorch_openreg._C # noqa: F401
# Define all the implementations in the registry
def register(fn):
_IMPL_REGISTRY[fn.__name__[1:]] = fn
return fn
@register
def _deviceCount():
return NUM_DEVICES
# Module used for our backend
class _OpenRegMod:
pass
# Set all the appropriate state on PyTorch
torch.utils.rename_privateuse1_backend("openreg")
torch._register_device_module("openreg", _OpenRegMod())
_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901
def _openreg_kernel_fallback(op, *args, **kwargs):
print("Calling ", op)
assert op is torch.ops.aten.empty.memory_format
# FIXME: this returns a cpu Tensor which is NOT ok.
return torch.empty(args[0])
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")

View File

@ -0,0 +1,17 @@
#include <OpenReg.h>
// Make this a proper CPython module
static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT,
.m_name = "pytorch_openreg._C",
};
PyMODINIT_FUNC PyInit__C(void) {
PyObject* mod = PyModule_Create(&openreg_C_module);
py::object openreg_mod = py::module_::import("pytorch_openreg");
// Only borrowed from the python side!
openreg::set_impl_registry(openreg_mod.attr("_IMPL_REGISTRY").ptr());
return mod;
}

View File

@ -0,0 +1,10 @@
#pragma once
// Shared header for OpenReg module
#include <torch/csrc/utils/pybind.h>
namespace openreg {
void set_impl_registry(PyObject* registry);
}

View File

@ -0,0 +1,246 @@
#include <OpenReg.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <iostream>
namespace openreg {
namespace {
// Python dictionary where real implementations can be found
PyObject* py_registry;
py::function get_method(const char* name) {
return py::cast<py::dict>(py_registry)[name];
}
// C++ hooks implementation
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
OpenRegHooksInterface(OpenRegHooksArgs) {};
~OpenRegHooksInterface() override = default;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
return get_method("hasPrimaryContext")(device_index).cast<bool>();
}
};
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs);
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs);
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "OpenRegHooks", OpenRegHooksInterface);
// Device guard registration
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;
OpenRegGuardImpl() = default;
explicit OpenRegGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == static_type);
}
/**
* Return the type of device managed by this guard implementation.
*/
c10::DeviceType type() const override {
return static_type;
}
/**
* Set the current device to Device, and return the previous c10::Device.
*/
c10::Device exchangeDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto old_device_index = get_method("exchangeDevice")(d.index()).cast<c10::DeviceIndex>();
return c10::Device(static_type, old_device_index);
}
/**
* Get the current device.
*/
c10::Device getDevice() const override {
py::gil_scoped_acquire acquire;
auto device = get_method("getDevice")().cast<c10::DeviceIndex>();
return c10::Device(static_type, device);
}
/**
* Set the current device to c10::Device.
*/
void setDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto device = get_method("setDevice")(d.index());
}
/**
* Set the current device to c10::Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
void uncheckedSetDevice(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
auto device = get_method("uncheckedSetDevice")(d.index());
}
/**
* Get the current stream for a given device.
*/
c10::Stream getStream(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("getStream")(d.index()).cast<c10::Stream>();
}
/**
* Get the default stream for a given device.
*/
c10::Stream getDefaultStream(c10::Device d) const override {
py::gil_scoped_acquire acquire;
return get_method("getDefaultStream")(d.index()).cast<c10::Stream>();
}
/**
* Get a stream from the global pool for a given device.
*/
c10::Stream getStreamFromGlobalPool(c10::Device d, bool isHighPriority = false) const override {
py::gil_scoped_acquire acquire;
return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority).cast<c10::Stream>();
}
/**
* Return a new stream for a given device and priority. The stream will be
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
py::gil_scoped_acquire acquire;
return get_method("getNewStream")(d.index(), priority).cast<c10::Stream>();
}
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("exchangeStream")(s).cast<c10::Stream>();
}
/**
* Destroys the given event.
*/
void destroyEvent(void* event, const c10::DeviceIndex device_index)
const noexcept override {
py::gil_scoped_acquire acquire;
get_method("destroyEvent")(event, device_index);
}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(
void** event,
const c10::Stream& stream,
const c10::DeviceIndex device_index,
const c10::EventFlag flag) const override {
py::gil_scoped_acquire acquire;
get_method("record")(event, stream, device_index, flag);
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(void* event, const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
get_method("block")(event, stream);
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool queryEvent(void* event) const override {
py::gil_scoped_acquire acquire;
return get_method("queryEvent")(event).cast<bool>();
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
c10::DeviceIndex deviceCount() const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("deviceCount")().cast<c10::DeviceIndex>();
}
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
bool queryStream(const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
return get_method("queryStream")(stream).cast<bool>();
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const c10::Stream& stream) const {
py::gil_scoped_acquire acquire;
get_method("synchronizeStream")(stream);
}
/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
void synchronizeEvent(void* event) const override {
py::gil_scoped_acquire acquire;
get_method("synchronizeEvent")(event);
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const c10::Stream& stream)
const override {
py::gil_scoped_acquire acquire;
get_method("recordDataPtrOnStream")(data_ptr, stream);
}
/**
* Fetch the elapsed time between two recorded events.
*/
double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index)
const override {
py::gil_scoped_acquire acquire;
return get_method("elapsedTime")(event1, event2, device_index).cast<double>();
}
};
// Register our device guard
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
} // anonymous namspaces
// Setter for the python dictionary with implementations
void set_impl_registry(PyObject* registry) {
py_registry = registry;
}
} // openreg

View File

@ -0,0 +1,63 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
PACKAGE_NAME = "pytorch_openreg"
version = 1.0
ROOT_DIR = Path(__file__).absolute().parent
CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove pytorch_openreg extension
for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"):
path.unlink()
# Remove build directory
build_dirs = [
ROOT_DIR / "build",
]
for path in build_dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
if __name__ == "__main__":
sources = list(CSRS_DIR.glob("*.cpp"))
# Note that we always compile with debug info
ext_modules = [
CppExtension(
name="pytorch_openreg._C",
sources=sorted(str(s) for s in sources),
include_dirs=[CSRS_DIR],
extra_compile_args={"cxx": ["-g", "-Wall", "-Werror"]},
)
]
setup(
name=PACKAGE_NAME,
version=version,
author="PyTorch Core Team",
description="Example for PyTorch out of tree regitration",
packages=find_packages(exclude=("test",)),
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=ext_modules,
python_requires=">=3.8",
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
)

View File

@ -0,0 +1,38 @@
# Owner(s): ["module: cpp"]
import os
import unittest
import psutil
import pytorch_openreg
import torch
from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase
class TestOpenReg(TestCase):
def test_initializes(self):
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
@unittest.skipIf(not IS_LINUX, "Only works on linux")
def test_autograd_init(self):
# Make sure autograd is initialized
torch.rand(2, requires_grad=True, device="openreg").sum().backward()
pid = os.getpid()
task_path = f"/proc/{pid}/task"
all_threads = psutil.Process(pid).threads()
all_thread_names = set()
for t in all_threads:
with open(f"{task_path}/{t.id}/comm") as file:
thread_name = file.read().strip()
all_thread_names.add(thread_name)
for i in range(pytorch_openreg.NUM_DEVICES):
self.assertIn(f"pt_autograd_{i}", all_thread_names)
if __name__ == "__main__":
run_tests()