mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Closes https://github.com/pytorch/pytorch/issues/30027 The idea here is that you can bind a function with `pybind11` in a single line and without modifying the function: ```cpp m.def("foo", foo, py::call_guard<torch::PyWarningHandler>()); ``` Where warnings are handled by the [`call_guard`](https://pybind11.readthedocs.io/en/stable/advanced/functions.html#call-guard) and exceptions are handled by the `pybind11` exception translator. To do this, I have added support for handling C++ exceptions in `torch::PyWarningHandler`'s destructor without setting the python error state before hand. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30588 Differential Revision: D19905626 Pulled By: albanD fbshipit-source-id: 90c0a5e298b123cc0c8ab9c52c91be4e96ea47c6
214 lines
7.1 KiB
C++
214 lines
7.1 KiB
C++
#include <torch/csrc/autograd/python_cpp_function.h>
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/types.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace autograd {
|
|
|
|
namespace {
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
PyObject* dist_autograd_init(PyObject* /* unused */) {
|
|
auto autograd_module =
|
|
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
|
|
if (!autograd_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto module = py::handle(autograd_module).cast<py::module>();
|
|
|
|
auto distAutogradContext =
|
|
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
|
|
.def(
|
|
"_context_id",
|
|
&DistAutogradContext::contextId,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_recv_functions",
|
|
[](const DistAutogradContext& ctx) {
|
|
std::map<int64_t, py::object> funcs;
|
|
for (const auto& map_entry : ctx.recvFunctions()) {
|
|
funcs.emplace(
|
|
map_entry.first,
|
|
py::reinterpret_steal<py::object>(
|
|
torch::autograd::functionToPyObject(
|
|
map_entry.second)));
|
|
}
|
|
return funcs;
|
|
})
|
|
.def(
|
|
"_send_functions",
|
|
[](const ContextPtr& ctx) {
|
|
std::map<int64_t, py::object> funcs;
|
|
for (const auto& map_entry : ctx->sendFunctions()) {
|
|
funcs.emplace(
|
|
map_entry.first,
|
|
py::reinterpret_steal<py::object>(
|
|
torch::autograd::functionToPyObject(
|
|
map_entry.second)));
|
|
}
|
|
return funcs;
|
|
})
|
|
.def("_known_worker_ids", &DistAutogradContext::getKnownWorkerIds);
|
|
|
|
module.def(
|
|
"_new_context",
|
|
[]() -> const ContextPtr {
|
|
return DistAutogradContainer::getInstance().newContext();
|
|
},
|
|
py::return_value_policy::reference);
|
|
|
|
module.def(
|
|
"_release_context",
|
|
[](int64_t context_id) {
|
|
return DistAutogradContainer::getInstance().releaseContext(context_id);
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def("_get_max_id", []() {
|
|
return DistAutogradContainer::getInstance().getMaxId();
|
|
});
|
|
|
|
module.def(
|
|
"_retrieve_context",
|
|
[](int64_t context_id) -> const ContextPtr {
|
|
return DistAutogradContainer::getInstance().retrieveContext(context_id);
|
|
},
|
|
py::return_value_policy::reference);
|
|
|
|
module.def(
|
|
"_current_context",
|
|
[]() -> const ContextPtr {
|
|
return DistAutogradContainer::getInstance().currentContext();
|
|
},
|
|
py::return_value_policy::reference);
|
|
|
|
module.def(
|
|
"_init",
|
|
[](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_get_debug_info",
|
|
[]() { return DistEngine::getInstance().getDebugInfo(); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
py::options options;
|
|
options.disable_function_signatures();
|
|
|
|
module.def(
|
|
"backward",
|
|
[](const std::vector<torch::Tensor>& roots, bool retainGraph = false) {
|
|
torch::autograd::variable_list variables;
|
|
for (const auto& root : roots) {
|
|
variables.emplace_back(root);
|
|
}
|
|
try {
|
|
DistEngine::getInstance().execute(variables, retainGraph);
|
|
} catch (python_error & e) {
|
|
// FIXME: crashes if exception type is not RuntimeError
|
|
throw std::runtime_error(e.what());
|
|
}
|
|
},
|
|
R"(
|
|
backward(roots: List[Tensor], retain_graph = False) -> None
|
|
|
|
Kicks off the distributed backward pass using the provided roots. This
|
|
currently implements the :ref:`fast-mode-algorithm` which
|
|
assumes all RPC messages sent in the same distributed autograd context
|
|
across workers would be part of the autograd graph during the backward pass.
|
|
|
|
We use the provided roots to discover the autograd graph and compute
|
|
appropriate dependencies. This method blocks until the entire
|
|
autograd computation is done.
|
|
|
|
We accumulate the gradients in the appropriate
|
|
:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
|
|
context used is the current autograd context of this node when
|
|
:meth:`torch.distributed.autograd.backward` is called. If there is no valid
|
|
autograd context, we throw an error. You can retrieve the accumulated
|
|
gradients using the :meth:`~torch.distributed.autograd.get_gradients` API.
|
|
|
|
Arguments:
|
|
roots (list): Tensors which represent the roots of the autograd
|
|
computation. All the tensors should be scalars.
|
|
retain_graph(bool, optional): If False, the graph used to compute the grad
|
|
will be freed. Note that in nearly all cases setting this
|
|
option to True is not needed and often can be worked around
|
|
in a much more efficient way. Usually, you need to set this
|
|
to True to run backward multiple times.
|
|
|
|
Example::
|
|
|
|
>> import torch.distributed.autograd as dist_autograd
|
|
>> with dist_autograd.context() as context_id:
|
|
>> pred = model.forward()
|
|
>> loss = loss_func(pred, loss)
|
|
>> dist_autograd.backward(loss)
|
|
)",
|
|
py::arg("roots"),
|
|
py::arg("retain_graph") = false,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"get_gradients",
|
|
[](int64_t contextId) -> py::dict {
|
|
const auto& autogradContext =
|
|
DistAutogradContainer::getInstance().retrieveContext(contextId);
|
|
return torch::jit::toPyObject(IValue(autogradContext->getGradients()));
|
|
},
|
|
R"(
|
|
get_gradients(context_id: int) -> Dict[Tensor, Tensor]
|
|
|
|
Retrieves a map from Tensor to the appropriate gradient for that Tensor
|
|
accumulated in the provided ``context_id`` as part of the distributed autograd
|
|
backward pass.
|
|
|
|
Arguments:
|
|
context_id(int): The autograd context id for which we should retrieve the
|
|
gradients.
|
|
|
|
Returns:
|
|
A map where the key is the Tensor and the value is the associated gradient for that Tensor.
|
|
|
|
Example::
|
|
|
|
>> import torch.distributed.autograd as dist_autograd
|
|
>> with dist_autograd.context() as context_id:
|
|
>> t1 = torch.rand((3, 3), requires_grad=True)
|
|
>> t2 = torch.rand((3, 3), requires_grad=True)
|
|
>> loss = t1 + t2
|
|
>> dist_autograd.backward([loss.sum()])
|
|
>> grads = dist_autograd.get_gradients(context_id)
|
|
>> print (grads[t1])
|
|
>> print (grads[t2])
|
|
)",
|
|
py::arg("context_id"));
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
} // namespace
|
|
|
|
static PyMethodDef methods[] = { // NOLINT
|
|
{"_dist_autograd_init",
|
|
(PyCFunction)dist_autograd_init,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return methods;
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace distributed
|
|
} // namespace torch
|