Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60943
In https://github.com/pytorch/pytorch/pull/60470 we made Future store Storages rather than store references to their DataPtrs (because these references could go stale...). However this meant that the Future could keep the Storage alive, and thus keep its memory allocated, even after the user was done with it. We fix it here by instead storing a weak ptr to that Storage (well, in fact to the StorageImpl, but it's the same).
ghstack-source-id: 133295799
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D29454104
fbshipit-source-id: d36dee00a4841c087bb7b3f5bc39e0459f209cdb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60470
A Future needs to know what DataPtrs are used by its value, but it isn't always able to extract them (and even when it is, that's expensive) so they're cached. DataPtrs are kinda like unique_ptrs (movable only, cannot be copied) hence the Future can only hold _references_ to them. The Future's value, however, is unfortunately mutable (we'd wish that weren't the case, but we don't think we can prevent that), which means the tensor/storage that owned that DataPtr might be deleted and thus the DataPtr could be freed. This means our cached reference becomes stale! Which leads to all kinds of disaster, like reading garbage data or segfaulting.
Luckily all the DataPtrs we were dealing with were held inside Storages, which have a shared_ptr semantics, thus allowing us to hold a strong pointer to them which ensures they're kept alive.
ghstack-source-id: 132177396
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D29303570
fbshipit-source-id: d814754806fa58b24e45269e97d768485ef972ba
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59208
Reland of https://github.com/pytorch/pytorch/pull/58425
Now that callbacks can provide pre-extracted DataPtrs, let's do so. This will become of crucial importance in the next PR, where some of these futures will become CUDA-aware, and thus they will try to extract DataPtrs on their own, but they would fail to do so here because Message isn't "inspectable".
ghstack-source-id: 130202845
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28623888
fbshipit-source-id: 1aa4bde8014870c071685ba8f72d5f3f01f0a512
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59206
Reland of https://github.com/pytorch/pytorch/pull/58423
This is part 2 of the previous PR. Here we address the remaining occurrences of "raw" Message, namely the ones within toMessageImpl. And since they're the last ones, we make the constructor of Message private, to prevent new usages from emerging.
ghstack-source-id: 130202848
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28623892
fbshipit-source-id: f815cf6b93e488c118e5d2298473e6e9d9f4c132
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59205
Reland of https://github.com/pytorch/pytorch/pull/58422
Similar to Future (which I tackled recently), Message is an ivalue type (a "custom class" one), and the natural way to represent it is inside an intrusive_ptr. However in the RPC code we had a mix of usages, often passing Message by value. This has undesirable consequences, as it could easily trigger a copy by accident, which I believe is why in many places we accepted _rvalue references_ to Message, in order to force the caller to move. In my experience this is non-idiomatic in C++ (normally a function signature specifies how the function consumes its arguments, and it's up to the caller to then decide whether to copy or move).
By moving to intrusive_ptr everywhere I think we eliminate and simplify many of the problems above.
In this PR I do half of the migration, by updating everything except the `toMessageImpl` methods, which will come in the next PR.
ghstack-source-id: 130202849
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28623891
fbshipit-source-id: c9aeea3440679a11741ca78c06b03c57cb815a5e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58425
Now that callbacks can provide pre-extracted DataPtrs, let's do so. This will become of crucial importance in the next PR, where some of these futures will become CUDA-aware, and thus they will try to extract DataPtrs on their own, but they would fail to do so here because Message isn't "inspectable".
ghstack-source-id: 129567057
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28474877
fbshipit-source-id: e68d7d45f1c1dc6daa5e05cf984cfc93d2dce0d0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58423
This is part 2 of the previous PR. Here we address the remaining occurrences of "raw" Message, namely the ones within toMessageImpl. And since they're the last ones, we make the constructor of Message private, to prevent new usages from emerging.
ghstack-source-id: 129567049
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28474879
fbshipit-source-id: 498652a8b80a953396cd5d4b275c0b2e869c9ecf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58422
Similar to Future (which I tackled recently), Message is an ivalue type (a "custom class" one), and the natural way to represent it is inside an intrusive_ptr. However in the RPC code we had a mix of usages, often passing Message by value. This has undesirable consequences, as it could easily trigger a copy by accident, which I believe is why in many places we accepted _rvalue references_ to Message, in order to force the caller to move. In my experience this is non-idiomatic in C++ (normally a function signature specifies how the function consumes its arguments, and it's up to the caller to then decide whether to copy or move).
By moving to intrusive_ptr everywhere I think we eliminate and simplify many of the problems above.
In this PR I do half of the migration, by updating everything except the `toMessageImpl` methods, which will come in the next PR.
ghstack-source-id: 129567053
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28474878
fbshipit-source-id: 5b76d45e05f6fa58c831e369c5c964d126187a6c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51698
Completely eliminates torch::utils::Future as we are now full relying on JitFuture.
ghstack-source-id: 122037612
Test Plan: CI
Reviewed By: kiukchung
Differential Revision: D26243735
fbshipit-source-id: 95010a730f9d35e618f74c5f9de482738cd57c15
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49906
This commit modifies RPC Message to inherit from `torch::CustomClassHolder`,
and wraps a Message in an IValue in `RpcAgent::send()`.
Test Plan: Imported from OSS
Reviewed By: lw
Differential Revision: D25719518
Pulled By: mrshenli
fbshipit-source-id: 694e40021e49e396da1620a2f81226522341550b
Summary:
Addresses https://github.com/pytorch/pytorch/issues/47145
Adds a new MessageTypeFlags enum so that checking for certain properties (e.g. isResponse, isRequest) can be done with a BITWISE AND instead of checking for each MessageType enum individually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48143
Reviewed By: mrshenli
Differential Revision: D25091008
Pulled By: H-Huang
fbshipit-source-id: 56a823747748633c1ef3fa07817ca0f08c7399a8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38748
This diff contains the message scaffolding and profiler changes in order to be able to remotely run the profiler across different nodes and aggregate the results on a single node.
As discussed, we have implemented this by creating new message types, that similar to autograd messages, wrap the profiling information with the original message, and send this new message over the wire. On the receiving end, this wrapped message is detected, we fetch the original message from it, and process the original message with the profiler enabled. When sending a response with profiling information, we serialize the profiled `Events` and send them back over RPC. When such a message is received, the events profiled on the remote node are stored (added back to the local profiler).
Changes in this PR:
- New message types (run_with_profiling_req, run_with_profiling_resp) to send profiling info over the wire. Message parsing logic is added to handle these wrapped types.
- Handling of sending profiler data over the wire, in particular, the attributes of the `ProfilerConfig` and the serialized profiled `Event`s
- The logic for wrapping RPC messages is deduped with that in `rpc_with_autograd`, and the common payload wrapping/unwrapping logic is moved to helper functions in `rpc/utils.cpp`
- Changes in `autograd/utils.cpp` to detect if we have enabled the profiler and are sending an RPC, if so, uses the above new message types
- Changes in request_callback to parse and turn on the profiler in a thread-local fashion
- Serialization and deserialization of profiling `Events`, and support to add the remote events to the thread-local profiler
- Introduction of the concept of `node_id`, which as discussed with ilia-cher , will be used along with the `Event`s handle attribute to distinguish between events. When there are events from different nodes, this node information is rendered in the profile output (e.g. when printing tables), otherwise, it is not, since it is irrelevant.
- Some changes to profiler.cpp to add useful helper methods/guards
- toHere() is now profiled for RRefs
- Unittests
ghstack-source-id: 106134626
Test Plan: Added unittests, existing profiler unittests.
Differential Revision: D19510010
fbshipit-source-id: 044347af992f19a9e3b357c9567f6fc73e988157
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38590
This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works:
- Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call).
- Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown)
- A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for.
- `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user.
- Tests are added, primarily via delay injection.
ghstack-source-id: 105232837
Test Plan: CI
Differential Revision: D21588165
fbshipit-source-id: c9f9e8aa3521012ea1de3e0f152a41afdf8b23f3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39010
The initial version of the serialization for the TensorPipe RPC agent (i.e., the conversion from rpc::Message to tensorpipe::Message) worker around a limitation of TensorPipe of only allowing one payload per message by pickling each tensor separately and storing the pickles as metadata (which is a less efficient way of sending data over, as it goes through more copies). Having now lifter that limitation we can now improve the way we serialize. We now put the type and the id as their own payloads, we do a single pickling pass for all the tensors of the message (which allows us to deduplicate them) and store the pickle as a payload. My impression is that pickling is a somewhat costly operation, so reducing the number of times we do it should be beneficial for performance. For this same reason, another change I've done here is separate the allocation of the buffers from the deserialization. This will allow us (in the future) to perform the allocation on the I/O event loop but perform the unpickling in the worker thread, thus keeping the event loop more responsive.
ghstack-source-id: 104810740
Test Plan: RPC tests
Differential Revision: D21716067
fbshipit-source-id: c1475cc78afdcf0820a485ffd98c91abb35796c7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35154
This is for issue https://github.com/pytorch/pytorch/issues/34999.
close https://github.com/pytorch/pytorch/issues/34999.
https://github.com/pytorch/pytorch/issues/34997 need more work.
This will make a few work items easier, like 1) Dist autograd profiler, 2) JIT annotation for Future.
Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_rref_forward_chain --stress-runs 100
buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork && \
buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par \
-r test_call_method_on_rref
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- 'test_rref_proxy_class \(fb\.test_rpc_fork\.RpcTestWithFork\)' --stress-runs 100
test_rref_proxy_reuse
test_handle_send_exceptions
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork && \
buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par \
-r test_script_call_python_return_future
```
Differential Revision: D7722184
fbshipit-source-id: bd92b855bfea4913d6672700590c57622fa86e0e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37519closes#37446
Currently FutureMessage is used in several places:
1. `rpc_async` returns a `FutureMessage` object and we expose it
as `torch.distributed.rpc.Future`. From applications perspective,
they are expecting a `py::object` instead of a `Message`, and we
do the conversion in the `Future.wait()` pybind method.
2. RPC autograd profiler takes `FutureMessage` and installs
callbacks to it. The profiler actually only need a `Future<T>`
and does not care what `T` is.
3. `OwnerRRef` exposes a `getFuture()` API which returns a
`FutureMessage`. This `FutureMessage` will be marked completed
when the value referenced by the `OwnerRRef` is ready.
`OwnerRRef` does not need it to be a Message type either, it
actually creates an empty `Message` to mark the `Future`.
The above places are using `FutureMessage`, but they don't really
need a `Message`, and `Message` is a communication layer type that
applications or profiler or the RRef shouldn't be aware of.
Another motivation for making this change is that for async RPC
UDF #36071, we are going to allow application to call
`markCompleted` in Python. If we still use `FutureMessage`, then
in the `markCompleted` pybind function, it needs to convert the
provided `py::object` into a specific message type, which is
leaking communication layer code to pybind functions. Even if
this is doable, we will have two entities (RPC agent and pybind
Python frontend) accessing the same request callback logic. This is too messy.
This commit replaces all surface `FutureMessage` with `FutureIValue`,
so that `FutureMessage` is no longer visible from Python land. Note
that this does not cause BC issues, as the Python Future type name
and its API stay intact. Internally, we still have `FutureMessage`
in the communication layer.
Test Plan: Imported from OSS
Reviewed By: xush6528
Differential Revision: D21308887
Pulled By: mrshenli
fbshipit-source-id: 4f574f38e83125081f142813cfdde56119522089
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638
Fixes: https://github.com/pytorch/pytorch/issues/27643
This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly.
(Note: this ignores all push blocking failures!)
Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass.
Differential Revision: D20164420
fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34413
In this diff we have made various improvements to ProcessGroupAgent in order to accomodate edge and error cases such as a "non-clean" shutdown (shutdowns in which we abort RPC as quickly as possible, and don't wait for all pending work across all RPC agents to be completed):
1. Catch and log exceptions in `enqueueRecv`. This prevents us from calling `std::terminate()` in a different thread and logs an error message indicating the issue. With this we no longer have crashes caused by exceptions in this thread during non-graceful shutdown.
2. Provide cleaner error messages everywhere (and use `c10::str` where possible). One example is in `agent::send()`.
3. Add the ability to abort pending sends that cause blocking waits in `handleSend`. The reason we need to abort this is since during a non-graceful shutdown, we could become blocked waiting for these since there is no guarantee the remote end is still active and this would result in a long wait and eventual timeout. We abort these by adding them to a map, and go through this map during `shutdown()`.
4. Fix flaky tests: `test_handle_send_exceptions` and `test_backward_node_failure` and `test_backward_node_failure_python_udf`. These tests were flaky since they dealt with non-graceful shutdown of workers which has chances for a bunch of edge cases explained above.
We have also refactored `createExceptionResponse`, `enqueueRecv`, and some test functions for the above reasons in this diff.
For testing:
Ensured that the tests are no longer flaky with 500 tests runs. Previously, these tests were flaky and disabled. Also added a unit test in the internal `ProcessGroupAgentTest.cpp`.
ghstack-source-id: 100311598
Test Plan: Ensured that the tests are no longer flaky with 500 tests runs. Previously, these tests were flaky and disabled. Also added a unit test in the internal `ProcessGroupAgentTest.cpp`.
Reviewed By: mrshenli
Differential Revision: D20269074
fbshipit-source-id: de9cad7f7185f9864ffbb6b14cd8ca9f6ff8f465
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497
Use a thread_local table to intercept UserRRefs created during user
function args deserialization, and then wait for confirmations of
those UserRRefs before launching the given user function.
Differential Revision: D20347464
Test Plan: Imported from OSS
Pulled By: mrshenli
fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32959
in rpc torch script call path, we need to pickle/unpickle rref, this diff is added to make jit pickler/unpickler be able to pickle/unpickle rref. It is similar to what is implemented for PyRef::pickle() and PyRef::unpickle().
The pickling/unpickling design assumes it is always coupled with RPC calls. It is not needed to checkpoint a model with rref, before checkpointing the model, user should call ref.to_here() to get value inside rref.
The pickling process is:
1. push torch.distributed.rpc.rref global string
1. call rref.fork() and create rrefForkData, which is a few IDs and type str of the value held inside the rref, the IDs includes rref id, fork id, caller work id, callee work id, owner work id
2. push the rrefForkData
The unpickling process is:
1. read torch.distributed.rpc.rref global string, and retrieve the cached global lamda function
2. the globa lamda function will get rrefForkData
3. if callee is also owner work id, then get owner rref based on Ids inside rrefFork data and return the ownerRRef
4. if callee is not owner work id, then create user rref using the rrefForkData and return the userRRef
5. meanwhile owner rref will be notified and do reference counting correctly
During unpickling, a type_resolver is needed to parse type str. This type_resolver has python dependency, so we get it from rpc_agent, and pass it to unpickler during construction. So we added a type_resolver argumenmt to jit unpickler constructor in this diff.
ghstack-source-id: 98814793
Test Plan: unit test
Differential Revision: D19713293
fbshipit-source-id: 4fd776cdd4ce8f457c4034d79acdfb4cd095c52e
Summary:
https://github.com/pytorch/pytorch/pull/30330 got rid of the need to send a `MessageType::SHUTDOWN` message, so we can now remove the logic/utils for this type of message.
I think we can also delete the enum entry in the `enum MessageType`, but we may want to keep it in case the logic in https://github.com/pytorch/pytorch/pull/30710 is ever moved to C++.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31270
Test Plan: All existing unit tests pass
Differential Revision: D19146983
Pulled By: rohan-varma
fbshipit-source-id: 35b185411f9446d7d4dfc37a6cb5477cf041e647
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29579
Per #28923, this diff is to move Future<Message> to torch::utils and extend it to be Future<T>, most of implementations are copied from FutureMessage and ivalue::Future. merge ivalue::Future with Future<T> will be done separately.
The main difference between Future<T> and FutureMessage is the error handling, instead of checking message type inside Future to handle error, this future<T> owns has_error_ and error_ states.
also this future passes value_, has_error_ and error_ states to callbacks for easily read future states.
In next diff, a torch script rpc async API will be created, before the API returns, it will create an ivalue::Future and passes it to Future<T>'s call back where state of ivalue::Future will be set. In this way, the torch script rpc async API can still return a ivalue::Future and call wait() to get its state appropriately afterwards.
ghstack-source-id: 95479525
Test Plan: unit tests
Differential Revision: D18263023
fbshipit-source-id: 48a65712656a72c2feb0bb3ec8b308c0528986a6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29605
Adds a wrapper around the existing createException function that
allows passing of an error string, instead of a regular C++ exception. This
allows us to createExceptions for errors that aren't necessarilu c++
exceptions. This function is used by
https://github.com/pytorch/pytorch/pull/29601 and
https://github.com/pytorch/pytorch/pull/26336.
ghstack-source-id: 93819039
Test Plan: Unit tests pass
Differential Revision: D18439216
fbshipit-source-id: 70b6a2e4f107304e322cdd2630847ad0071bc0c1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27530
Per discussion in #27286, the `UDF` part is superfluous.
This makes the naming consistent with the `MessageType` enum.
Test Plan: Imported from OSS
Differential Revision: D17808211
Pulled By: pietern
fbshipit-source-id: 0ff925de26d027951ce285750ad276ed17fee4c6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27943
This is step 1 to make PyRRef::toHere() non-blocking on caller.
Test Plan: Imported from OSS
Differential Revision: D17936747
Pulled By: mrshenli
fbshipit-source-id: 7cf60e5804e72bdc28f0135fed4d7fdce05ea38a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27951
we want to clean up the distributed autograd context across the other nodes when a single node is done (here done means exited the context manager `with dist_autograd.context() as context_id: ...`).
This PR does a few things to implement the above:
1) Add classes to encapsulate messages for requesting this context release and the response
2) Handling of this request in `request_callback_impl.cpp`. When we receive this request, we get the context from a given context_id and release it.
3) RPC call in `DistAutogradContainer::releaseContext` to send this command. This currently does not wait for an ack or implement any sort of retrying. We send the RPC to all the workerIds we have come into contact with (implemented in https://github.com/pytorch/pytorch/pull/26324)
4) Relevant unit tests
In follow up PRs, we will add error checking + retries for this call.
ghstack-source-id: 92269279
Test Plan: Added/modified unit tests in `test/dist_autograd_test.py`
Differential Revision: D17920137
fbshipit-source-id: 7403512ab5fcbc28d21c548b2e45319dd472e26a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28326
- Message::type() should return a MessageType, not const MessageType&,
since MessageType is just an enum.
- Add moveTensors() method for parallelism with movePayload().
ghstack-source-id: 92236443
Test Plan: buck test mode/dev-nosan caffe2/test/...
Differential Revision: D18021692
fbshipit-source-id: 5b2f5806f104a221de8df0282f3e395d15e5bfe4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022
This change implements the "FAST" mode distributed autograd backward
pass as described in https://github.com/pytorch/pytorch/issues/23110.
At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.
We have made the following changes to the local autograd engine for this
purpose:
1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.
In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926
Test Plan: unit tests.
Differential Revision: D17652615
fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499
See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.
Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)
Main changes:
1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5. Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.
Test Plan:
Imported from OSS
buck test mode/dev-nosan //caffe2/test:rpc_fork
Differential Revision: D17184146
Pulled By: mrshenli
fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527
Master GH issue: https://github.com/pytorch/pytorch/issues/23110.
This change builds upon https://github.com/pytorch/pytorch/pull/24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.
Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466
Test Plan: unit tests.
Differential Revision: D17148077
fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25169
See #23110 for RRef design details. This commit only implements
RRef as return value for builtin operators, and RRef will communicate
between a user and the owner. More specifically, a RRef is first
created on the `dist.remote` caller, which is a user of the RRef.
Then the RRef user sends and notification to the owner to report
the fork to the owner, and the owner uses a shared_ptr to keep
the RRef alive. When the user RRef is destructed on the caller,
another notification will be sent to the owner, and the owner
can then drop it's RRef as well.
Test Plan: Imported from OSS
Differential Revision: D17048343
Pulled By: mrshenli
fbshipit-source-id: 9dd3b3d0e4fd214c76fecdbed746a6d3029b3efd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24138
catch exception thrown on server, send the exception message back to client and rethrow it.
Reviewed By: mrshenli
Differential Revision: D16748748
fbshipit-source-id: ce18b3ea1b1d28645ec292f58aa0c818d93e559e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24502
Files in distributed.rpc package mixes snake camel names. This
commit cleans that up and all files use snake names now.
ghstack-source-id: 88548990
Reviewed By: xush6528
Differential Revision: D16860155
fbshipit-source-id: 3a22a89bf6c4e11aac5849564fc53296a04d6a8b