Add device daemon (#131814)

Base implementation aiming towards https://github.com/pytorch/rfcs/pull/64

Details of the implementation and next steps in https://github.com/pytorch/pytorch/blob/gh/albanD/3/head/test/cpp_extensions/open_registration_extension/README.md

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131814
Approved by: https://github.com/ezyang
This commit is contained in:
albanD
2024-08-27 16:12:01 -04:00
committed by PyTorch MergeBot
parent d6091c8726
commit 3b33f26513
10 changed files with 614 additions and 41 deletions

View File

@ -57,7 +57,7 @@ per-file-ignores =
torch/distributed/_tensor/_collective_utils.py: TOR901 torch/distributed/_tensor/_collective_utils.py: TOR901
# This is a full package that happen to live within the test # This is a full package that happen to live within the test
# folder, so ok to skip # folder, so ok to skip
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901 test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py: TOR901
optional-ascii-coding = True optional-ascii-coding = True
exclude = exclude =
./.git, ./.git,

View File

@ -1,4 +1,4 @@
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core. This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
## How to use ## How to use
Install as standalone with `python setup.py develop` (or install) from this folder. Install as standalone with `python setup.py develop` (or install) from this folder.
@ -8,6 +8,23 @@ You can run test via `python test/test_openreg.py`.
For simplicity anything that can be implemented from python is done so. For simplicity anything that can be implemented from python is done so.
A real implementation will most likely want to call these different APIs from c++ directly. A real implementation will most likely want to call these different APIs from c++ directly.
The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn. The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate. The codebase is split as follows:
- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module.
- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation
- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth.
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
## Next steps
Currently, the autograd test is disabled because it's missing the getStream implementation.
The main next step would be to:
- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations.
- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver.
- Add RNG Generator.
- Add Pinned memory and HostAllocator.
Longer term:
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.
- Build this module in the CI environment and enable Device-generic tests on this device.

View File

@ -1,26 +1,13 @@
import torch import torch
# Global properties of our device
NUM_DEVICES = 7
# Create our python implementation dict so that the C++ module # Create our python implementation dict so that the C++ module
# can access it during its initialization # can access it during its initialization
_IMPL_REGISTRY = {} # Also register aten impls
from ._aten_impl import _IMPL_REGISTRY as _IMPL_REGISTRY # noqa: F401
# Load the C++ Module # Load the C++ Module
import pytorch_openreg._C # noqa: F401 import pytorch_openreg._C # noqa: F401 # usort: skip
# Define all the implementations in the registry
def register(fn):
_IMPL_REGISTRY[fn.__name__[1:]] = fn
return fn
@register
def _deviceCount():
return NUM_DEVICES
# Module used for our backend # Module used for our backend
@ -31,15 +18,3 @@ class _OpenRegMod:
# Set all the appropriate state on PyTorch # Set all the appropriate state on PyTorch
torch.utils.rename_privateuse1_backend("openreg") torch.utils.rename_privateuse1_backend("openreg")
torch._register_device_module("openreg", _OpenRegMod()) torch._register_device_module("openreg", _OpenRegMod())
_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901
def _openreg_kernel_fallback(op, *args, **kwargs):
print("Calling ", op)
assert op is torch.ops.aten.empty.memory_format
# FIXME: this returns a cpu Tensor which is NOT ok.
return torch.empty(args[0])
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")

View File

@ -0,0 +1,153 @@
import logging
import torch
from torch.utils._pytree import tree_any
log = logging.getLogger(__name__)
from ._device_daemon import daemon
from ._meta_parser import prepare_for_sending, to_device_no_copy
_IMPL_REGISTRY = {}
# Define all the implementations in the registry
def _register_same_name(name, with_log=False):
def _(*args, **kwargs):
if with_log:
log.info("Calling hook %s", name)
return daemon.exec(name, *args, **kwargs)
_IMPL_REGISTRY[name] = _
_register_same_name("deviceCount")
_register_same_name("getDevice")
_register_same_name("uncheckedSetDevice")
_register_same_name("exchangeDevice")
_register_same_name("malloc", True)
_register_same_name("free", True)
_openreg_lib = torch.library.Library("_", "IMPL")
def _openreg_kernel_fallback(op, *args, **kwargs):
log.info("Calling kernel %s", op)
# Special ops needed to avoid infinite recursion
if op is torch.ops.aten._copy_from.default:
from_, to_ = args
if from_.device.type == to_.device.type:
assert from_.device.type == "openreg"
op = torch.ops.aten.copy_.default
# handled below as a regular copy
elif from_.device.type == "openreg":
args, _ = prepare_for_sending((from_,), {})
host_mem = daemon.exec("send_data", *args)
return to_.copy_(host_mem)
elif to_.device.type == "openreg":
args, _ = prepare_for_sending((to_,), {})
daemon.exec("recv_data", from_, *args)
return to_
else:
raise RuntimeError("Should not happen")
elif op is torch.ops.aten.set_.source_Tensor:
return torch.ops.aten.set_.source_Storage_storage_offset(
args[0],
args[1].untyped_storage(),
args[1].storage_offset(),
args[1].size(),
args[1].stride(),
)
elif op is torch.ops.aten._local_scalar_dense.default:
args, _ = prepare_for_sending(args, {})
host_mem = daemon.exec("send_data", *args)
return host_mem.item()
op_name = None
post_process = None
if "out" in op._overloadname:
# Note that all structured native op will call here
if isinstance(kwargs["out"], tuple):
raise RuntimeError(f"out= variant {op} with tuple out= not supported")
if kwargs["out"].nelement() == 0:
# Out variant that needs a resize, convert to an out of place
# and handle generically below
orig_out = kwargs["out"]
del kwargs["out"]
if op._overloadname != "out":
raise RuntimeError(
"Cannot retranslate non-default out= variant form 0 size"
)
op = op.overloadpacket.default
def _post_process():
nonlocal real_res
orig_out.set_(real_res)
real_res = orig_out
post_process = _post_process
else:
# No metadata update to do, just run the op on the device
op_name = op.overloadpacket._qualified_op_name
real_res = kwargs["out"]
elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)):
# No Tensor argument means factory function
# They should decompose and be handled in our c++ side directly
raise RuntimeError(f"{op} not handled yet.")
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
# Only handle inplace ops returning their first arg
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
assert (
len(op._schema.returns) == 1
), f"NYI Inplace {op} with more than one return"
op_name = op.overloadpacket._qualified_op_name
real_res = args[0]
elif any(r.alias_info is not None for r in op._schema.returns):
# View ops
if op is torch.ops.aten.view.default:
return torch.ops.aten._unsafe_view(*args, **kwargs)
raise RuntimeError(f"{op} view op is not handled yet")
if op_name is None:
# 1. Compute updated metadata
if torch.Tag.dynamic_output_shape not in op.tags:
# Usual case: run the meta op to see the output metadata
meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs)
meta_res = op(*meta_args, **meta_kwargs)
# 2. Allocate the output
real_res, _ = to_device_no_copy("openreg", meta_res, {})
else:
# Slow version for data-dependent functions:
# Run the op on the device just to get the output shape
args_, kwargs_ = prepare_for_sending(args, kwargs)
shape = daemon.exec(
"get_op_output_shape",
op.overloadpacket._qualified_op_name,
args_,
kwargs_,
)
# 2. Allocate the output
real_res = args[0].new(shape)
# 3. Move to out variant
kwargs["out"] = real_res
# Let overload resolution find the out= overload
op_name = op.overloadpacket._qualified_op_name
# 4. Run the compute and populate the output on the device
args, kwargs = prepare_for_sending(args, kwargs)
daemon.exec("run_op", op_name, args, kwargs)
if post_process is not None:
post_process()
return real_res
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")

View File

@ -0,0 +1,168 @@
import logging
import torch
from ._meta_parser import (
OpenRegTensorData,
receive_after_sending,
safe_str,
validate_send_queue_args,
)
log = logging.getLogger(__name__)
mp_context = torch.multiprocessing.get_context("spawn")
# Constant properties of our device
NUM_DEVICES = 7
# Global state of our driver
CURR_DEVICE_IDX = 0
CURR_STREAM = 0
# Our allocator
class Allocator:
def __init__(self):
self.allocated = {}
def malloc(self, size):
new_data = torch.empty(size, dtype=torch.uint8)
ptr = new_data.data_ptr()
self.allocated[ptr] = new_data
return ptr
def free(self, ptr):
if ptr not in self.allocated:
return False
else:
del self.allocated[ptr]
return True
def tensor_from_meta(self, meta):
# Usual case, we're receiving a known Tensor
found_base = self.allocated.get(meta.data_ptr, None)
# Might be a rewrap of another storage at a different offset
# Slow path to try and find the corresponding storage
if found_base is None:
for tag, t in self.allocated.items():
# t is always a 1D uint8 storage!
if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement():
# Blame @ngimel for this
slice_size = t.nelement() - (meta.data_ptr - tag)
found_base = torch.tensor((), dtype=torch.uint8).set_(
t.untyped_storage()[meta.data_ptr - tag :],
size=(slice_size,),
stride=(1,),
storage_offset=0,
)
# This pointer is not allocated here, segfault !
if found_base is None:
log.info("Currently allocated blocks:\n %s", safe_str(self.allocated))
log.info("Trying to access %s", meta)
raise RuntimeError("SEGFAULT!")
# Raw 1d uint8 data
raw = found_base
# Slice the right storage part
raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes)
# Reinterpret cast in the right dtype
as_dtype = raw_slice.view(dtype=meta.dtype)
# View to the right shape/stride/offset
view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset)
return view
def run_op(allocator, op_name, args, kwargs):
op, _ = torch._C._jit_get_operation(op_name)
args, kwargs = receive_after_sending(allocator, args, kwargs)
return op(*args, **kwargs)
class _Daemon:
def __init__(self):
super().__init__()
self.is_initialized = False
def _lazy_init(self):
if self.is_initialized:
return
self.req_queue = mp_context.Queue()
self.ans_queue = mp_context.Queue()
self.runner = mp_context.Process(
target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True
)
self.runner.start()
self.is_initialized = True
def exec(self, cmd, *args):
self._lazy_init()
log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
validate_send_queue_args(cmd, args)
self.req_queue.put((cmd,) + args)
res = self.ans_queue.get()
log.info("Main process result for %s received: %s", cmd, safe_str(res))
if res == "ERROR":
raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
else:
return res
@staticmethod
def run_forever(req_queue, ans_queue):
# Initialize our device
global CURR_DEVICE_IDX
empty_res = object()
allocator = Allocator()
# Serve all requests
while True:
cmd, *args = req_queue.get()
log.info("Worker executing: %s", cmd)
res = empty_res
if cmd == "deviceCount":
assert len(args) == 0
res = NUM_DEVICES
elif cmd == "getDevice":
res = CURR_DEVICE_IDX
elif cmd == "uncheckedSetDevice":
assert len(args) == 1
CURR_DEVICE_IDX = int(args[0])
res = None
elif cmd == "exchangeDevice":
assert len(args) == 1
res = CURR_DEVICE_IDX
CURR_DEVICE_IDX = int(args[0])
elif cmd == "malloc":
res = allocator.malloc(*args)
elif cmd == "free":
res = allocator.free(*args)
elif cmd == "run_op":
op_name, args, kwargs = args
run_op(allocator, op_name, args, kwargs)
res = None
elif cmd == "send_data":
assert len(args) == 1
res = OpenRegTensorData.from_meta(allocator, args[0])
elif cmd == "recv_data":
assert len(args) == 2
host_tensor, dev_mem = args
dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem)
dev_tensor.copy_(host_tensor)
res = None
elif cmd == "get_op_output_shape":
op_name, args, kwargs = args
res = run_op(allocator, op_name, args, kwargs).size()
else:
log.warning("Bad command in worker")
res = "ERROR"
if res == empty_res:
raise RuntimeError("Bad impl didn't return anything")
log.info("Worker answering to: %s", cmd)
ans_queue.put(res)
daemon = _Daemon()

View File

@ -0,0 +1,104 @@
import pprint
import torch
from torch.utils._pytree import tree_map, tree_map_only
class OpenRegTensorMeta:
def __init__(self, tensor, checked=True):
if checked and not tensor.device.type == "openreg":
raise RuntimeError(
"Creating OpenRegTensorMeta is only for Tensors on openreg device"
)
self.data_ptr = tensor.untyped_storage().data_ptr()
self.size = tensor.size()
self.stride = tensor.stride()
self.storage_offset = tensor.storage_offset()
self.dtype = tensor.dtype
self.nelem_in_bytes = tensor.nelement() * tensor.element_size()
def __repr__(self):
return (
f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, "
f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})"
)
class OpenRegTensorData(torch.Tensor):
@staticmethod
def from_meta(allocator, tensor_meta):
return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta))
VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float}
VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str}
def safe_str(args):
def convert(obj):
if isinstance(obj, torch.Tensor):
return str(OpenRegTensorMeta(obj, checked=False))
else:
return obj
new_args = tree_map(convert, args)
return pprint.pformat(new_args)
def validate_send_queue_args(cmd, args):
def check(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT:
if (
cmd == "recv_data"
and type(obj) is torch.Tensor
and obj.device.type == "cpu"
):
# Only HtoD copy command can send cpu Tensors over
return
raise RuntimeError(
f"Trying to send invalid object through queue: {type(obj)}"
)
tree_map(check, args)
def prepare_for_sending(args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_IN:
raise RuntimeError(
f"Cannot send object of type {type(obj)} " "over openreg device pipe."
)
if isinstance(obj, torch.Tensor):
return OpenRegTensorMeta(obj)
else:
return obj
return tree_map(convert, (args, kwargs))
def receive_after_sending(allocator, args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT:
raise RuntimeError(
f"Received invalid object of type {type(obj)} "
"over openreg device pipe."
)
if isinstance(obj, OpenRegTensorMeta):
return allocator.tensor_from_meta(obj)
else:
return obj
return tree_map(convert, (args, kwargs))
def to_device_no_copy(device, args, kwargs):
def safe_to(t):
if device == "meta":
return t.to(device=device)
else:
return torch.empty_like(t, device=device)
return tree_map_only(torch.Tensor, safe_to, (args, kwargs))

View File

@ -6,5 +6,6 @@
namespace openreg { namespace openreg {
void set_impl_registry(PyObject* registry); void set_impl_registry(PyObject* registry);
py::function get_method(const char* name);
} }

View File

@ -11,10 +11,6 @@ namespace {
// Python dictionary where real implementations can be found // Python dictionary where real implementations can be found
PyObject* py_registry; PyObject* py_registry;
py::function get_method(const char* name) {
return py::cast<py::dict>(py_registry)[name];
}
// C++ hooks implementation // C++ hooks implementation
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {}; struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};
@ -243,4 +239,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
void set_impl_registry(PyObject* registry) { void set_impl_registry(PyObject* registry) {
py_registry = registry; py_registry = registry;
} }
py::function get_method(const char* name) {
auto dict = py::cast<py::dict>(py_registry);
TORCH_CHECK(dict.contains(name), "OpenReg registry does not contain ",
"an implementation for '", name, "' make sure to add it in the __init__.py "
"file and register it.")
return dict[name];
}
} // openreg } // openreg

View File

@ -0,0 +1,125 @@
#include <OpenReg.h>
#include <torch/library.h>
#include <c10/core/Allocator.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/EmptyTensor.h>
#include <c10/util/ArrayRef.h>
#include <c10/core/TensorOptions.h>
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/set_cpu_dispatch.h>
namespace openreg {
namespace {
using openreg_ptr_t = uint64_t;
// A dummy allocator for our custom device, that secretly uses the CPU
struct OpenRegAllocator final : at::Allocator {
OpenRegAllocator() = default;
at::DataPtr allocate(size_t nbytes) override {
py::gil_scoped_acquire acquire;
auto curr_device_idx = get_method("getDevice")().cast<c10::DeviceIndex>();
auto curr_device = c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx);
void* data = nullptr;
if (nbytes > 0) {
data = reinterpret_cast<void*>(get_method("malloc")(nbytes).cast<openreg_ptr_t>());
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on openreg device.");
}
return {data, data, &ReportAndDelete, curr_device};
}
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
py::gil_scoped_acquire acquire;
TORCH_CHECK(
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
"Failed to free memory pointer at ", ptr
);
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
py::gil_scoped_acquire acquire;
get_method("copy_data")(reinterpret_cast<openreg_ptr_t>(dest), reinterpret_cast<openreg_ptr_t>(src), count);
}
};
// Register our dummy allocator
static OpenRegAllocator global_openreg_alloc;
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc);
// Empty op needs C++ code and cannot be handled by python side fallback
at::Tensor empty_openreg(
c10::IntArrayRef size,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported");
TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(
size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt);
}
at::Tensor empty_strided_openreg(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported");
TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_strided_generic(
size, stride, &global_openreg_alloc, pu1_dks, dtype);
}
at::Tensor as_strided_openreg(
const at::Tensor& self,
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<int64_t> storage_offset_) {
// Metadata-only change so we re-use the cpu impl
return at::cpu::as_strided(self, size, stride, storage_offset_);
}
at::Tensor& set_openreg(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
return at::cpu::set_(result, storage, storage_offset, size, stride);
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("empty.memory_format", empty_openreg);
m.impl("empty_strided", empty_strided_openreg);
m.impl("as_strided", as_strided_openreg);
m.impl("set_.source_Storage_storage_offset", set_openreg);
}
} // anonymous namspaces
} // openreg

View File

@ -7,17 +7,17 @@ import psutil
import pytorch_openreg import pytorch_openreg
import torch import torch
from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
class TestOpenReg(TestCase): class TestOpenReg(TestCase):
def test_initializes(self): def test_initializes(self):
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg") self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
@unittest.skipIf(not IS_LINUX, "Only works on linux") @unittest.SkipTest
def test_autograd_init(self): def test_autograd_init(self):
# Make sure autograd is initialized # Make sure autograd is initialized
torch.rand(2, requires_grad=True, device="openreg").sum().backward() torch.ones(2, requires_grad=True, device="openreg").sum().backward()
pid = os.getpid() pid = os.getpid()
task_path = f"/proc/{pid}/task" task_path = f"/proc/{pid}/task"
@ -30,9 +30,35 @@ class TestOpenReg(TestCase):
thread_name = file.read().strip() thread_name = file.read().strip()
all_thread_names.add(thread_name) all_thread_names.add(thread_name)
for i in range(pytorch_openreg.NUM_DEVICES): for i in range(pytorch_openreg._device_daemon.NUM_DEVICES):
self.assertIn(f"pt_autograd_{i}", all_thread_names) self.assertIn(f"pt_autograd_{i}", all_thread_names)
def test_factory(self):
a = torch.empty(50, device="openreg")
self.assertEqual(a.device.type, "openreg")
a.fill_(3.5)
self.assertTrue(a.eq(3.5).all())
def test_printing(self):
a = torch.ones(20, device="openreg")
# Does not crash!
str(a)
def test_cross_device_copy(self):
a = torch.rand(10)
b = a.to(device="openreg").add(2).to(device="cpu")
self.assertEqual(b, a + 2)
def test_data_dependent_output(self):
cpu_a = torch.randn(10)
a = cpu_a.to(device="openreg")
mask = a.gt(0)
out = torch.masked_select(a, mask)
self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0)))
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()