mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add device daemon (#131814)
Base implementation aiming towards https://github.com/pytorch/rfcs/pull/64 Details of the implementation and next steps in https://github.com/pytorch/pytorch/blob/gh/albanD/3/head/test/cpp_extensions/open_registration_extension/README.md Pull Request resolved: https://github.com/pytorch/pytorch/pull/131814 Approved by: https://github.com/ezyang
This commit is contained in:
2
.flake8
2
.flake8
@ -57,7 +57,7 @@ per-file-ignores =
|
|||||||
torch/distributed/_tensor/_collective_utils.py: TOR901
|
torch/distributed/_tensor/_collective_utils.py: TOR901
|
||||||
# This is a full package that happen to live within the test
|
# This is a full package that happen to live within the test
|
||||||
# folder, so ok to skip
|
# folder, so ok to skip
|
||||||
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901
|
test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py: TOR901
|
||||||
optional-ascii-coding = True
|
optional-ascii-coding = True
|
||||||
exclude =
|
exclude =
|
||||||
./.git,
|
./.git,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core.
|
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
|
||||||
|
|
||||||
## How to use
|
## How to use
|
||||||
Install as standalone with `python setup.py develop` (or install) from this folder.
|
Install as standalone with `python setup.py develop` (or install) from this folder.
|
||||||
@ -8,6 +8,23 @@ You can run test via `python test/test_openreg.py`.
|
|||||||
For simplicity anything that can be implemented from python is done so.
|
For simplicity anything that can be implemented from python is done so.
|
||||||
A real implementation will most likely want to call these different APIs from c++ directly.
|
A real implementation will most likely want to call these different APIs from c++ directly.
|
||||||
|
|
||||||
The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn.
|
The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
|
||||||
|
|
||||||
Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate.
|
The codebase is split as follows:
|
||||||
|
- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module.
|
||||||
|
- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation
|
||||||
|
- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth.
|
||||||
|
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
|
||||||
|
|
||||||
|
## Next steps
|
||||||
|
|
||||||
|
Currently, the autograd test is disabled because it's missing the getStream implementation.
|
||||||
|
The main next step would be to:
|
||||||
|
- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations.
|
||||||
|
- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver.
|
||||||
|
- Add RNG Generator.
|
||||||
|
- Add Pinned memory and HostAllocator.
|
||||||
|
|
||||||
|
Longer term:
|
||||||
|
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.
|
||||||
|
- Build this module in the CI environment and enable Device-generic tests on this device.
|
||||||
|
@ -1,26 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# Global properties of our device
|
|
||||||
NUM_DEVICES = 7
|
|
||||||
|
|
||||||
# Create our python implementation dict so that the C++ module
|
# Create our python implementation dict so that the C++ module
|
||||||
# can access it during its initialization
|
# can access it during its initialization
|
||||||
_IMPL_REGISTRY = {}
|
# Also register aten impls
|
||||||
|
from ._aten_impl import _IMPL_REGISTRY as _IMPL_REGISTRY # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# Load the C++ Module
|
# Load the C++ Module
|
||||||
import pytorch_openreg._C # noqa: F401
|
import pytorch_openreg._C # noqa: F401 # usort: skip
|
||||||
|
|
||||||
|
|
||||||
# Define all the implementations in the registry
|
|
||||||
def register(fn):
|
|
||||||
_IMPL_REGISTRY[fn.__name__[1:]] = fn
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
@register
|
|
||||||
def _deviceCount():
|
|
||||||
return NUM_DEVICES
|
|
||||||
|
|
||||||
|
|
||||||
# Module used for our backend
|
# Module used for our backend
|
||||||
@ -31,15 +18,3 @@ class _OpenRegMod:
|
|||||||
# Set all the appropriate state on PyTorch
|
# Set all the appropriate state on PyTorch
|
||||||
torch.utils.rename_privateuse1_backend("openreg")
|
torch.utils.rename_privateuse1_backend("openreg")
|
||||||
torch._register_device_module("openreg", _OpenRegMod())
|
torch._register_device_module("openreg", _OpenRegMod())
|
||||||
|
|
||||||
_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901
|
|
||||||
|
|
||||||
|
|
||||||
def _openreg_kernel_fallback(op, *args, **kwargs):
|
|
||||||
print("Calling ", op)
|
|
||||||
assert op is torch.ops.aten.empty.memory_format
|
|
||||||
# FIXME: this returns a cpu Tensor which is NOT ok.
|
|
||||||
return torch.empty(args[0])
|
|
||||||
|
|
||||||
|
|
||||||
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
|
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._pytree import tree_any
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from ._device_daemon import daemon
|
||||||
|
from ._meta_parser import prepare_for_sending, to_device_no_copy
|
||||||
|
|
||||||
|
|
||||||
|
_IMPL_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Define all the implementations in the registry
|
||||||
|
def _register_same_name(name, with_log=False):
|
||||||
|
def _(*args, **kwargs):
|
||||||
|
if with_log:
|
||||||
|
log.info("Calling hook %s", name)
|
||||||
|
return daemon.exec(name, *args, **kwargs)
|
||||||
|
|
||||||
|
_IMPL_REGISTRY[name] = _
|
||||||
|
|
||||||
|
|
||||||
|
_register_same_name("deviceCount")
|
||||||
|
_register_same_name("getDevice")
|
||||||
|
_register_same_name("uncheckedSetDevice")
|
||||||
|
_register_same_name("exchangeDevice")
|
||||||
|
_register_same_name("malloc", True)
|
||||||
|
_register_same_name("free", True)
|
||||||
|
|
||||||
|
_openreg_lib = torch.library.Library("_", "IMPL")
|
||||||
|
|
||||||
|
|
||||||
|
def _openreg_kernel_fallback(op, *args, **kwargs):
|
||||||
|
log.info("Calling kernel %s", op)
|
||||||
|
|
||||||
|
# Special ops needed to avoid infinite recursion
|
||||||
|
if op is torch.ops.aten._copy_from.default:
|
||||||
|
from_, to_ = args
|
||||||
|
if from_.device.type == to_.device.type:
|
||||||
|
assert from_.device.type == "openreg"
|
||||||
|
op = torch.ops.aten.copy_.default
|
||||||
|
# handled below as a regular copy
|
||||||
|
elif from_.device.type == "openreg":
|
||||||
|
args, _ = prepare_for_sending((from_,), {})
|
||||||
|
host_mem = daemon.exec("send_data", *args)
|
||||||
|
return to_.copy_(host_mem)
|
||||||
|
elif to_.device.type == "openreg":
|
||||||
|
args, _ = prepare_for_sending((to_,), {})
|
||||||
|
daemon.exec("recv_data", from_, *args)
|
||||||
|
return to_
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Should not happen")
|
||||||
|
elif op is torch.ops.aten.set_.source_Tensor:
|
||||||
|
return torch.ops.aten.set_.source_Storage_storage_offset(
|
||||||
|
args[0],
|
||||||
|
args[1].untyped_storage(),
|
||||||
|
args[1].storage_offset(),
|
||||||
|
args[1].size(),
|
||||||
|
args[1].stride(),
|
||||||
|
)
|
||||||
|
elif op is torch.ops.aten._local_scalar_dense.default:
|
||||||
|
args, _ = prepare_for_sending(args, {})
|
||||||
|
host_mem = daemon.exec("send_data", *args)
|
||||||
|
return host_mem.item()
|
||||||
|
|
||||||
|
op_name = None
|
||||||
|
post_process = None
|
||||||
|
if "out" in op._overloadname:
|
||||||
|
# Note that all structured native op will call here
|
||||||
|
if isinstance(kwargs["out"], tuple):
|
||||||
|
raise RuntimeError(f"out= variant {op} with tuple out= not supported")
|
||||||
|
if kwargs["out"].nelement() == 0:
|
||||||
|
# Out variant that needs a resize, convert to an out of place
|
||||||
|
# and handle generically below
|
||||||
|
orig_out = kwargs["out"]
|
||||||
|
del kwargs["out"]
|
||||||
|
if op._overloadname != "out":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot retranslate non-default out= variant form 0 size"
|
||||||
|
)
|
||||||
|
op = op.overloadpacket.default
|
||||||
|
|
||||||
|
def _post_process():
|
||||||
|
nonlocal real_res
|
||||||
|
orig_out.set_(real_res)
|
||||||
|
real_res = orig_out
|
||||||
|
|
||||||
|
post_process = _post_process
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No metadata update to do, just run the op on the device
|
||||||
|
op_name = op.overloadpacket._qualified_op_name
|
||||||
|
real_res = kwargs["out"]
|
||||||
|
elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)):
|
||||||
|
# No Tensor argument means factory function
|
||||||
|
# They should decompose and be handled in our c++ side directly
|
||||||
|
raise RuntimeError(f"{op} not handled yet.")
|
||||||
|
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
|
||||||
|
# Only handle inplace ops returning their first arg
|
||||||
|
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
|
||||||
|
assert (
|
||||||
|
len(op._schema.returns) == 1
|
||||||
|
), f"NYI Inplace {op} with more than one return"
|
||||||
|
op_name = op.overloadpacket._qualified_op_name
|
||||||
|
real_res = args[0]
|
||||||
|
elif any(r.alias_info is not None for r in op._schema.returns):
|
||||||
|
# View ops
|
||||||
|
if op is torch.ops.aten.view.default:
|
||||||
|
return torch.ops.aten._unsafe_view(*args, **kwargs)
|
||||||
|
raise RuntimeError(f"{op} view op is not handled yet")
|
||||||
|
|
||||||
|
if op_name is None:
|
||||||
|
# 1. Compute updated metadata
|
||||||
|
if torch.Tag.dynamic_output_shape not in op.tags:
|
||||||
|
# Usual case: run the meta op to see the output metadata
|
||||||
|
meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs)
|
||||||
|
meta_res = op(*meta_args, **meta_kwargs)
|
||||||
|
|
||||||
|
# 2. Allocate the output
|
||||||
|
real_res, _ = to_device_no_copy("openreg", meta_res, {})
|
||||||
|
else:
|
||||||
|
# Slow version for data-dependent functions:
|
||||||
|
# Run the op on the device just to get the output shape
|
||||||
|
args_, kwargs_ = prepare_for_sending(args, kwargs)
|
||||||
|
shape = daemon.exec(
|
||||||
|
"get_op_output_shape",
|
||||||
|
op.overloadpacket._qualified_op_name,
|
||||||
|
args_,
|
||||||
|
kwargs_,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Allocate the output
|
||||||
|
real_res = args[0].new(shape)
|
||||||
|
|
||||||
|
# 3. Move to out variant
|
||||||
|
kwargs["out"] = real_res
|
||||||
|
# Let overload resolution find the out= overload
|
||||||
|
op_name = op.overloadpacket._qualified_op_name
|
||||||
|
|
||||||
|
# 4. Run the compute and populate the output on the device
|
||||||
|
args, kwargs = prepare_for_sending(args, kwargs)
|
||||||
|
daemon.exec("run_op", op_name, args, kwargs)
|
||||||
|
|
||||||
|
if post_process is not None:
|
||||||
|
post_process()
|
||||||
|
|
||||||
|
return real_res
|
||||||
|
|
||||||
|
|
||||||
|
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
|
@ -0,0 +1,168 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ._meta_parser import (
|
||||||
|
OpenRegTensorData,
|
||||||
|
receive_after_sending,
|
||||||
|
safe_str,
|
||||||
|
validate_send_queue_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
mp_context = torch.multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
# Constant properties of our device
|
||||||
|
NUM_DEVICES = 7
|
||||||
|
|
||||||
|
# Global state of our driver
|
||||||
|
CURR_DEVICE_IDX = 0
|
||||||
|
CURR_STREAM = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Our allocator
|
||||||
|
class Allocator:
|
||||||
|
def __init__(self):
|
||||||
|
self.allocated = {}
|
||||||
|
|
||||||
|
def malloc(self, size):
|
||||||
|
new_data = torch.empty(size, dtype=torch.uint8)
|
||||||
|
ptr = new_data.data_ptr()
|
||||||
|
self.allocated[ptr] = new_data
|
||||||
|
return ptr
|
||||||
|
|
||||||
|
def free(self, ptr):
|
||||||
|
if ptr not in self.allocated:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
del self.allocated[ptr]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def tensor_from_meta(self, meta):
|
||||||
|
# Usual case, we're receiving a known Tensor
|
||||||
|
found_base = self.allocated.get(meta.data_ptr, None)
|
||||||
|
|
||||||
|
# Might be a rewrap of another storage at a different offset
|
||||||
|
# Slow path to try and find the corresponding storage
|
||||||
|
if found_base is None:
|
||||||
|
for tag, t in self.allocated.items():
|
||||||
|
# t is always a 1D uint8 storage!
|
||||||
|
if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement():
|
||||||
|
# Blame @ngimel for this
|
||||||
|
slice_size = t.nelement() - (meta.data_ptr - tag)
|
||||||
|
found_base = torch.tensor((), dtype=torch.uint8).set_(
|
||||||
|
t.untyped_storage()[meta.data_ptr - tag :],
|
||||||
|
size=(slice_size,),
|
||||||
|
stride=(1,),
|
||||||
|
storage_offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This pointer is not allocated here, segfault !
|
||||||
|
if found_base is None:
|
||||||
|
log.info("Currently allocated blocks:\n %s", safe_str(self.allocated))
|
||||||
|
log.info("Trying to access %s", meta)
|
||||||
|
raise RuntimeError("SEGFAULT!")
|
||||||
|
|
||||||
|
# Raw 1d uint8 data
|
||||||
|
raw = found_base
|
||||||
|
# Slice the right storage part
|
||||||
|
raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes)
|
||||||
|
# Reinterpret cast in the right dtype
|
||||||
|
as_dtype = raw_slice.view(dtype=meta.dtype)
|
||||||
|
# View to the right shape/stride/offset
|
||||||
|
view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset)
|
||||||
|
return view
|
||||||
|
|
||||||
|
|
||||||
|
def run_op(allocator, op_name, args, kwargs):
|
||||||
|
op, _ = torch._C._jit_get_operation(op_name)
|
||||||
|
args, kwargs = receive_after_sending(allocator, args, kwargs)
|
||||||
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _Daemon:
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.is_initialized = False
|
||||||
|
|
||||||
|
def _lazy_init(self):
|
||||||
|
if self.is_initialized:
|
||||||
|
return
|
||||||
|
self.req_queue = mp_context.Queue()
|
||||||
|
self.ans_queue = mp_context.Queue()
|
||||||
|
|
||||||
|
self.runner = mp_context.Process(
|
||||||
|
target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True
|
||||||
|
)
|
||||||
|
self.runner.start()
|
||||||
|
self.is_initialized = True
|
||||||
|
|
||||||
|
def exec(self, cmd, *args):
|
||||||
|
self._lazy_init()
|
||||||
|
log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
|
||||||
|
validate_send_queue_args(cmd, args)
|
||||||
|
self.req_queue.put((cmd,) + args)
|
||||||
|
res = self.ans_queue.get()
|
||||||
|
log.info("Main process result for %s received: %s", cmd, safe_str(res))
|
||||||
|
if res == "ERROR":
|
||||||
|
raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
|
||||||
|
else:
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_forever(req_queue, ans_queue):
|
||||||
|
# Initialize our device
|
||||||
|
global CURR_DEVICE_IDX
|
||||||
|
empty_res = object()
|
||||||
|
allocator = Allocator()
|
||||||
|
|
||||||
|
# Serve all requests
|
||||||
|
while True:
|
||||||
|
cmd, *args = req_queue.get()
|
||||||
|
log.info("Worker executing: %s", cmd)
|
||||||
|
res = empty_res
|
||||||
|
if cmd == "deviceCount":
|
||||||
|
assert len(args) == 0
|
||||||
|
res = NUM_DEVICES
|
||||||
|
elif cmd == "getDevice":
|
||||||
|
res = CURR_DEVICE_IDX
|
||||||
|
elif cmd == "uncheckedSetDevice":
|
||||||
|
assert len(args) == 1
|
||||||
|
CURR_DEVICE_IDX = int(args[0])
|
||||||
|
res = None
|
||||||
|
elif cmd == "exchangeDevice":
|
||||||
|
assert len(args) == 1
|
||||||
|
res = CURR_DEVICE_IDX
|
||||||
|
CURR_DEVICE_IDX = int(args[0])
|
||||||
|
elif cmd == "malloc":
|
||||||
|
res = allocator.malloc(*args)
|
||||||
|
elif cmd == "free":
|
||||||
|
res = allocator.free(*args)
|
||||||
|
elif cmd == "run_op":
|
||||||
|
op_name, args, kwargs = args
|
||||||
|
run_op(allocator, op_name, args, kwargs)
|
||||||
|
res = None
|
||||||
|
elif cmd == "send_data":
|
||||||
|
assert len(args) == 1
|
||||||
|
res = OpenRegTensorData.from_meta(allocator, args[0])
|
||||||
|
elif cmd == "recv_data":
|
||||||
|
assert len(args) == 2
|
||||||
|
host_tensor, dev_mem = args
|
||||||
|
dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem)
|
||||||
|
dev_tensor.copy_(host_tensor)
|
||||||
|
res = None
|
||||||
|
elif cmd == "get_op_output_shape":
|
||||||
|
op_name, args, kwargs = args
|
||||||
|
res = run_op(allocator, op_name, args, kwargs).size()
|
||||||
|
else:
|
||||||
|
log.warning("Bad command in worker")
|
||||||
|
res = "ERROR"
|
||||||
|
|
||||||
|
if res == empty_res:
|
||||||
|
raise RuntimeError("Bad impl didn't return anything")
|
||||||
|
log.info("Worker answering to: %s", cmd)
|
||||||
|
ans_queue.put(res)
|
||||||
|
|
||||||
|
|
||||||
|
daemon = _Daemon()
|
@ -0,0 +1,104 @@
|
|||||||
|
import pprint
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._pytree import tree_map, tree_map_only
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRegTensorMeta:
|
||||||
|
def __init__(self, tensor, checked=True):
|
||||||
|
if checked and not tensor.device.type == "openreg":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Creating OpenRegTensorMeta is only for Tensors on openreg device"
|
||||||
|
)
|
||||||
|
self.data_ptr = tensor.untyped_storage().data_ptr()
|
||||||
|
self.size = tensor.size()
|
||||||
|
self.stride = tensor.stride()
|
||||||
|
self.storage_offset = tensor.storage_offset()
|
||||||
|
self.dtype = tensor.dtype
|
||||||
|
self.nelem_in_bytes = tensor.nelement() * tensor.element_size()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, "
|
||||||
|
f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRegTensorData(torch.Tensor):
|
||||||
|
@staticmethod
|
||||||
|
def from_meta(allocator, tensor_meta):
|
||||||
|
return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta))
|
||||||
|
|
||||||
|
|
||||||
|
VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float}
|
||||||
|
|
||||||
|
VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str}
|
||||||
|
|
||||||
|
|
||||||
|
def safe_str(args):
|
||||||
|
def convert(obj):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return str(OpenRegTensorMeta(obj, checked=False))
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
new_args = tree_map(convert, args)
|
||||||
|
return pprint.pformat(new_args)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_send_queue_args(cmd, args):
|
||||||
|
def check(obj):
|
||||||
|
if type(obj) not in VALID_QUEUE_TYPES_OUT:
|
||||||
|
if (
|
||||||
|
cmd == "recv_data"
|
||||||
|
and type(obj) is torch.Tensor
|
||||||
|
and obj.device.type == "cpu"
|
||||||
|
):
|
||||||
|
# Only HtoD copy command can send cpu Tensors over
|
||||||
|
return
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Trying to send invalid object through queue: {type(obj)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tree_map(check, args)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_for_sending(args, kwargs):
|
||||||
|
def convert(obj):
|
||||||
|
if type(obj) not in VALID_QUEUE_TYPES_IN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot send object of type {type(obj)} " "over openreg device pipe."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return OpenRegTensorMeta(obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return tree_map(convert, (args, kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def receive_after_sending(allocator, args, kwargs):
|
||||||
|
def convert(obj):
|
||||||
|
if type(obj) not in VALID_QUEUE_TYPES_OUT:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Received invalid object of type {type(obj)} "
|
||||||
|
"over openreg device pipe."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(obj, OpenRegTensorMeta):
|
||||||
|
return allocator.tensor_from_meta(obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return tree_map(convert, (args, kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def to_device_no_copy(device, args, kwargs):
|
||||||
|
def safe_to(t):
|
||||||
|
if device == "meta":
|
||||||
|
return t.to(device=device)
|
||||||
|
else:
|
||||||
|
return torch.empty_like(t, device=device)
|
||||||
|
|
||||||
|
return tree_map_only(torch.Tensor, safe_to, (args, kwargs))
|
@ -6,5 +6,6 @@
|
|||||||
namespace openreg {
|
namespace openreg {
|
||||||
|
|
||||||
void set_impl_registry(PyObject* registry);
|
void set_impl_registry(PyObject* registry);
|
||||||
|
py::function get_method(const char* name);
|
||||||
|
|
||||||
}
|
}
|
@ -11,10 +11,6 @@ namespace {
|
|||||||
// Python dictionary where real implementations can be found
|
// Python dictionary where real implementations can be found
|
||||||
PyObject* py_registry;
|
PyObject* py_registry;
|
||||||
|
|
||||||
py::function get_method(const char* name) {
|
|
||||||
return py::cast<py::dict>(py_registry)[name];
|
|
||||||
}
|
|
||||||
|
|
||||||
// C++ hooks implementation
|
// C++ hooks implementation
|
||||||
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};
|
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};
|
||||||
|
|
||||||
@ -243,4 +239,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
|||||||
void set_impl_registry(PyObject* registry) {
|
void set_impl_registry(PyObject* registry) {
|
||||||
py_registry = registry;
|
py_registry = registry;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::function get_method(const char* name) {
|
||||||
|
auto dict = py::cast<py::dict>(py_registry);
|
||||||
|
TORCH_CHECK(dict.contains(name), "OpenReg registry does not contain ",
|
||||||
|
"an implementation for '", name, "' make sure to add it in the __init__.py "
|
||||||
|
"file and register it.")
|
||||||
|
return dict[name];
|
||||||
|
}
|
||||||
} // openreg
|
} // openreg
|
@ -0,0 +1,125 @@
|
|||||||
|
#include <OpenReg.h>
|
||||||
|
|
||||||
|
#include <torch/library.h>
|
||||||
|
#include <c10/core/Allocator.h>
|
||||||
|
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||||
|
#include <ATen/EmptyTensor.h>
|
||||||
|
#include <c10/util/ArrayRef.h>
|
||||||
|
#include <c10/core/TensorOptions.h>
|
||||||
|
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
||||||
|
#include <ATen/ops/set_cpu_dispatch.h>
|
||||||
|
|
||||||
|
namespace openreg {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using openreg_ptr_t = uint64_t;
|
||||||
|
|
||||||
|
// A dummy allocator for our custom device, that secretly uses the CPU
|
||||||
|
struct OpenRegAllocator final : at::Allocator {
|
||||||
|
OpenRegAllocator() = default;
|
||||||
|
|
||||||
|
at::DataPtr allocate(size_t nbytes) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
auto curr_device_idx = get_method("getDevice")().cast<c10::DeviceIndex>();
|
||||||
|
auto curr_device = c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx);
|
||||||
|
void* data = nullptr;
|
||||||
|
if (nbytes > 0) {
|
||||||
|
data = reinterpret_cast<void*>(get_method("malloc")(nbytes).cast<openreg_ptr_t>());
|
||||||
|
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on openreg device.");
|
||||||
|
}
|
||||||
|
return {data, data, &ReportAndDelete, curr_device};
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ReportAndDelete(void* ptr) {
|
||||||
|
if (!ptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
TORCH_CHECK(
|
||||||
|
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
|
||||||
|
"Failed to free memory pointer at ", ptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::DeleterFnPtr raw_deleter() const override {
|
||||||
|
return &ReportAndDelete;
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
get_method("copy_data")(reinterpret_cast<openreg_ptr_t>(dest), reinterpret_cast<openreg_ptr_t>(src), count);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register our dummy allocator
|
||||||
|
static OpenRegAllocator global_openreg_alloc;
|
||||||
|
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc);
|
||||||
|
|
||||||
|
|
||||||
|
// Empty op needs C++ code and cannot be handled by python side fallback
|
||||||
|
at::Tensor empty_openreg(
|
||||||
|
c10::IntArrayRef size,
|
||||||
|
std::optional<c10::ScalarType> dtype_opt,
|
||||||
|
std::optional<c10::Layout> layout_opt,
|
||||||
|
std::optional<c10::Device> device_opt,
|
||||||
|
std::optional<bool> pin_memory_opt,
|
||||||
|
std::optional<c10::MemoryFormat> memory_format_opt) {
|
||||||
|
const auto device = c10::device_or_default(device_opt);
|
||||||
|
const auto dtype = c10::dtype_or_default(dtype_opt);
|
||||||
|
TORCH_CHECK(device.is_privateuseone());
|
||||||
|
TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported");
|
||||||
|
TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU");
|
||||||
|
const c10::DeviceGuard device_guard(device);
|
||||||
|
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
|
||||||
|
return at::detail::empty_generic(
|
||||||
|
size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor empty_strided_openreg(
|
||||||
|
c10::IntArrayRef size,
|
||||||
|
c10::IntArrayRef stride,
|
||||||
|
std::optional<c10::ScalarType> dtype_opt,
|
||||||
|
std::optional<c10::Layout> layout_opt,
|
||||||
|
std::optional<c10::Device> device_opt,
|
||||||
|
std::optional<bool> pin_memory_opt) {
|
||||||
|
const auto device = c10::device_or_default(device_opt);
|
||||||
|
const auto dtype = c10::dtype_or_default(dtype_opt);
|
||||||
|
TORCH_CHECK(device.is_privateuseone());
|
||||||
|
TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported");
|
||||||
|
TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU");
|
||||||
|
const c10::DeviceGuard device_guard(device);
|
||||||
|
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
|
||||||
|
return at::detail::empty_strided_generic(
|
||||||
|
size, stride, &global_openreg_alloc, pu1_dks, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor as_strided_openreg(
|
||||||
|
const at::Tensor& self,
|
||||||
|
c10::IntArrayRef size,
|
||||||
|
c10::IntArrayRef stride,
|
||||||
|
std::optional<int64_t> storage_offset_) {
|
||||||
|
// Metadata-only change so we re-use the cpu impl
|
||||||
|
return at::cpu::as_strided(self, size, stride, storage_offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor& set_openreg(
|
||||||
|
at::Tensor& result,
|
||||||
|
at::Storage storage,
|
||||||
|
int64_t storage_offset,
|
||||||
|
c10::IntArrayRef size,
|
||||||
|
c10::IntArrayRef stride) {
|
||||||
|
return at::cpu::set_(result, storage, storage_offset, size, stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||||
|
m.impl("empty.memory_format", empty_openreg);
|
||||||
|
m.impl("empty_strided", empty_strided_openreg);
|
||||||
|
m.impl("as_strided", as_strided_openreg);
|
||||||
|
m.impl("set_.source_Storage_storage_offset", set_openreg);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namspaces
|
||||||
|
|
||||||
|
} // openreg
|
@ -7,17 +7,17 @@ import psutil
|
|||||||
import pytorch_openreg
|
import pytorch_openreg
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestOpenReg(TestCase):
|
class TestOpenReg(TestCase):
|
||||||
def test_initializes(self):
|
def test_initializes(self):
|
||||||
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
|
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
|
||||||
|
|
||||||
@unittest.skipIf(not IS_LINUX, "Only works on linux")
|
@unittest.SkipTest
|
||||||
def test_autograd_init(self):
|
def test_autograd_init(self):
|
||||||
# Make sure autograd is initialized
|
# Make sure autograd is initialized
|
||||||
torch.rand(2, requires_grad=True, device="openreg").sum().backward()
|
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
|
||||||
|
|
||||||
pid = os.getpid()
|
pid = os.getpid()
|
||||||
task_path = f"/proc/{pid}/task"
|
task_path = f"/proc/{pid}/task"
|
||||||
@ -30,9 +30,35 @@ class TestOpenReg(TestCase):
|
|||||||
thread_name = file.read().strip()
|
thread_name = file.read().strip()
|
||||||
all_thread_names.add(thread_name)
|
all_thread_names.add(thread_name)
|
||||||
|
|
||||||
for i in range(pytorch_openreg.NUM_DEVICES):
|
for i in range(pytorch_openreg._device_daemon.NUM_DEVICES):
|
||||||
self.assertIn(f"pt_autograd_{i}", all_thread_names)
|
self.assertIn(f"pt_autograd_{i}", all_thread_names)
|
||||||
|
|
||||||
|
def test_factory(self):
|
||||||
|
a = torch.empty(50, device="openreg")
|
||||||
|
self.assertEqual(a.device.type, "openreg")
|
||||||
|
|
||||||
|
a.fill_(3.5)
|
||||||
|
|
||||||
|
self.assertTrue(a.eq(3.5).all())
|
||||||
|
|
||||||
|
def test_printing(self):
|
||||||
|
a = torch.ones(20, device="openreg")
|
||||||
|
# Does not crash!
|
||||||
|
str(a)
|
||||||
|
|
||||||
|
def test_cross_device_copy(self):
|
||||||
|
a = torch.rand(10)
|
||||||
|
b = a.to(device="openreg").add(2).to(device="cpu")
|
||||||
|
self.assertEqual(b, a + 2)
|
||||||
|
|
||||||
|
def test_data_dependent_output(self):
|
||||||
|
cpu_a = torch.randn(10)
|
||||||
|
a = cpu_a.to(device="openreg")
|
||||||
|
mask = a.gt(0)
|
||||||
|
out = torch.masked_select(a, mask)
|
||||||
|
|
||||||
|
self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
Reference in New Issue
Block a user