Files
pytorch/torch/csrc/distributed/autograd/init.cpp
Peter Bell 44af8ee6cd Add pybind11 exception translator (#30588)
Summary:
Closes https://github.com/pytorch/pytorch/issues/30027

The idea here is that you can bind a function with `pybind11` in a single line and without modifying the function:
```cpp
m.def("foo", foo, py::call_guard<torch::PyWarningHandler>());
```
Where warnings are handled by the [`call_guard`](https://pybind11.readthedocs.io/en/stable/advanced/functions.html#call-guard) and exceptions are handled by the `pybind11` exception translator. To do this, I have added support for handling C++ exceptions in `torch::PyWarningHandler`'s destructor without setting the python error state before hand.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30588

Differential Revision: D19905626

Pulled By: albanD

fbshipit-source-id: 90c0a5e298b123cc0c8ab9c52c91be4e96ea47c6
2020-02-18 11:33:29 -08:00

214 lines
7.1 KiB
C++

#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
namespace torch {
namespace distributed {
namespace autograd {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* dist_autograd_init(PyObject* /* unused */) {
auto autograd_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
if (!autograd_module) {
throw python_error();
}
auto module = py::handle(autograd_module).cast<py::module>();
auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
.def(
"_context_id",
&DistAutogradContext::contextId,
py::call_guard<py::gil_scoped_release>())
.def(
"_recv_functions",
[](const DistAutogradContext& ctx) {
std::map<int64_t, py::object> funcs;
for (const auto& map_entry : ctx.recvFunctions()) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
})
.def(
"_send_functions",
[](const ContextPtr& ctx) {
std::map<int64_t, py::object> funcs;
for (const auto& map_entry : ctx->sendFunctions()) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
})
.def("_known_worker_ids", &DistAutogradContext::getKnownWorkerIds);
module.def(
"_new_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().newContext();
},
py::return_value_policy::reference);
module.def(
"_release_context",
[](int64_t context_id) {
return DistAutogradContainer::getInstance().releaseContext(context_id);
},
py::call_guard<py::gil_scoped_release>());
module.def("_get_max_id", []() {
return DistAutogradContainer::getInstance().getMaxId();
});
module.def(
"_retrieve_context",
[](int64_t context_id) -> const ContextPtr {
return DistAutogradContainer::getInstance().retrieveContext(context_id);
},
py::return_value_policy::reference);
module.def(
"_current_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().currentContext();
},
py::return_value_policy::reference);
module.def(
"_init",
[](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_debug_info",
[]() { return DistEngine::getInstance().getDebugInfo(); },
py::call_guard<py::gil_scoped_release>());
py::options options;
options.disable_function_signatures();
module.def(
"backward",
[](const std::vector<torch::Tensor>& roots, bool retainGraph = false) {
torch::autograd::variable_list variables;
for (const auto& root : roots) {
variables.emplace_back(root);
}
try {
DistEngine::getInstance().execute(variables, retainGraph);
} catch (python_error & e) {
// FIXME: crashes if exception type is not RuntimeError
throw std::runtime_error(e.what());
}
},
R"(
backward(roots: List[Tensor], retain_graph = False) -> None
Kicks off the distributed backward pass using the provided roots. This
currently implements the :ref:`fast-mode-algorithm` which
assumes all RPC messages sent in the same distributed autograd context
across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute
appropriate dependencies. This method blocks until the entire
autograd computation is done.
We accumulate the gradients in the appropriate
:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
context used is the current autograd context of this node when
:meth:`torch.distributed.autograd.backward` is called. If there is no valid
autograd context, we throw an error. You can retrieve the accumulated
gradients using the :meth:`~torch.distributed.autograd.get_gradients` API.
Arguments:
roots (list): Tensors which represent the roots of the autograd
computation. All the tensors should be scalars.
retain_graph(bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this
option to True is not needed and often can be worked around
in a much more efficient way. Usually, you need to set this
to True to run backward multiple times.
Example::
>> import torch.distributed.autograd as dist_autograd
>> with dist_autograd.context() as context_id:
>> pred = model.forward()
>> loss = loss_func(pred, loss)
>> dist_autograd.backward(loss)
)",
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_gradients",
[](int64_t contextId) -> py::dict {
const auto& autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
return torch::jit::toPyObject(IValue(autogradContext->getGradients()));
},
R"(
get_gradients(context_id: int) -> Dict[Tensor, Tensor]
Retrieves a map from Tensor to the appropriate gradient for that Tensor
accumulated in the provided ``context_id`` as part of the distributed autograd
backward pass.
Arguments:
context_id(int): The autograd context id for which we should retrieve the
gradients.
Returns:
A map where the key is the Tensor and the value is the associated gradient for that Tensor.
Example::
>> import torch.distributed.autograd as dist_autograd
>> with dist_autograd.context() as context_id:
>> t1 = torch.rand((3, 3), requires_grad=True)
>> t2 = torch.rand((3, 3), requires_grad=True)
>> loss = t1 + t2
>> dist_autograd.backward([loss.sum()])
>> grads = dist_autograd.get_gradients(context_id)
>> print (grads[t1])
>> print (grads[t2])
)",
py::arg("context_id"));
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_dist_autograd_init",
(PyCFunction)dist_autograd_init,
METH_NOARGS,
nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace autograd
} // namespace distributed
} // namespace torch