mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This series of changes try to cover C style casts into C++ alternatives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750 Approved by: https://github.com/Skylion007
862 lines
36 KiB
C++
862 lines
36 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
|
|
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
#include <torch/csrc/distributed/rpc/python_functions.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
|
#include <torch/csrc/distributed/rpc/rpc.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
|
|
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
|
|
#include <torch/csrc/distributed/rpc/types.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/types.h>
|
|
|
|
#include <pybind11/chrono.h>
|
|
#include <pybind11/operators.h>
|
|
|
|
namespace torch::distributed::rpc {
|
|
|
|
namespace {
|
|
|
|
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
auto rpc_module =
|
|
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
|
|
if (!rpc_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
|
|
if (!torch_C_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
|
|
auto m =
|
|
torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
|
|
|
|
auto module = py::handle(m).cast<py::module>();
|
|
|
|
auto rpcBackendOptions =
|
|
shared_ptr_class_<RpcBackendOptions>(
|
|
module,
|
|
"RpcBackendOptions",
|
|
R"(An abstract structure encapsulating the options passed into the RPC
|
|
backend. An instance of this class can be passed in to
|
|
:meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
|
|
with specific configurations, such as the RPC timeout and
|
|
``init_method`` to be used. )")
|
|
.def(py::init<>())
|
|
.def(
|
|
py::init<float, std::string>(),
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod)
|
|
.def_readwrite(
|
|
"rpc_timeout",
|
|
&RpcBackendOptions::rpcTimeoutSeconds,
|
|
R"(A float indicating the timeout to use for all
|
|
RPCs. If an RPC does not complete in this timeframe, it will
|
|
complete with an exception indicating that it has timed out.)")
|
|
.def_readwrite(
|
|
"init_method",
|
|
&RpcBackendOptions::initMethod,
|
|
R"(URL specifying how to initialize the process group.
|
|
Default is ``env://``)");
|
|
|
|
// The following C++ constants need to be cast so they can be used from
|
|
// python.
|
|
module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
|
|
module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
|
|
module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
|
|
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
|
|
py::cast(kDefaultNumWorkerThreads);
|
|
|
|
auto workerInfo =
|
|
shared_ptr_class_<WorkerInfo>(
|
|
module,
|
|
"WorkerInfo",
|
|
R"(A structure that encapsulates information of a worker in the system.
|
|
Contains the name and ID of the worker. This class is not meant to
|
|
be constructed directly, rather, an instance can be retrieved
|
|
through :meth:`~torch.distributed.rpc.get_worker_info` and the
|
|
result can be passed in to functions such as
|
|
:meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
|
|
:meth:`~torch.distributed.rpc.remote` to avoid copying a string on
|
|
every invocation.)")
|
|
.def(
|
|
py::init<std::string, worker_id_t>(),
|
|
py::arg("name"),
|
|
py::arg("id"))
|
|
.def_readonly(
|
|
"name", &WorkerInfo::name_, R"(The name of the worker.)")
|
|
.def_readonly(
|
|
"id",
|
|
&WorkerInfo::id_,
|
|
R"(Globally unique id to identify the worker.)")
|
|
.def("__eq__", &WorkerInfo::operator==, py::is_operator())
|
|
// pybind11 suggests the syntax .def(hash(py::self)), with the
|
|
// unqualified "hash" function call. However the
|
|
// argument-dependent lookup for the function "hash" doesn't get
|
|
// triggered in this context because it conflicts with the struct
|
|
// c10::hash, so we need to use the qualified name
|
|
// py::detail::hash, which unfortunately is in a detail namespace.
|
|
.def(py::detail::hash(py::self)) // NOLINT
|
|
.def(
|
|
"__repr__",
|
|
[](const WorkerInfo& workerInfo) {
|
|
std::ostringstream os;
|
|
os << workerInfo;
|
|
return os.str();
|
|
})
|
|
.def(py::pickle(
|
|
/* __getstate__ */
|
|
[](const WorkerInfo& workerInfo) {
|
|
return py::make_tuple(workerInfo.name_, workerInfo.id_);
|
|
},
|
|
/* __setstate__ */
|
|
[](const py::tuple& t) {
|
|
TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
|
|
|
|
WorkerInfo info(
|
|
t[0].cast<std::string>(), t[1].cast<worker_id_t>());
|
|
return info;
|
|
}));
|
|
|
|
auto rpcAgent =
|
|
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
|
|
.def(
|
|
"join",
|
|
&RpcAgent::join,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
py::arg("shutdown") = false,
|
|
py::arg("timeout") = 0)
|
|
.def(
|
|
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"shutdown",
|
|
&RpcAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
static_cast<const WorkerInfo& (RpcAgent::*)(void) const>(
|
|
&RpcAgent::getWorkerInfo),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
static_cast<const WorkerInfo& (RpcAgent::*)(const std::string&)
|
|
const>(&RpcAgent::getWorkerInfo),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
&RpcAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_device_map",
|
|
&RpcAgent::getDeviceMap,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_debug_info",
|
|
&RpcAgent::getDebugInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_metrics",
|
|
&RpcAgent::getMetrics,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
auto pyRRef =
|
|
shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
|
|
A class encapsulating a reference to a value of some type on a remote
|
|
worker. This handle will keep the referenced remote value alive on the
|
|
worker. A ``UserRRef`` will be deleted when 1) no references to it in
|
|
both the application code and in the local RRef context, or 2) the
|
|
application has called a graceful shutdown. Invoking methods on a
|
|
deleted RRef leads to undefined behaviors. RRef implementation only
|
|
offers best-effort error detection, and applications should not use
|
|
``UserRRefs`` after ``rpc.shutdown()``.
|
|
|
|
.. warning::
|
|
RRefs can only be serialized and deserialized by the RPC module.
|
|
Serializing and deserializing RRefs without RPC (e.g., Python
|
|
pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
|
|
JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
|
|
lead to errors.
|
|
|
|
Args:
|
|
value (object): The value to be wrapped by this RRef.
|
|
type_hint (Type, optional): Python type that should be passed to
|
|
``TorchScript`` compiler as type hint for ``value``.
|
|
|
|
Example::
|
|
Following examples skip RPC initialization and shutdown code
|
|
for simplicity. Refer to RPC docs for those details.
|
|
|
|
1. Create an RRef using rpc.remote
|
|
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
|
>>> # get a copy of value from the RRef
|
|
>>> x = rref.to_here()
|
|
|
|
2. Create an RRef from a local object
|
|
|
|
>>> import torch
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> x = torch.zeros(2, 2)
|
|
>>> rref = RRef(x)
|
|
|
|
3. Share an RRef with other workers
|
|
|
|
>>> # On both worker0 and worker1:
|
|
>>> def f(rref):
|
|
>>> return rref.to_here() + 1
|
|
|
|
>>> # On worker0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> rref = RRef(torch.zeros(2, 2))
|
|
>>> # the following RPC shares the rref with worker1, reference
|
|
>>> # count is automatically updated.
|
|
>>> rpc.rpc_sync("worker1", f, args=(rref,))
|
|
)")
|
|
.def(
|
|
py::init<const py::object&, const py::object&>(),
|
|
py::arg("value"),
|
|
py::arg("type_hint") = py::none())
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"is_owner",
|
|
&PyRRef::isOwner,
|
|
R"(
|
|
Returns whether or not the current node is the owner of this
|
|
``RRef``.
|
|
)")
|
|
.def(
|
|
"confirmed_by_owner",
|
|
&PyRRef::confirmedByOwner,
|
|
R"(
|
|
Returns whether this ``RRef`` has been confirmed by the owner.
|
|
``OwnerRRef`` always returns true, while ``UserRRef`` only
|
|
returns true when the owner knowns about this ``UserRRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner",
|
|
&PyRRef::owner,
|
|
R"(
|
|
Returns worker information of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner_name",
|
|
&PyRRef::ownerName,
|
|
R"(
|
|
Returns worker name of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
"to_here",
|
|
&PyRRef::toHere,
|
|
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Blocking call that copies the value of the RRef from the owner
|
|
to the local node and returns it. If the current node is the
|
|
owner, returns a reference to the local value.
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``to_here``. If
|
|
the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this
|
|
argument is not provided, the default RPC timeout
|
|
(60s) will be used.
|
|
)")
|
|
.def(
|
|
"local_value",
|
|
&PyRRef::localValue,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
If the current node is the owner, returns a reference to the
|
|
local value. Otherwise, throws an exception.
|
|
)")
|
|
.def(
|
|
"rpc_sync",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::RPC_SYNC, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_sync`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.rpc_sync()``.
|
|
If the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this argument
|
|
is not provided, the default RPC timeout will be used.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_sync().size() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"rpc_async",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::RPC_ASYNC, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_async`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.rpc_async()``.
|
|
If the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this argument
|
|
is not provided, the default RPC timeout will be used.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"remote",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::REMOTE, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch a ``remote`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.remote().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.remote()``. If
|
|
the creation of this :class:`~torch.distributed.rpc.RRef`
|
|
is not successfully completed within the timeout, then the
|
|
next time there is an attempt to use the RRef
|
|
(such as ``to_here``), a timeout will be raised. If not
|
|
provided, the default RPC timeout will be used. Please see
|
|
``rpc.remote()`` for specific timeout semantics for
|
|
:class:`~torch.distributed.rpc.RRef`.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.remote().size().to_here() # returns torch.Size([2, 2])
|
|
>>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
py::pickle(
|
|
/* __getstate__ */
|
|
[](const PyRRef& /* unused */) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not pickle rref in python pickler, rref can only be "
|
|
"pickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy Pybind API's
|
|
// requirement.
|
|
return py::make_tuple();
|
|
},
|
|
/* __setstate__ */
|
|
[](py::tuple /* unused */) { // NOLINT
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not unpickle rref in python pickler, rref can only be "
|
|
"unpickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy PyBind's API
|
|
// requirement.
|
|
return PyRRef(
|
|
py::cast<py::none>(Py_None),
|
|
py::cast<py::none>(Py_None));
|
|
}),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_serialize",
|
|
&PyRRef::pickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def_static(
|
|
"_deserialize",
|
|
&PyRRef::unpickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_type",
|
|
// Intentionally not releasing GIL, as most accesses just
|
|
// retrieve cached type py::object
|
|
&PyRRef::getRRefType,
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::arg("blocking") = true,
|
|
R"(
|
|
If ``blocking=True``, returns the type of the data object
|
|
referenced by this ``RRef``. On the owner, this is same as
|
|
``type(rref.local_value())``. Otherwise, returns a future to
|
|
this result. On a user, this will trigger an RPC to fetch the
|
|
``type`` object from the owner. After this function is run
|
|
once, the ``type`` object is cached by the ``RRef``, and
|
|
subsequent invocations no longer trigger RPC. Note that this is
|
|
true regardless of the ``blocking`` argument of subsequent
|
|
calls.
|
|
|
|
Args:
|
|
rref (torch.distributed.rpc.RRef): The RRef to get type of.
|
|
timeout (float, optional): Timeout, in seconds for
|
|
``_get_type``. If the call does not complete within
|
|
this timeframe, an exception indicating so will be
|
|
raised. If this argument is not provided, the default
|
|
RPC timeout will be used.
|
|
blocking (bool, optional): Whether to synchronously wait on
|
|
the RPC triggered by the first call and return the
|
|
type. If ``False``, will return a future. Default is
|
|
``True``.
|
|
)")
|
|
.def(
|
|
"_get_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Returns the future that corresponds to the creation of this RRef
|
|
on the remote node. This is for internal use cases such as profiling
|
|
only.
|
|
)")
|
|
.def(
|
|
"_get_profiling_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getProfilingFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Returns future that completes when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
.def(
|
|
"_set_profiling_future",
|
|
[](PyRRef& self,
|
|
const std::shared_ptr<jit::PythonFutureWrapper>&
|
|
wrappedFuture) {
|
|
self.setProfilingFuture(wrappedFuture->fut);
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Set future that is completed when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
.def(
|
|
"backward",
|
|
[](PyRRef& self,
|
|
int64_t dist_autograd_ctx_id,
|
|
bool retain_graph) {
|
|
self.backward(dist_autograd_ctx_id, retain_graph);
|
|
},
|
|
py::arg("dist_autograd_ctx_id") = -1,
|
|
py::arg("retain_graph") = false,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Runs the backward pass using the RRef as the root of the
|
|
backward pass. If ``dist_autograd_ctx_id`` is provided,
|
|
we perform a distributed backward pass using the provided
|
|
ctx_id starting from the owner of the RRef. In this case,
|
|
:meth:`~torch.distributed.autograd.get_gradients` should be
|
|
used to retrieve the gradients. If ``dist_autograd_ctx_id``
|
|
is ``None``, it is assumed that this is a local autograd graph
|
|
and we only perform a local backward pass. In the local case,
|
|
the node calling this API has to be the owner of the RRef.
|
|
The value of the RRef is expected to be a scalar Tensor.
|
|
|
|
Args:
|
|
dist_autograd_ctx_id (int, optional): The distributed
|
|
autograd context id for which we should retrieve the
|
|
gradients (default: -1).
|
|
retain_graph(bool, optional): If ``False``, the graph used to
|
|
compute the grad will be freed. Note that in nearly all
|
|
cases setting this option to ``True`` is not needed and
|
|
often can be worked around in a much more efficient way.
|
|
Usually, you need to set this to ``True`` to run backward
|
|
multiple times (default: False).
|
|
|
|
Example::
|
|
>>> import torch.distributed.autograd as dist_autograd
|
|
>>> with dist_autograd.context() as context_id:
|
|
>>> rref.backward(context_id)
|
|
)")
|
|
// not releasing GIL to avoid context switch
|
|
.def("__repr__", &PyRRef::str);
|
|
|
|
#ifdef USE_TENSORPIPE
|
|
|
|
// Base class: torch.distributed.rpc.RpcBackendOptions.
|
|
py::class_<TensorPipeRpcBackendOptions>(
|
|
module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
|
|
.def(
|
|
py::init<
|
|
int,
|
|
std::optional<std::vector<std::string>>,
|
|
std::optional<std::vector<std::string>>,
|
|
float,
|
|
std::string,
|
|
std::unordered_map<std::string, DeviceMap>,
|
|
std::vector<c10::Device>>(),
|
|
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
|
|
py::arg("_transports") = std::optional<std::vector<std::string>>(),
|
|
py::arg("_channels") = std::optional<std::vector<std::string>>(),
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod,
|
|
py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
|
|
py::arg("devices") = std::vector<c10::Device>())
|
|
.def_readwrite(
|
|
"num_worker_threads",
|
|
&TensorPipeRpcBackendOptions::numWorkerThreads,
|
|
R"(
|
|
The number of threads in the thread-pool used by
|
|
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
|
requests.
|
|
)")
|
|
.def_readwrite(
|
|
"device_maps",
|
|
&TensorPipeRpcBackendOptions::deviceMaps,
|
|
R"(The device map locations.)")
|
|
.def_readwrite(
|
|
"devices",
|
|
&TensorPipeRpcBackendOptions::devices,
|
|
R"(All devices used by the local agent.)")
|
|
.def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
|
|
|
|
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
|
|
.def(
|
|
py::init(
|
|
[](const c10::intrusive_ptr<::c10d::Store>& store,
|
|
std::string selfName,
|
|
worker_id_t selfId,
|
|
std::optional<int> worldSize,
|
|
TensorPipeRpcBackendOptions opts,
|
|
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
|
|
std::vector<c10::Device> devices) {
|
|
return std::shared_ptr<TensorPipeAgent>(
|
|
new TensorPipeAgent(
|
|
store,
|
|
std::move(selfName),
|
|
selfId,
|
|
worldSize,
|
|
std::move(opts),
|
|
std::move(reverseDeviceMaps),
|
|
std::move(devices),
|
|
std::make_unique<RequestCallbackImpl>()),
|
|
impl::destroy_without_gil<TensorPipeAgent>);
|
|
}),
|
|
py::arg("store"),
|
|
py::arg("name"),
|
|
py::arg("rank"),
|
|
py::arg("world_size"),
|
|
py::arg("rpc_backend_options"),
|
|
py::arg("reverse_device_maps"),
|
|
py::arg("devices"))
|
|
.def(
|
|
"join",
|
|
&TensorPipeAgent::join,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
py::arg("shutdown") = false,
|
|
py::arg("timeout") = 0)
|
|
.def(
|
|
"shutdown",
|
|
&TensorPipeAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
static_cast<const WorkerInfo& (TensorPipeAgent::*)(void) const>(
|
|
&RpcAgent::getWorkerInfo),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
static_cast<const WorkerInfo& (TensorPipeAgent::*)(const std::string&)
|
|
const>(&TensorPipeAgent::getWorkerInfo),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
static_cast<const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id)
|
|
const>(&TensorPipeAgent::getWorkerInfo),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
static_cast<std::vector<WorkerInfo> (TensorPipeAgent::*)() const>(
|
|
&TensorPipeAgent::getWorkerInfos),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_device_map",
|
|
static_cast<DeviceMap (TensorPipeAgent::*)(const WorkerInfo& dst)
|
|
const>(&TensorPipeAgent::getDeviceMap),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_backend_options",
|
|
&TensorPipeAgent::getBackendOptions,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_update_group_membership",
|
|
&TensorPipeAgent::updateGroupMembership,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
|
|
.def_property_readonly("store", &TensorPipeAgent::getStore);
|
|
|
|
#endif // USE_TENSORPIPE
|
|
|
|
module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
|
|
|
|
module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
|
|
|
|
module.def(
|
|
"_set_and_start_rpc_agent",
|
|
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
|
|
RpcAgent::setCurrentRpcAgent(rpcAgent);
|
|
// Initializing typeResolver inside RpcAgent constructor will make
|
|
// RpcAgent have python dependency. To avoid RpcAgent to have python
|
|
// dependency, setTypeResolver() here.
|
|
std::shared_ptr<TypeResolver> typeResolver =
|
|
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
|
|
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
|
|
qn.qualifiedName());
|
|
return c10::StrongTypePtr(
|
|
PythonRpcHandler::getInstance().jitCompilationUnit(),
|
|
std::move(typePtr));
|
|
});
|
|
rpcAgent->setTypeResolver(typeResolver);
|
|
rpcAgent->start();
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_reset_current_rpc_agent",
|
|
[]() { RpcAgent::setCurrentRpcAgent(nullptr); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_delete_all_user_and_unforked_owner_rrefs",
|
|
[](std::chrono::milliseconds timeoutMillis) {
|
|
RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
|
|
},
|
|
py::arg("timeout") = kDeleteAllUsersTimeout,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
|
|
// NB: do not release GIL in the function. The destroyInstance() method
|
|
// returns a list of deleted OwnerRRefs that hold py::object instances.
|
|
// Clearing those OwnerRRefs are likely to trigger Python deref, which
|
|
// requires GIL.
|
|
RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
|
|
});
|
|
|
|
module.def("_rref_context_get_debug_info", []() {
|
|
return RRefContext::getInstance().getDebugInfo();
|
|
});
|
|
|
|
module.def(
|
|
"_cleanup_python_rpc_handler",
|
|
[]() { PythonRpcHandler::getInstance().cleanup(); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_builtin",
|
|
[](const WorkerInfo& dst,
|
|
const std::string& opName,
|
|
const float rpcTimeoutSeconds,
|
|
const py::args& args,
|
|
const py::kwargs& kwargs) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_python_udf",
|
|
[](const WorkerInfo& dst,
|
|
std::string& pickledPythonUDF,
|
|
std::vector<torch::Tensor>& tensors,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
|
|
dst,
|
|
pickledPythonUDF,
|
|
tensors,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution));
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_torchscript",
|
|
[](const std::string& dstWorkerName,
|
|
const std::string& qualifiedNameStr,
|
|
const py::tuple& argsTuple,
|
|
const py::dict& kwargsDict,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
|
|
dstWorkerName,
|
|
qualifiedNameStr,
|
|
argsTuple,
|
|
kwargsDict,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution));
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_builtin",
|
|
&pyRemoteBuiltin,
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_remote_python_udf",
|
|
&pyRemotePythonUdf,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_torchscript",
|
|
&pyRemoteTorchscript,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"get_rpc_timeout",
|
|
[]() {
|
|
return static_cast<float>(
|
|
RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count()) /
|
|
kSecToMsConversion;
|
|
},
|
|
R"(
|
|
Retrieve the default timeout for all RPCs that was set during RPC initialization.
|
|
The returned value will be in seconds.
|
|
Returns:
|
|
``float`` indicating the RPC timeout in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"enable_gil_profiling",
|
|
[](bool flag) {
|
|
RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
|
|
},
|
|
R"(
|
|
Set whether GIL wait times should be enabled or not. This incurs a slight
|
|
overhead cost. Default is disabled for performance reasons.
|
|
|
|
Args:
|
|
flag (bool): True to set GIL profiling, False to disable.
|
|
)");
|
|
|
|
module.def(
|
|
"_set_rpc_timeout",
|
|
[](const float rpcTimeoutSeconds) {
|
|
auto rpcTimeout = std::chrono::milliseconds(
|
|
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
|
RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
|
|
},
|
|
R"(
|
|
Set the default timeout for all RPCs. The input unit is expected to be
|
|
in seconds. If an RPC is not completed within this time, an exception
|
|
indicating it has timed out will be raised. To control timeout for
|
|
specific RPCs, a timeout parameter can be passed into
|
|
:meth:`~torch.distributed.rpc.rpc_sync` and
|
|
:meth:`~torch.distributed.rpc.rpc_async`.
|
|
|
|
Args:
|
|
rpcTimeoutSeconds (float): Timeout value in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"_enable_server_process_global_profiler",
|
|
&profiler::processglobal::enableServer);
|
|
module.def(
|
|
"_disable_server_process_global_profiler",
|
|
&profiler::processglobal::disableServer);
|
|
|
|
module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
|
|
|
|
py::class_<
|
|
RemoteProfilerManager,
|
|
std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
|
|
module, "RemoteProfilerManager")
|
|
.def("set_current_profiling_key", [](const std::string& key) {
|
|
auto& inst = RemoteProfilerManager::getInstance();
|
|
inst.setCurrentKey(key);
|
|
});
|
|
|
|
module.def(
|
|
"_enable_jit_rref_pickle",
|
|
&enableJitRRefPickle,
|
|
R"(
|
|
Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
|
|
pickled RRefs out of RPC contexts.
|
|
|
|
|
|
.. warning::
|
|
This is dangerous. If the module contains RRefs, the pickled
|
|
result must be sent over RPC and get unpickled on the receiving side
|
|
to restore the module. Otherwise, there will be RRef leaks, which
|
|
can potentially lead to program hang. When using this API, it is
|
|
applications responsibility to make sure that the above assumption
|
|
always holds.
|
|
)");
|
|
module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
|
|
|
|
Py_RETURN_TRUE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
} // namespace
|
|
|
|
static PyMethodDef methods[] = { // NOLINT
|
|
{"_rpc_init", rpc_init, METH_NOARGS, nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return methods;
|
|
}
|
|
|
|
} // namespace torch::distributed::rpc
|