Files
pytorch/torch/distributed/internal_rpc_utils.py
Shen Li 2486b0ba82 Add Python RRef as args and return value (#25499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499

See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
2019-10-03 17:47:12 -07:00

136 lines
4.9 KiB
Python

#!/usr/bin/env python3
import collections
import copyreg
import io
import pickle
import six
import threading
import traceback
import torch
# Thread local tensor tables to store tensors while pickling torch.Tensor
# objects
_thread_local_tensor_tables = threading.local()
class _InternalRPCPickler:
r"""
This class provides serialize() and deserialize() interfaces to serialize
data to be "binary string + tensor table" format
So for RPC python UDF function and args, non tensor data will be serialized
into regular binary string, tensor data will be put into thread local tensor
tables, this serialization format is consistent with builtin operator and args
using JIT pickler. This format will make tensor handling in C++ much easier,
e.g. attach tensor to distributed autograd graph in C++
"""
def __init__(self):
# python2 does not have dispatch_table, add "if six.PY3" condition,
# as _InternalRPCPickler still got build in python2 even
# we skipped python 2 tests for rpc_test
if six.PY3:
self._dispatch_table = copyreg.dispatch_table.copy()
self._dispatch_table[torch.Tensor] = self._tensor_reducer
@classmethod
def _tensor_receiver(cls, tensor_index):
global _thread_local_tensor_tables
return _thread_local_tensor_tables.recv_tables[tensor_index]
def _tensor_reducer(self, obj):
global _thread_local_tensor_tables
_thread_local_tensor_tables.send_tables.append(obj)
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
return (_InternalRPCPickler._tensor_receiver, (tensor_index, ))
def serialize(self, obj):
r"""
Serialize non tensor data into binary string, tensor data into
tensor table
"""
f = io.BytesIO()
p = pickle.Pickler(f)
p.dispatch_table = self._dispatch_table
# save _thread_local_tensor_tables.send_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "send_tables"):
old_send_tables = _thread_local_tensor_tables.send_tables
else:
old_send_tables = None
_thread_local_tensor_tables.send_tables = []
p.dump(obj)
# restore _thread_local_tensor_tables.send_tables if return
# from nested call, otherwise clean up the table
tensors = _thread_local_tensor_tables.send_tables
if old_send_tables is not None:
_thread_local_tensor_tables.send_tables = old_send_tables
else:
del _thread_local_tensor_tables.send_tables
return (f.getvalue(), tensors)
def deserialize(self, binary_data, tensor_table):
r"""
Deserilize binary string + tensor table to original obj
"""
# save _thread_local_tensor_tables.recv_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "recv_tables"):
old_recv_tables = _thread_local_tensor_tables.recv_tables
else:
old_recv_tables = None
_thread_local_tensor_tables.recv_tables = tensor_table
ret = pickle.loads(binary_data)
# restore _thread_local_tensor_tables.recv_tables if return
# from nested call, otherwise clean up the table
if old_recv_tables is not None:
_thread_local_tensor_tables.recv_tables = old_recv_tables
else:
del _thread_local_tensor_tables.recv_tables
return ret
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
_internal_rpc_pickler = _InternalRPCPickler()
def serialize(obj):
return _internal_rpc_pickler.serialize(obj)
def run_python_udf_internal(pickled_python_udf, tensors):
r"""
Internal python function will be imported and executed in C++ land
it unpickles pickled python udf strings and tensors and run the python
udf, return serialized result and tensor tables
"""
python_udf = _internal_rpc_pickler.deserialize(pickled_python_udf, tensors)
try:
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
except Exception as e:
# except str = exception info + traceback string
except_str = "{}\n{}".format(repr(e), traceback.format_exc())
result = RemoteException(except_str)
# return _internal_rpc_pickler.serialize(result)
return result
def load_python_udf_result_internal(pickled_python_result, tensors):
r"""
Internal python function will be imported and executed in C++ land
it unpickled pickled python udf result and tensor tables, return python object
"""
result = _internal_rpc_pickler.deserialize(pickled_python_result, tensors)
if isinstance(result, RemoteException):
raise Exception(result.msg)
return result
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
RemoteException = collections.namedtuple("RemoteException", ["msg"])