mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42682 ghstack-source-id: 109834351 Test Plan: CI Reviewed By: malfet Differential Revision: D22978717 fbshipit-source-id: 18b7cbdb532e78ff9259e82f0f92ad279124419d
786 lines
32 KiB
C++
786 lines
32 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/distributed/rpc/process_group_agent.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
|
|
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
#include <torch/csrc/distributed/rpc/python_functions.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
|
|
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
|
|
#include <torch/csrc/distributed/rpc/types.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/types.h>
|
|
|
|
#include <pybind11/chrono.h>
|
|
#include <pybind11/operators.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
namespace {
|
|
|
|
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
|
|
constexpr float kSecToMsConversion = 1000;
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
PyObject* rpc_init(PyObject* /* unused */) {
|
|
auto rpc_module =
|
|
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
|
|
if (!rpc_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto module = py::handle(rpc_module).cast<py::module>();
|
|
|
|
auto rpcBackendOptions =
|
|
shared_ptr_class_<RpcBackendOptions>(
|
|
module,
|
|
"RpcBackendOptions",
|
|
R"(An abstract structure encapsulating the options passed into the RPC
|
|
backend. An instance of this class can be passed in to
|
|
:meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
|
|
with specific configurations, such as the RPC timeout and
|
|
``init_method`` to be used. )")
|
|
.def_readwrite(
|
|
"rpc_timeout",
|
|
&RpcBackendOptions::rpcTimeoutSeconds,
|
|
R"(A float indicating the timeout to use for all
|
|
RPCs. If an RPC does not complete in this timeframe, it will
|
|
complete with an exception indicating that it has timed out.)")
|
|
.def_readwrite(
|
|
"init_method",
|
|
&RpcBackendOptions::initMethod,
|
|
R"(URL specifying how to initialize the process group.
|
|
Default is ``env://``)");
|
|
|
|
// The following C++ constants need to be cast so they can be used from
|
|
// python.
|
|
module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
|
|
module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
|
|
module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
|
|
|
|
auto workerInfo =
|
|
shared_ptr_class_<WorkerInfo>(
|
|
module,
|
|
"WorkerInfo",
|
|
R"(A structure that encapsulates information of a worker in the system.
|
|
Contains the name and ID of the worker. This class is not meant to
|
|
be constructed directly, rather, an instance can be retrieved
|
|
through :meth:`~torch.distributed.rpc.get_worker_info` and the
|
|
result can be passed in to functions such as
|
|
:meth:`~torch.distributed.rpc.rpc_sync`, :class:`~torch.distributed.rpc.rpc_async`,
|
|
:meth:`~torch.distributed.rpc.remote` to avoid copying a string on
|
|
every invocation.)")
|
|
.def(
|
|
py::init<std::string, worker_id_t>(),
|
|
py::arg("name"),
|
|
py::arg("id"))
|
|
.def_readonly(
|
|
"name", &WorkerInfo::name_, R"(The name of the worker.)")
|
|
.def_readonly(
|
|
"id",
|
|
&WorkerInfo::id_,
|
|
R"(Globally unique id to identify the worker.)")
|
|
.def("__eq__", &WorkerInfo::operator==, py::is_operator())
|
|
// pybind11 suggests the syntax .def(hash(py::self)), with the
|
|
// unqualified "hash" function call. However the
|
|
// argument-dependent lookup for the function "hash" doesn't get
|
|
// triggered in this context because it conflicts with the struct
|
|
// torch::hash, so we need to use the qualified name
|
|
// py::detail::hash, which unfortunately is in a detail namespace.
|
|
.def(py::detail::hash(py::self)) // NOLINT
|
|
.def("__repr__", [](const WorkerInfo& workerInfo) {
|
|
std::ostringstream os;
|
|
os << workerInfo;
|
|
return os.str();
|
|
});
|
|
|
|
auto rpcAgent =
|
|
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
|
|
.def(
|
|
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
&RpcAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_debug_info",
|
|
&RpcAgent::getDebugInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_metrics",
|
|
&RpcAgent::getMetrics,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
auto pyRRef =
|
|
shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
|
|
A class encapsulating a reference to a value of some type on a remote
|
|
worker. This handle will keep the referenced remote value alive on the
|
|
worker. A ``UserRRef`` will be deleted when 1) no references to it in
|
|
both the application code and in the local RRef context, or 2) the
|
|
application has called a graceful shutdown. Invoking methods on a
|
|
deleted RRef leads to undefined behaviors. RRef implementation only
|
|
offers best-effort error detection, and applications should not use
|
|
``UserRRefs`` after ``rpc.shutdown()``.
|
|
|
|
.. warning::
|
|
RRefs can only be serialized and deserialized by the RPC module.
|
|
Serializing and deserializing RRefs without RPC (e.g., Python
|
|
pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
|
|
JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
|
|
lead to errors.
|
|
|
|
Example::
|
|
Following examples skip RPC initialization and shutdown code
|
|
for simplicity. Refer to RPC docs for those details.
|
|
|
|
1. Create an RRef using rpc.remote
|
|
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
|
>>> # get a copy of value from the RRef
|
|
>>> x = rref.to_here()
|
|
|
|
2. Create an RRef from a local object
|
|
|
|
>>> import torch
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> x = torch.zeros(2, 2)
|
|
>>> rref = RRef(x)
|
|
|
|
3. Share an RRef with other workers
|
|
|
|
>>> # On both worker0 and worker1:
|
|
>>> def f(rref):
|
|
>>> return rref.to_here() + 1
|
|
|
|
>>> # On worker0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> rref = RRef(torch.zeros(2, 2))
|
|
>>> # the following RPC shares the rref with worker1, reference
|
|
>>> # count is automatically updated.
|
|
>>> rpc.rpc_sync("worker1", f, args=(rref,))
|
|
)")
|
|
.def(
|
|
py::init<const py::object&, const py::object&>(),
|
|
py::arg("value"),
|
|
py::arg("type_hint") = py::none())
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"is_owner",
|
|
&PyRRef::isOwner,
|
|
R"(
|
|
Returns whether or not the current node is the owner of this
|
|
``RRef``.
|
|
)")
|
|
.def(
|
|
"confirmed_by_owner",
|
|
&PyRRef::confirmedByOwner,
|
|
R"(
|
|
Returns whether this ``RRef`` has been confirmed by the owner.
|
|
``OwnerRRef`` always returns true, while ``UserRRef`` only
|
|
returns true when the owner knowns about this ``UserRRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner",
|
|
&PyRRef::owner,
|
|
R"(
|
|
Returns worker information of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner_name",
|
|
&PyRRef::ownerName,
|
|
R"(
|
|
Returns worker name of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
"to_here",
|
|
&PyRRef::toHere,
|
|
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Blocking call that copies the value of the RRef from the owner
|
|
to the local node and returns it. If the current node is the
|
|
owner, returns a reference to the local value.
|
|
|
|
Arguments:
|
|
timeout (float, optional): Timeout for ``to_here``. If
|
|
the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this
|
|
argument is not provided, the default RPC timeout
|
|
(60s) will be used.
|
|
)")
|
|
.def(
|
|
"local_value",
|
|
&PyRRef::localValue,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
If the current node is the owner, returns a reference to the
|
|
local value. Otherwise, throws an exception.
|
|
)")
|
|
.def(
|
|
"rpc_sync",
|
|
[](const PyRRef& self) {
|
|
return self.createRRefProxy(RRefProxyType::RPC_SYNC);
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_sync`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_sync().size() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"rpc_async",
|
|
[](const PyRRef& self) {
|
|
return self.createRRefProxy(RRefProxyType::RPC_ASYNC);
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_async`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"remote",
|
|
[](const PyRRef& self) {
|
|
return self.createRRefProxy(RRefProxyType::REMOTE);
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch a ``remote`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.remote().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.remote().size().to_here() # returns torch.Size([2, 2])
|
|
>>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
py::pickle(
|
|
/* __getstate__ */
|
|
[](const PyRRef& /* unused */) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not pickle rref in python pickler, rref can only be "
|
|
"pickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy Pybind API's
|
|
// requirement.
|
|
return py::make_tuple();
|
|
},
|
|
/* __setstate__ */
|
|
[](py::tuple /* unused */) { // NOLINT
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not unpickle rref in python pickler, rref can only be "
|
|
"unpickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy PyBind's API
|
|
// requirement.
|
|
return PyRRef(
|
|
py::cast<py::none>(Py_None),
|
|
py::cast<py::none>(Py_None));
|
|
}),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_serialize",
|
|
&PyRRef::pickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def_static(
|
|
"_deserialize",
|
|
&PyRRef::unpickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Returns the future that corresponds to the creation of this RRef
|
|
on the remote node. This is for internal use cases such as profiling
|
|
only.
|
|
)")
|
|
.def(
|
|
"_get_profiling_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getProfilingFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Returns future that completes when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
.def(
|
|
"_set_profiling_future",
|
|
[](PyRRef& self,
|
|
const std::shared_ptr<jit::PythonFutureWrapper>&
|
|
wrappedFuture) {
|
|
self.setProfilingFuture(wrappedFuture->fut);
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Set future that is completed when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
// not releasing GIL to avoid context switch
|
|
.def("__repr__", &PyRRef::str);
|
|
|
|
shared_ptr_class_<ProcessGroupRpcBackendOptions>(
|
|
module,
|
|
"ProcessGroupRpcBackendOptions",
|
|
rpcBackendOptions,
|
|
R"(
|
|
The backend options class for ``ProcessGroupAgent``, which is derived
|
|
from ``RpcBackendOptions``.
|
|
|
|
Arguments:
|
|
num_send_recv_threads (int, optional): The number of threads in
|
|
the thread-pool used by ``ProcessGroupAgent`` (default: 4).
|
|
rpc_timeout (float, optional): The default timeout, in seconds,
|
|
for RPC requests (default: 60 seconds). If the
|
|
RPC has not completed in this timeframe, an exception
|
|
indicating so will be raised. Callers can override this
|
|
timeout for individual RPCs in
|
|
:meth:`~torch.distributed.rpc.rpc_sync` and
|
|
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
|
|
init_method (str, optional): The URL to initialize
|
|
``ProcessGroupGloo`` (default: ``env://``).
|
|
)")
|
|
.def(
|
|
py::init<int, float, std::string>(),
|
|
py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads,
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod)
|
|
.def_readwrite(
|
|
"num_send_recv_threads",
|
|
&ProcessGroupRpcBackendOptions::numSendRecvThreads,
|
|
R"(
|
|
The number of threads in the thread-pool used by ProcessGroupAgent.
|
|
)");
|
|
|
|
module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") =
|
|
py::cast(kDefaultNumSendRecvThreads);
|
|
|
|
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
|
|
.def(py::init([](std::string workerName,
|
|
const std::shared_ptr<::c10d::ProcessGroup>& pg,
|
|
int numSendRecvThreads,
|
|
std::chrono::milliseconds rpcTimeout) {
|
|
return std::make_unique<ProcessGroupAgent>(
|
|
std::move(workerName),
|
|
pg,
|
|
numSendRecvThreads,
|
|
rpcTimeout,
|
|
std::make_unique<RequestCallbackImpl>());
|
|
}))
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (ProcessGroupAgent::*)(void)const) &
|
|
RpcAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) &
|
|
ProcessGroupAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
(std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) &
|
|
ProcessGroupAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"join",
|
|
&ProcessGroupAgent::join,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"shutdown",
|
|
&ProcessGroupAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"sync",
|
|
&ProcessGroupAgent::sync,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
#ifdef USE_TENSORPIPE
|
|
|
|
// Base class: torch.distributed.rpc.RpcBackendOptions.
|
|
py::class_<TensorPipeRpcBackendOptions>(
|
|
module,
|
|
"TensorPipeRpcBackendOptions",
|
|
rpcBackendOptions,
|
|
R"(
|
|
The backend options for
|
|
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
|
|
:class:`~torch.distributed.rpc.RpcBackendOptions`.
|
|
|
|
Arguments:
|
|
num_worker_threads (int, optional): The number of threads in the
|
|
thread-pool used by
|
|
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
|
requests (default: 16).
|
|
rpc_timeout (float, optional): The default timeout, in seconds,
|
|
for RPC requests (default: 60 seconds). If the RPC has not
|
|
completed in this timeframe, an exception indicating so will
|
|
be raised. Callers can override this timeout for individual
|
|
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
|
|
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
|
|
init_method (str, optional): The URL to initialize the distributed
|
|
store used for rendezvous. It takes any value accepted for the
|
|
same argument of :meth:`~torch.distributed.init_process_group`
|
|
(default: ``env://``).
|
|
)")
|
|
.def(
|
|
py::init<
|
|
int,
|
|
optional<std::vector<std::string>>,
|
|
optional<std::vector<std::string>>,
|
|
float,
|
|
std::string>(),
|
|
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
|
|
py::arg("_transports") = optional<std::vector<std::string>>(),
|
|
py::arg("_channels") = optional<std::vector<std::string>>(),
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod)
|
|
.def_readwrite(
|
|
"num_worker_threads",
|
|
&TensorPipeRpcBackendOptions::numWorkerThreads,
|
|
R"(
|
|
The number of threads in the thread-pool used by
|
|
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
|
requests.
|
|
)");
|
|
|
|
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
|
|
py::cast(kDefaultNumWorkerThreads);
|
|
|
|
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
|
|
.def(
|
|
py::init([](const std::shared_ptr<::c10d::Store>& store,
|
|
std::string selfName,
|
|
worker_id_t selfId,
|
|
int worldSize,
|
|
std::shared_ptr<::c10d::ProcessGroup> processGroup,
|
|
TensorPipeRpcBackendOptions opts) {
|
|
return std::make_shared<TensorPipeAgent>(
|
|
store,
|
|
std::move(selfName),
|
|
selfId,
|
|
worldSize,
|
|
std::move(processGroup),
|
|
std::move(opts),
|
|
std::make_unique<RequestCallbackImpl>());
|
|
}),
|
|
py::arg("store"),
|
|
py::arg("name"),
|
|
py::arg("rank"),
|
|
py::arg("world_size"),
|
|
py::arg("process_group"),
|
|
py::arg("rpc_backend_options"))
|
|
.def(
|
|
"join",
|
|
&TensorPipeAgent::join,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"shutdown",
|
|
&TensorPipeAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (TensorPipeAgent::*)(void)const) &
|
|
RpcAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&)const) &
|
|
TensorPipeAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
|
|
TensorPipeAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
#endif // USE_TENSORPIPE
|
|
|
|
module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
|
|
|
|
module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
|
|
|
|
module.def(
|
|
"_set_and_start_rpc_agent",
|
|
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
|
|
RpcAgent::setCurrentRpcAgent(rpcAgent);
|
|
// Initializing typeResolver inside RpcAgent constructor will make
|
|
// RpcAgent have python dependency. To avoid RpcAgent to have python
|
|
// dependency, setTypeResolver() here.
|
|
std::shared_ptr<TypeResolver> typeResolver =
|
|
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
|
|
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
|
|
qn.qualifiedName());
|
|
return c10::StrongTypePtr(
|
|
PythonRpcHandler::getInstance().jitCompilationUnit(),
|
|
std::move(typePtr));
|
|
});
|
|
rpcAgent->setTypeResolver(typeResolver);
|
|
rpcAgent->start();
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def("_reset_current_rpc_agent", []() {
|
|
RpcAgent::setCurrentRpcAgent(nullptr);
|
|
});
|
|
|
|
module.def(
|
|
"_delete_all_user_and_unforked_owner_rrefs",
|
|
[](std::chrono::milliseconds timeoutMillis) {
|
|
RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
|
|
},
|
|
py::arg("timeout") = kDeleteAllUsersTimeout,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
|
|
// NB: do not release GIL in the function. The destroyInstance() method
|
|
// returns a list of deleted OwnerRRefs that hold py::object instances.
|
|
// Clearing those OwnerRRefs are likely to trigger Python deref, which
|
|
// requires GIL.
|
|
RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
|
|
});
|
|
|
|
module.def("_rref_context_get_debug_info", []() {
|
|
return RRefContext::getInstance().getDebugInfo();
|
|
});
|
|
|
|
module.def(
|
|
"_cleanup_python_rpc_handler",
|
|
[]() { PythonRpcHandler::getInstance().cleanup(); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_builtin",
|
|
[](const WorkerInfo& dst,
|
|
const std::string& opName,
|
|
const float rpcTimeoutSeconds,
|
|
const py::args& args,
|
|
const py::kwargs& kwargs) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_python_udf",
|
|
[](const WorkerInfo& dst,
|
|
std::string& pickledPythonUDF,
|
|
std::vector<torch::Tensor>& tensors,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
pyRpcPythonUdf(
|
|
dst,
|
|
pickledPythonUDF,
|
|
tensors,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution),
|
|
/* unwrap_func */ [](const py::object& value) {
|
|
py::gil_scoped_release release;
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
// This will unwrap RemoteException and raise the contained
|
|
// server-side Python exception on client side. A caveat here is
|
|
// that the exception must be raise in the client thread calling
|
|
// the pybind "wait" API, so that it can be correctly shown to
|
|
// user. A wrong way is to raise it in RPC server thread, where
|
|
// the exception would be swallowed in the ThreadPool task, and
|
|
// also no pybind handling code can help shown the Python
|
|
// exception.
|
|
pythonRpcHandler.handleException(value);
|
|
});
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_torchscript",
|
|
[](const std::string& dstWorkerName,
|
|
const std::string& qualifiedNameStr,
|
|
const py::tuple& argsTuple,
|
|
const py::dict& kwargsDict,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
|
|
dstWorkerName,
|
|
qualifiedNameStr,
|
|
argsTuple,
|
|
kwargsDict,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution));
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_builtin",
|
|
&pyRemoteBuiltin,
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_remote_python_udf",
|
|
&pyRemotePythonUdf,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_torchscript",
|
|
&pyRemoteTorchscript,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"get_rpc_timeout",
|
|
[]() {
|
|
return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
|
|
kSecToMsConversion;
|
|
},
|
|
R"(
|
|
Retrieve the default timeout for all RPCs that was set during RPC initialization.
|
|
The returned value will be in seconds.
|
|
Returns:
|
|
``float`` indicating the RPC timeout in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"enable_gil_profiling",
|
|
[](bool flag) {
|
|
RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
|
|
},
|
|
R"(
|
|
Set whether GIL wait times should be enabled or not. This incurs a slight
|
|
overhead cost. Default is disabled for performance reasons.
|
|
|
|
Arguments:
|
|
flag (bool): True to set GIL profiling, False to disable.
|
|
)");
|
|
|
|
module.def(
|
|
"_set_rpc_timeout",
|
|
[](const float rpcTimeoutSeconds) {
|
|
auto rpcTimeout = std::chrono::milliseconds(
|
|
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
|
RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
|
|
},
|
|
R"(
|
|
Set the default timeout for all RPCs. The input unit is expected to be
|
|
in seconds. If an RPC is not completed within this time, an exception
|
|
indicating it has timed out will be raised. To control timeout for
|
|
specific RPCs, a timeout parameter can be passed into
|
|
:meth:`~torch.distributed.rpc.rpc_sync` and
|
|
:meth:`~torch.distributed.rpc.rpc_async`.
|
|
|
|
Arguments:
|
|
rpcTimeoutSeconds (float): Timeout value in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"_enable_server_process_global_profiler",
|
|
&profiler::processglobal::enableServer);
|
|
module.def(
|
|
"_disable_server_process_global_profiler",
|
|
&profiler::processglobal::disableServer);
|
|
|
|
module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
|
|
|
|
py::class_<
|
|
RemoteProfilerManager,
|
|
std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
|
|
module, "RemoteProfilerManager")
|
|
.def("set_current_profiling_key", [](const std::string& key) {
|
|
auto& inst = RemoteProfilerManager::getInstance();
|
|
inst.setCurrentKey(key);
|
|
});
|
|
|
|
module.def(
|
|
"_enable_jit_rref_pickle",
|
|
&enableJitRRefPickle,
|
|
R"(
|
|
Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
|
|
pickled RRefs out of RPC contexts.
|
|
|
|
|
|
.. warning::
|
|
This is dangerous. If the module contains RRefs, the pickled
|
|
result must be sent over RPC and get unpickled on the receiving side
|
|
to restore the module. Otherwise, there will be RRef leaks, which
|
|
can potentially lead to program hang. When using this API, it is
|
|
applications responsibility to make sure that the above assumption
|
|
always holds.
|
|
)");
|
|
module.def(
|
|
"_disable_jit_rref_pickle",
|
|
&disableJitRRefPickle);
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
static PyMethodDef methods[] = { // NOLINT
|
|
{"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return methods;
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|