mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make RRefContext get devices from RPC agent when creating OwnerRRef (#57443)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57443 Based on the comments in https://github.com/pytorch/pytorch/pull/57355, I started looking at the callsites of `getOrCreateOwnerRRef` and `createOwnerRRef`, and noticed that many of them didn't specify the `devices` argument, which was optional and thus defaulted to `{}`, which created a CPU-only Future inside the OwnerRRef. (Such callsites were, for example, in `processPythonRemoteCall` and `processBaseScriptRemoteCall`, or `PyRRef::unpickle`, ...). Some (or all?) of these callsites might still have worked thanks to the RRef's own handling of CUDA streams and events, however we intend to remove that in https://github.com/pytorch/pytorch/pull/57355. I think it would be a safer and more generic solution to always create OwnerRRefs with the full set of devices supported by the RPC agent, and this is in fact easy to do since the RRefContext has access to the RPC agent. This means that all OwnerRRefs, no matter how they're created, will support CUDA if the agent does. This also allows us to stop requiring to specify devices when creating a OwnerRRef by hand in Python. ghstack-source-id: 128184665 Test Plan: CI Reviewed By: mrshenli Differential Revision: D28144365 fbshipit-source-id: 1f2d446873f31ee297415c46b94126b6502b12d3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7ffadf6e46
commit
7d4121d1d2
@ -123,14 +123,10 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
|
||||
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
||||
}
|
||||
|
||||
PyRRef::PyRRef(
|
||||
const py::object& value,
|
||||
const py::object& type_hint,
|
||||
std::vector<c10::Device> devices)
|
||||
: PyRRef([&value, &type_hint, devices{std::move(devices)}]() mutable {
|
||||
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
|
||||
: PyRRef([&value, &type_hint]() mutable {
|
||||
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
|
||||
auto rref = RRefContext::getInstance().createOwnerRRef(
|
||||
elem_type, std::move(devices));
|
||||
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
|
||||
// jit::toIValue takes a py::handle as the first argument, and it calls
|
||||
// py::handle.cast<py::object>() to incref of provided value. The
|
||||
// returned ivalue will keep the reference alive.
|
||||
|
Reference in New Issue
Block a user