mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46624 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24761656 Pulled By: xuzhao9 fbshipit-source-id: b55aee5dd2b97f573a50e5bbfddde7d984943fec
237 lines
8.8 KiB
Python
237 lines
8.8 KiB
Python
import collections
|
|
import copyreg
|
|
import io
|
|
import pickle
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
from enum import Enum
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from torch._C._distributed_rpc import _get_current_rpc_agent
|
|
|
|
|
|
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
|
# objects
|
|
_thread_local_tensor_tables = threading.local()
|
|
|
|
|
|
class RPCExecMode(Enum):
|
|
SYNC = "sync"
|
|
ASYNC = "async"
|
|
ASYNC_JIT = "async_jit"
|
|
REMOTE = "remote"
|
|
|
|
|
|
class _InternalRPCPickler:
|
|
r"""
|
|
This class provides serialize() and deserialize() interfaces to serialize
|
|
data to be "binary string + tensor table" format
|
|
So for RPC python UDF function and args, non tensor data will be serialized
|
|
into regular binary string, tensor data will be put into thread local tensor
|
|
tables, this serialization format is consistent with builtin operator and args
|
|
using JIT pickler. This format will make tensor handling in C++ much easier,
|
|
e.g. attach tensor to distributed autograd graph in C++
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
|
|
self._dispatch_table[torch.Tensor] = self._tensor_reducer
|
|
|
|
@classmethod
|
|
def _tensor_receiver(cls, tensor_index):
|
|
global _thread_local_tensor_tables
|
|
return _thread_local_tensor_tables.recv_tables[tensor_index]
|
|
|
|
def _tensor_reducer(self, tensor):
|
|
global _thread_local_tensor_tables
|
|
_thread_local_tensor_tables.send_tables.append(tensor)
|
|
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
|
|
return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
|
|
|
|
@classmethod
|
|
def _py_rref_receiver(cls, rref_fork_data):
|
|
return dist.rpc.PyRRef._deserialize(rref_fork_data)
|
|
|
|
def _py_rref_reducer(self, py_rref):
|
|
rref_fork_data = py_rref._serialize()
|
|
return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
|
|
|
|
def _rref_reducer(self, rref):
|
|
return self._py_rref_reducer(rref)
|
|
|
|
def serialize(self, obj):
|
|
r"""
|
|
Serialize non tensor data into binary string, tensor data into
|
|
tensor table
|
|
"""
|
|
f = io.BytesIO()
|
|
p = pickle.Pickler(f)
|
|
p.dispatch_table = self._dispatch_table
|
|
|
|
# rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
|
|
# user picklers could have different initialization function from _InternalRPCPickler,
|
|
# but all the user picklers should call serialize() and use _rref_reducer to pickle rref
|
|
# in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
|
|
# compiled yet, it is not good place to acces rpc.RRef inside _InternalRPCPickler constructor,
|
|
# so puting rref's dispatch table here
|
|
#
|
|
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
|
|
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
|
|
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
|
|
# Ignore type error because dispatch_table is defined in third-party package
|
|
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
|
|
|
|
# save _thread_local_tensor_tables.send_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "send_tables"):
|
|
old_send_tables = _thread_local_tensor_tables.send_tables
|
|
else:
|
|
old_send_tables = None
|
|
_thread_local_tensor_tables.send_tables = []
|
|
|
|
p.dump(obj)
|
|
|
|
# restore _thread_local_tensor_tables.send_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
tensors = _thread_local_tensor_tables.send_tables
|
|
if old_send_tables is not None:
|
|
_thread_local_tensor_tables.send_tables = old_send_tables
|
|
else:
|
|
del _thread_local_tensor_tables.send_tables
|
|
|
|
return (f.getvalue(), tensors)
|
|
|
|
def deserialize(self, binary_data, tensor_table):
|
|
r"""
|
|
Deserilize binary string + tensor table to original obj
|
|
"""
|
|
# save _thread_local_tensor_tables.recv_tables if it is in nested call
|
|
global _thread_local_tensor_tables
|
|
if hasattr(_thread_local_tensor_tables, "recv_tables"):
|
|
old_recv_tables = _thread_local_tensor_tables.recv_tables
|
|
else:
|
|
old_recv_tables = None
|
|
_thread_local_tensor_tables.recv_tables = tensor_table
|
|
|
|
try:
|
|
ret = pickle.loads(binary_data)
|
|
except AttributeError as e:
|
|
# Occurs when function is not found on module/class during
|
|
# unpickling.
|
|
except_str = (
|
|
str(e)
|
|
+ """ Default RPC pickler does not serialize
|
|
function code. Ensure that UDFs are defined on both caller and
|
|
callee modules."""
|
|
)
|
|
ret = AttributeError(except_str)
|
|
|
|
# restore _thread_local_tensor_tables.recv_tables if return
|
|
# from nested call, otherwise clean up the table
|
|
if old_recv_tables is not None:
|
|
_thread_local_tensor_tables.recv_tables = old_recv_tables
|
|
else:
|
|
del _thread_local_tensor_tables.recv_tables
|
|
|
|
return ret
|
|
|
|
|
|
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
|
|
_internal_rpc_pickler = _InternalRPCPickler()
|
|
|
|
|
|
def serialize(obj):
|
|
return _internal_rpc_pickler.serialize(obj)
|
|
|
|
|
|
def deserialize(binary_data, tensor_table):
|
|
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
|
|
|
|
|
|
def _run_function(python_udf):
|
|
r"""
|
|
This function is exclusively called from C++.
|
|
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
|
|
|
|
Runs a Python UDF and returns its return value.
|
|
Wraps any exception in ``RemoteException`` if the function raises.
|
|
"""
|
|
try:
|
|
if isinstance(python_udf, AttributeError):
|
|
raise python_udf
|
|
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
|
|
except Exception as e:
|
|
# except str = exception info + traceback string
|
|
except_str = (
|
|
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
|
|
f"{repr(e)}\n{traceback.format_exc()}"
|
|
)
|
|
print(except_str, file=sys.stderr)
|
|
result = RemoteException(except_str, type(e))
|
|
return result
|
|
|
|
|
|
def _handle_exception(result):
|
|
if isinstance(result, RemoteException):
|
|
raise result.exception_type(result.msg)
|
|
|
|
|
|
def _build_rpc_profiling_key(
|
|
exec_type, func_name, current_worker_name, dst_worker_name
|
|
):
|
|
"""
|
|
Builds the key that RPC calls are profiled with using the autograd profiler.
|
|
This will be the name of the corresponding Event recorded in the profiler.
|
|
|
|
Arguments:
|
|
exec_type (RPCExecMode): Type of RPC/RRef call
|
|
func_name (str): Name of function being profiled.
|
|
current_worker_name (str): Name of current worker.
|
|
dst_worker_name (str): Name of the destination worker.
|
|
|
|
Returns:
|
|
String representing profiling key
|
|
"""
|
|
profile_key = "rpc_{rpc_type}#{func_name}({current_worker} -> {dst_worker})".format(
|
|
rpc_type=exec_type.value,
|
|
func_name=func_name,
|
|
current_worker=current_worker_name,
|
|
dst_worker=dst_worker_name,
|
|
)
|
|
return profile_key
|
|
|
|
|
|
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
|
|
"""
|
|
This function should be called from RPC/RRef functions to create a
|
|
RecordFunction object for profiling. This function also runs the before
|
|
callbacks that start the profiling, though the user is responsible for
|
|
running the appropriate callbacks when the function to be profiled finishes.
|
|
|
|
Arguments:
|
|
exec_type (RPCExecMode): Type of RPC/RRef call
|
|
func_name (str): Name of function being profiled.
|
|
current_worker_name (str): Name of current worker.
|
|
dest_worker_name (str): Name of the destination worker.
|
|
|
|
Returns:
|
|
An instance of `torch.autograd._RecordFunction`.
|
|
"""
|
|
assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
|
|
profile_key = "rpc_{}#{}({} -> {})".format(
|
|
exec_type.value, str(func_name), current_worker_name, dest_worker_name
|
|
)
|
|
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
|
|
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
|
|
return rf
|
|
|
|
|
|
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
|
|
RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])
|