mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499 See #23110 for model parallel design details, and #26759 for the RRef protocol. This commit add support for using RRef as Python UDF arguments and return value. RRefs can now be shared from owner to user, from user to owner, or from user to user. Limitations: 1. No implicit type conversion yet. (#27099) 2. No failure handling and retry. (#26116) 3. UDF is not yet blocked until all RRefs are confirmed. (#27098) 4. Internal RRef control messages are not idempotent yet. (#26116) 5. Cannot delete RRefs correctly when there are circular dependencies. (#27096) Main changes: 1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations. 2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages. 3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`. 4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure. 5. Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs. 6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`. Test Plan: Imported from OSS buck test mode/dev-nosan //caffe2/test:rpc_fork Differential Revision: D17184146 Pulled By: mrshenli fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
136 lines
4.9 KiB
Python
136 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import collections
|
|
import copyreg
|
|
import io
|
|
import pickle
|
|
import six
|
|
import threading
|
|
import traceback
|
|
|
|
import torch
|
|
|
|
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
|
# objects
|
|
_thread_local_tensor_tables = threading.local()
|
|
|
|
|
|
class _InternalRPCPickler:
|
|
r"""
|
|
This class provides serialize() and deserialize() interfaces to serialize
|
|
data to be "binary string + tensor table" format
|
|
So for RPC python UDF function and args, non tensor data will be serialized
|
|
into regular binary string, tensor data will be put into thread local tensor
|
|
tables, this serialization format is consistent with builtin operator and args
|
|
using JIT pickler. This format will make tensor handling in C++ much easier,
|
|
e.g. attach tensor to distributed autograd graph in C++
|
|
"""
|
|
def __init__(self):
|
|
# python2 does not have dispatch_table, add "if six.PY3" condition,
|
|
# as _InternalRPCPickler still got build in python2 even
|
|
# we skipped python 2 tests for rpc_test
|
|
if six.PY3:
|
|
self._dispatch_table = copyreg.dispatch_table.copy()
|
|
self._dispatch_table[torch.Tensor] = self._tensor_reducer
|
|
|
|
@classmethod
|
|
def _tensor_receiver(cls, tensor_index):
|
|
global _thread_local_tensor_tables
|
|
return _thread_local_tensor_tables.recv_tables[tensor_index]
|
|
|
|
def _tensor_reducer(self, obj):
|
|
global _thread_local_tensor_tables
|
|
_thread_local_tensor_tables.send_tables.append(obj)
|
|
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
|
|
return (_InternalRPCPickler._tensor_receiver, (tensor_index, ))
|
|
|
|
def serialize(self, obj):
|
|
r"""
|
|
Serialize non tensor data into binary string, tensor data into
|
|
tensor table
|
|
"""
|
|
f = io.BytesIO()
|
|
p = pickle.Pickler(f)
|
|
p.dispatch_table = self._dispatch_table
|
|
|
|
# save _thread_local_tensor_tables.send_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "send_tables"):
|
|
old_send_tables = _thread_local_tensor_tables.send_tables
|
|
else:
|
|
old_send_tables = None
|
|
_thread_local_tensor_tables.send_tables = []
|
|
|
|
p.dump(obj)
|
|
|
|
# restore _thread_local_tensor_tables.send_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
tensors = _thread_local_tensor_tables.send_tables
|
|
if old_send_tables is not None:
|
|
_thread_local_tensor_tables.send_tables = old_send_tables
|
|
else:
|
|
del _thread_local_tensor_tables.send_tables
|
|
|
|
return (f.getvalue(), tensors)
|
|
|
|
def deserialize(self, binary_data, tensor_table):
|
|
r"""
|
|
Deserilize binary string + tensor table to original obj
|
|
"""
|
|
# save _thread_local_tensor_tables.recv_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "recv_tables"):
|
|
old_recv_tables = _thread_local_tensor_tables.recv_tables
|
|
else:
|
|
old_recv_tables = None
|
|
_thread_local_tensor_tables.recv_tables = tensor_table
|
|
|
|
ret = pickle.loads(binary_data)
|
|
|
|
# restore _thread_local_tensor_tables.recv_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
if old_recv_tables is not None:
|
|
_thread_local_tensor_tables.recv_tables = old_recv_tables
|
|
else:
|
|
del _thread_local_tensor_tables.recv_tables
|
|
|
|
return ret
|
|
|
|
|
|
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
|
|
_internal_rpc_pickler = _InternalRPCPickler()
|
|
|
|
def serialize(obj):
|
|
return _internal_rpc_pickler.serialize(obj)
|
|
|
|
def run_python_udf_internal(pickled_python_udf, tensors):
|
|
r"""
|
|
Internal python function will be imported and executed in C++ land
|
|
it unpickles pickled python udf strings and tensors and run the python
|
|
udf, return serialized result and tensor tables
|
|
"""
|
|
python_udf = _internal_rpc_pickler.deserialize(pickled_python_udf, tensors)
|
|
try:
|
|
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
|
|
except Exception as e:
|
|
# except str = exception info + traceback string
|
|
except_str = "{}\n{}".format(repr(e), traceback.format_exc())
|
|
result = RemoteException(except_str)
|
|
# return _internal_rpc_pickler.serialize(result)
|
|
return result
|
|
|
|
|
|
def load_python_udf_result_internal(pickled_python_result, tensors):
|
|
r"""
|
|
Internal python function will be imported and executed in C++ land
|
|
it unpickled pickled python udf result and tensor tables, return python object
|
|
"""
|
|
result = _internal_rpc_pickler.deserialize(pickled_python_result, tensors)
|
|
if isinstance(result, RemoteException):
|
|
raise Exception(result.msg)
|
|
return result
|
|
|
|
|
|
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
|
|
RemoteException = collections.namedtuple("RemoteException", ["msg"])
|