Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49906
This commit modifies RPC Message to inherit from `torch::CustomClassHolder`,
and wraps a Message in an IValue in `RpcAgent::send()`.
Test Plan: Imported from OSS
Reviewed By: lw
Differential Revision: D25719518
Pulled By: mrshenli
fbshipit-source-id: 694e40021e49e396da1620a2f81226522341550b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46568
This PR adds support for an RRef.backward() API. This would be useful
in applications like pipeline parallelism as described here:
https://github.com/pytorch/pytorch/issues/44827
This PR only adds support for local RRefs, remote RRef support will be added in
a follow up PR.
ghstack-source-id: 115100729
Test Plan:
1) unit tests.
2) waitforbuildbot
Reviewed By: mrshenli
Differential Revision: D24406311
fbshipit-source-id: fb0b4e185d9721bf57f4dea9847e0aaa66b3e513
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44663
The new API returns the type of the data object referenced by this
`RRef`. On the owner, this is same as `type(rref.local_value())`.
On a user, this will trigger an RPC to fetch the `type` object from
the owner. After this function is run once, the `type` object is
cached by the `RRef`, and subsequent invocations no longer trigger
RPC.
closes#33210
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D23691990
Pulled By: mrshenli
fbshipit-source-id: a2d87cd601a691dd75164b6bcd7315245e9cf6bd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40902
See the bottom of this stack for context.
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D22360210
Pulled By: suo
fbshipit-source-id: 4275127173a36982ce9ad357aa344435b98e1faf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39974
# Problem
When this assertion happens, I don't know
- which worker_id it is on, even with the worker_name "trainer:0".
- which rref is throwing this exception.
```shell
File "/mnt/xarfuse/uid-213229/96b122e4-seed-df64b884-e2b4-4520-b7a8-777e79c829ac-ns-4026532900/caffe2/torch/fb/training_toolkit/backend/training_strategies/parameter_server_strategy.py", line 246, in _initialize_trainers
trainer_name: fut.wait() for trainer_name, fut in model_rref_futs.items()
File "/mnt/xarfuse/uid-213229/96b122e4-seed-df64b884-e2b4-4520-b7a8-777e79c829ac-ns-4026532900/caffe2/torch/fb/training_toolkit/backend/training_strategies/parameter_server_strategy.py", line 246, in <dictcomp>
trainer_name: fut.wait() for trainer_name, fut in model_rref_futs.items()
File "/mnt/xarfuse/uid-213229/96b122e4-seed-df64b884-e2b4-4520-b7a8-777e79c829ac-ns-4026532900/torch/distributed/rpc/internal.py", line 158, in _handle_exception
raise result.exception_type(result.msg)
RuntimeError: RuntimeError('Cannot call localValue() on a non-local reference. Call it on trainer:0')
Traceback (most recent call last):
File "/mnt/xarfuse/uid-213229/96b122e4-seed-21bc7792-3714-4e62-a1c1-32a7c38ed984-ns-4026533058/torch/distributed/rpc/internal.py", line 148, in _run_function
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
File "/mnt/xarfuse/uid-213229/96b122e4-seed-21bc7792-3714-4e62-a1c1-32a7c38ed984-ns-4026533058/torch/distributed/rpc/rref_proxy.py", line 5, in _local_invoke
return getattr(rref.local_value(), func_name)(*args, **kwargs)
RuntimeError: Cannot call localValue() on a non-local reference. Call it on trainer:0
```
Changes,
- Add stringify WorkerInfo
- Make localValue() assertion message clearer about the case.
ghstack-source-id: 105840918
Test Plan:
buck test mode/dev-nosan //caffe2/test/distributed/rpc/:rpc_fork -- test_local_value_not_on_owner
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit/:rpc_fork
Reviewed By: mrshenli
Differential Revision: D5690653
fbshipit-source-id: ca6a8b1ff6e09f8644303a0f82f9b1a546a11170
Summary:
Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly
Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str().
annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated"
Closes https://github.com/pytorch/pytorch/issues/39449
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544
Differential Revision: D21978759
Pulled By: gmagogsfm
fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38590
This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works:
- Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call).
- Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown)
- A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for.
- `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user.
- Tests are added, primarily via delay injection.
ghstack-source-id: 105232837
Test Plan: CI
Differential Revision: D21588165
fbshipit-source-id: c9f9e8aa3521012ea1de3e0f152a41afdf8b23f3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38352
Fixes the RPC profiling by using the `then()` API added in https://github.com/pytorch/pytorch/pull/37311. Instead of adding a regular callback, we return a new future that completes when the profiling callback is finished. This is transparent to the user as the future still completes with the value of the original future (i.e. the RPC's return value)
To make this work for RRef, we add a `_set_profiling_future` to set the profiling future, and `_get_profiling_future` to retrieve this future and wait on it in the tests.
Re-enabled profiling tests and stress tested them 1000 times to verify the fix
ghstack-source-id: 104086114
Test Plan: Re-enabled profiling tests
Differential Revision: D21506940
fbshipit-source-id: 35cde22f0551c825c9bc98ddc24cca412878a63a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35154
This is for issue https://github.com/pytorch/pytorch/issues/34999.
close https://github.com/pytorch/pytorch/issues/34999.
https://github.com/pytorch/pytorch/issues/34997 need more work.
This will make a few work items easier, like 1) Dist autograd profiler, 2) JIT annotation for Future.
Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_rref_forward_chain --stress-runs 100
buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork && \
buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par \
-r test_call_method_on_rref
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- 'test_rref_proxy_class \(fb\.test_rpc_fork\.RpcTestWithFork\)' --stress-runs 100
test_rref_proxy_reuse
test_handle_send_exceptions
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork && \
buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par \
-r test_script_call_python_return_future
```
Differential Revision: D7722184
fbshipit-source-id: bd92b855bfea4913d6672700590c57622fa86e0e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37519closes#37446
Currently FutureMessage is used in several places:
1. `rpc_async` returns a `FutureMessage` object and we expose it
as `torch.distributed.rpc.Future`. From applications perspective,
they are expecting a `py::object` instead of a `Message`, and we
do the conversion in the `Future.wait()` pybind method.
2. RPC autograd profiler takes `FutureMessage` and installs
callbacks to it. The profiler actually only need a `Future<T>`
and does not care what `T` is.
3. `OwnerRRef` exposes a `getFuture()` API which returns a
`FutureMessage`. This `FutureMessage` will be marked completed
when the value referenced by the `OwnerRRef` is ready.
`OwnerRRef` does not need it to be a Message type either, it
actually creates an empty `Message` to mark the `Future`.
The above places are using `FutureMessage`, but they don't really
need a `Message`, and `Message` is a communication layer type that
applications or profiler or the RRef shouldn't be aware of.
Another motivation for making this change is that for async RPC
UDF #36071, we are going to allow application to call
`markCompleted` in Python. If we still use `FutureMessage`, then
in the `markCompleted` pybind function, it needs to convert the
provided `py::object` into a specific message type, which is
leaking communication layer code to pybind functions. Even if
this is doable, we will have two entities (RPC agent and pybind
Python frontend) accessing the same request callback logic. This is too messy.
This commit replaces all surface `FutureMessage` with `FutureIValue`,
so that `FutureMessage` is no longer visible from Python land. Note
that this does not cause BC issues, as the Python Future type name
and its API stay intact. Internally, we still have `FutureMessage`
in the communication layer.
Test Plan: Imported from OSS
Reviewed By: xush6528
Differential Revision: D21308887
Pulled By: mrshenli
fbshipit-source-id: 4f574f38e83125081f142813cfdde56119522089
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36619
With this PR, applications no longer need to create dedicated helpers
to run functions on the object referenced by an RRef. Instead,
`rref.rpc_sync().some_func()` will use `rpc_sync` to run `some_func`
on the owner of the RRef using the object referenced by the RRef.
Similar helpers for `rref.rpc_async().some_func()` and
`rref.remote().some_func()` are also added.
An alternative design is to expose PyRRef as RRefBase and then
implement everything in a new Python RRef class. However, the RRef
class cannot directly inherit from PyRRef/RRefBase, otherwise we
will need to let pyRemote* C++ functions to load RRef from Python
and return an RRef instance. It is possible to let RRef hold a
instance of PyRRef instead of inherit from it, but this does not
look like a elegant design, as we will have RRef holding PyRRef and
PyRRef holding the C++ RRef. Another alternative is to use dynamic
method loading, by installing member methods to PyRRef instances.
However, this would require different solutions to handle
RRef(data) and rpc.remote(...). Base on the above thinking, we
decided to go with the current implementation for simplicity and we
can also keep all RRef-related APIs in one place.
Test Plan: Imported from OSS
Differential Revision: D21028333
Pulled By: mrshenli
fbshipit-source-id: fe90f56ef7183d18874e357900093755e1601eb4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35055
This is the first step to improving the way RPCs are profiled as suggested by Ilia. For now, since RPC can return two different types of futures, we have to implement two different code paths, one for the python eager mode future and one for the jit future.
This diff implements the python eager part. We have defined a method `_call_end_callbacks_on_future` that takes in a future and schedules a `RecordFunction` to be completed as a callback on the future.
Once https://github.com/pytorch/pytorch/pull/35039 lands, we can implement the JIT codepath by registering an operator that takes a `Future(t)` as well.
These code paths will be merged once the futures are merged.
ghstack-source-id: 102478180
Test Plan: Added unit tests
Differential Revision: D20452003
fbshipit-source-id: 1acdcb073bd1f63d6fb2e78277ac0be00fd6671d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35943
This change will add message to tell why the concrete Module type is not a subtype of the Interface type, by telling the missing method name. For example, users may have forgot to tag that method with torch.jit.export.
Test Plan: `
Differential Revision: D7993693
fbshipit-source-id: 1a5b1d9ef483e5e120ab53c2427586560fbb9bcd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34755
This diff disallows to use python pickler to pickle RRef. RRef can only be pickled in the scope of RPC call using _InternalRPCPickler.
ghstack-source-id: 100481337
Test Plan: unit tests
Differential Revision: D20453806
fbshipit-source-id: ebd4115ee01457ba6958cde805afd0a87c686612
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34896
Make TorchScript support calling ref.owner() to get owner worker id and calling ref.owner_name() to get owner worker name.
Differential Revision: D7652208
fbshipit-source-id: a60125bb316ac2cf19a993cbd2affc933c0af7c9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497
Use a thread_local table to intercept UserRRefs created during user
function args deserialization, and then wait for confirmations of
those UserRRefs before launching the given user function.
Differential Revision: D20347464
Test Plan: Imported from OSS
Pulled By: mrshenli
fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515
Once upon a time we thought this was necessary. In reality it is not, so
removing it.
For backcompat, our public interface (defined in `api/`) still has
typedefs to the old `script::` names.
There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph
transform. I renamed one of them.
Test Plan: Imported from OSS
Differential Revision: D20353503
Pulled By: suo
fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34183https://github.com/pytorch/pytorch/pull/33263 enhanced the RRef Python constructor to infer most types, by `jit::tryToInferType(..)`.
But this helper function can't infer `ScriptModule` type due to `ScriptModule`'s special per-Module type singleton logic, so it's still not possible for an Python-created RRef to know the JIT type of it's contained `ScriptModule`.
Instead of inferring the specific type of a Module, which could leads to too many candidate types (due to Module's multiple inheritance possibility), it's more straightforward to set it's type as a user-specified `ModuleInterface` type.
We added an optional argument `type_hint` for users to mark an `RRef` for what `ModuleInterface` type it's holds.
ghstack-source-id: 99649379
(Note: this ignores all push blocking failures!)
Test Plan:
Aspects that need to be confirmed in the test cases
https://fb.quip.com/aGxRAh2lCg05
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_class_rref
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_module_rref
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_class_rref_in_py_and_use_in_script
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_module_rref_in_py_and_use_in_script
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_torchscript_function_exception
```
Differential Revision: D7065050
fbshipit-source-id: e10210c0996622969e499e4a35b0659b36787c1c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33263
This PR allow PyRRef local creation to inspect the pyobject, if it
founds that we could turn it to an IValue, turn to an IValue first,
otherwise hold it as a PyObjectType
Test Plan:
Imported from OSS
https://fb.quip.com/aGxRAh2lCg05
Differential Revision: D19871243
Pulled By: wanchaol
fbshipit-source-id: ae5be3c52fb1e6db33c64e64ef64bc8b9ea63a9a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32959
in rpc torch script call path, we need to pickle/unpickle rref, this diff is added to make jit pickler/unpickler be able to pickle/unpickle rref. It is similar to what is implemented for PyRef::pickle() and PyRef::unpickle().
The pickling/unpickling design assumes it is always coupled with RPC calls. It is not needed to checkpoint a model with rref, before checkpointing the model, user should call ref.to_here() to get value inside rref.
The pickling process is:
1. push torch.distributed.rpc.rref global string
1. call rref.fork() and create rrefForkData, which is a few IDs and type str of the value held inside the rref, the IDs includes rref id, fork id, caller work id, callee work id, owner work id
2. push the rrefForkData
The unpickling process is:
1. read torch.distributed.rpc.rref global string, and retrieve the cached global lamda function
2. the globa lamda function will get rrefForkData
3. if callee is also owner work id, then get owner rref based on Ids inside rrefFork data and return the ownerRRef
4. if callee is not owner work id, then create user rref using the rrefForkData and return the userRRef
5. meanwhile owner rref will be notified and do reference counting correctly
During unpickling, a type_resolver is needed to parse type str. This type_resolver has python dependency, so we get it from rpc_agent, and pass it to unpickler during construction. So we added a type_resolver argumenmt to jit unpickler constructor in this diff.
ghstack-source-id: 98814793
Test Plan: unit test
Differential Revision: D19713293
fbshipit-source-id: 4fd776cdd4ce8f457c4034d79acdfb4cd095c52e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33190
This enable the initial RRef type to be used inside TorchScript, user
could pass a python RRef into a torchscript function and call to_here
inside. Specifically, this PR:
- Add RRef schema type parsing
- Add python interop for RRef in Python and into JIT
- register to_here op in register_distributed_ops
More support for RRef in TorchScript will be added in future PRs
Test Plan: Imported from OSS
Differential Revision: D19871244
Pulled By: wanchaol
fbshipit-source-id: 7eca6c491a84666b261c70806254b705603bd663
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189
Add RRefInterface to Aten/Core, which will later be used by IValue
Switch all the rpc code base to use intrusive_ptr instead of shared_ptr,
so that we could add it to IValue.
Actual adding to IValue and JIT will be in next PR
Test Plan: Imported from OSS
Differential Revision: D19871241
Pulled By: wanchaol
fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32753
Functions to be bound as an Aten operator could not have Python dependency.
This is to refactor and remove Python dependency.
ghstack-source-id: 97485800
Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_script_functions_not_supported
buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork
buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_script_functions_not_supported
```
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork
buck build mode/dev-nosan //caffe2/test/distributed/rpc:dist_autograd_fork
buck-out/gen/caffe2/test/distributed/rpc/dist_autograd_fork\#binary.par -r test_backward_simple_script_call
```
Differential Revision: D5741675
fbshipit-source-id: 31ee60955be8d815d0773f3699e3ff2f1f9d8849
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32807
After this commit, RRefContext no longer depends on pybind.
Test Plan: Imported from OSS
Differential Revision: D19636316
Pulled By: mrshenli
fbshipit-source-id: 88faa101c32e9019e979ae8e5da6706e49842726
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32785
Add PythonRpcHandler::handleExceptionWithGIL() so that in PyRRef::localValue(),
we don't need to release the GIL and re-acquire the following line.
ghstack-source-id: 97418465
Test Plan: existing test coverage
Differential Revision: D19626195
fbshipit-source-id: db694d04b078811f819626789e1e86f1b35adb5b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32748
This is to follow up PR #30630, we need to have GIL when calling jit::toPyObject(), for some binded functions need to be taged with GIL release if underneath C++ codes requires GIL. so
1. pyRef::to_here() and pyRef::local_value() added GIL
2. pyRef::pickle and pyRef::unpickle() added GIL release tag
3. in request_callback_impl, also added GIL as needed
4. for typeParser, use cached jitCompilationUnit_, also clean it up in cleanUp() function
ghstack-source-id: 97373011
Test Plan: unit test
Differential Revision: D19612337
fbshipit-source-id: 4d09f9b52ba626545ae7d31fea6b671301ed3890
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630
This remove template and all the specializations it have in rpc, we
universally use IValue as the inner value since we support making python
object to be hold inside IValue.
This will also ensure that we have the correct type information when
creating the RRef, we use the return type from the schema when creating
userRRef and OwnerRRef, it will enable IValue to always have the correct
type if the IValue is the RRef object (next PR)
Test Plan: Imported from OSS
Differential Revision: D19502235
fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
Summary:
Closes https://github.com/pytorch/pytorch/issues/31198, see the issue for more details. We throw an error when `local_value()` is called on a non-owned rref, but the incorrect node name is printed in the error message. This PR fixes that and adds a relevant unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31199
Differential Revision: D19072014
Pulled By: rohan-varma
fbshipit-source-id: 760c20bfd2fbf286eaaca19500469509a575cfec
Summary:
Given that pybind11 implements these gil functions, I don't think it makes sense for Pytorch to have its own bespoke versions.
Fixes https://github.com/pytorch/pytorch/issues/29065
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29095
Differential Revision: D18301806
Pulled By: ezyang
fbshipit-source-id: 03da6a26c41ee65aaadf7b67b9f0b14d2def2a5a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29930
Right now, python call remote exception rethrown is coupled with deserializtiaon.
For owner ref, the setValue() and getValue() do not use serialization and deserialization, so when users create a ref to itself, and call ownerRef.to_here(), python call remote exception will not be rethrown.
This diff is to move remote exception rethrown out of deserialization, and exception can be handled for ownerRef.localValue() or ownerRef.to_here()
close#29924
ghstack-source-id: 94210894
Test Plan: unit tests
Differential Revision: D18541916
fbshipit-source-id: 7cda93f623d52c740b3c1b1fa9a442f866984340
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28948
Add the constructor RRef(value) in python. This allows to wrap a local object with RRef an pass or return this RRef to users.
This enables returning, for example, a list of RRefs containing the parameters of a module to the user of the module.
ghstack-source-id: 93565010
Test Plan: unit test.
Differential Revision: D18241227
fbshipit-source-id: b9e9b958f40623348d62ee6fc9e7f0414b4215b7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29396
The return types of RRef.to_here()/local_value() were recently
changed to Future, which triggers flakiness as the RRef could be
deleted before the future.wait() finishes. While we are still
discussing how we'd like to solve it, this commit reverts the
return type to stop bleeding in tests.
closes#28885
Test Plan: Imported from OSS
Differential Revision: D18375571
Pulled By: mrshenli
fbshipit-source-id: 354dbf38b15ab804e44fc9968dd30888415c1fab
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28909
This allows to chain calls on RRef as exemplified in the new test case added.
ghstack-source-id: 92996018
Test Plan: unit test.
Differential Revision: D18231081
fbshipit-source-id: deeac044ef6d63f18ea241760ac17a3e644cb3d7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28025
Add a PyFuture type which is wrapper of either an OwnerRRef or a
jit::Future. The difference between PyFuture and jit::Future is that
PyFuture can return an custom py::object type.
Test Plan: Imported from OSS
Differential Revision: D17936746
Pulled By: mrshenli
fbshipit-source-id: a7451af3993d98aeab462ffd5318fc6d28f915c8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27943
This is step 1 to make PyRRef::toHere() non-blocking on caller.
Test Plan: Imported from OSS
Differential Revision: D17936747
Pulled By: mrshenli
fbshipit-source-id: 7cf60e5804e72bdc28f0135fed4d7fdce05ea38a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022
This change implements the "FAST" mode distributed autograd backward
pass as described in https://github.com/pytorch/pytorch/issues/23110.
At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.
We have made the following changes to the local autograd engine for this
purpose:
1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.
In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926
Test Plan: unit tests.
Differential Revision: D17652615
fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3