Make RRefContext get devices from RPC agent when creating OwnerRRef (#57443)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57443

Based on the comments in https://github.com/pytorch/pytorch/pull/57355, I started looking at the callsites of `getOrCreateOwnerRRef` and `createOwnerRRef`, and noticed that many of them didn't specify the `devices` argument, which was optional and thus defaulted to `{}`, which created a CPU-only Future inside the OwnerRRef. (Such callsites were, for example, in `processPythonRemoteCall` and `processBaseScriptRemoteCall`, or `PyRRef::unpickle`, ...).

Some (or all?) of these callsites might still have worked thanks to the RRef's own handling of CUDA streams and events, however we intend to remove that in https://github.com/pytorch/pytorch/pull/57355. I think it would be a safer and more generic solution to always create OwnerRRefs with the full set of devices supported by the RPC agent, and this is in fact easy to do since the RRefContext has access to the RPC agent. This means that all OwnerRRefs, no matter how they're created, will support CUDA if the agent does. This also allows us to stop requiring to specify devices when creating a OwnerRRef by hand in Python.
ghstack-source-id: 128184665

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28144365

fbshipit-source-id: 1f2d446873f31ee297415c46b94126b6502b12d3
This commit is contained in:
Luca Wehrstedt
2021-05-06 01:10:10 -07:00
committed by Facebook GitHub Bot
parent 7ffadf6e46
commit 7d4121d1d2
7 changed files with 15 additions and 47 deletions

View File

@ -123,14 +123,10 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
PyRRef::PyRRef(
const py::object& value,
const py::object& type_hint,
std::vector<c10::Device> devices)
: PyRRef([&value, &type_hint, devices{std::move(devices)}]() mutable {
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
: PyRRef([&value, &type_hint]() mutable {
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
auto rref = RRefContext::getInstance().createOwnerRRef(
elem_type, std::move(devices));
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
// jit::toIValue takes a py::handle as the first argument, and it calls
// py::handle.cast<py::object>() to incref of provided value. The
// returned ivalue will keep the reference alive.