Enable set sequence nr (#114120)

Summary:
In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path.

For  example:

```
tbe lookup -> a2a ->
                     overarch
dense ------------->
```

if the forward code is written as
a2a_out = a2a
dense = dense_net
out = overarch(a2a_out, dense)
out.backward()

The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap.

Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first.

Test Plan: Tests incoming

Differential Revision: D51445261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Ying Liu
2023-11-21 19:47:24 +00:00
committed by PyTorch MergeBot
parent 1a3dbf57ca
commit 85b97605ab
5 changed files with 73 additions and 3 deletions

View File

@ -201,6 +201,17 @@ PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
auto& fn = *((THPCppFunction*)self)->cdata;
return THPUtils_packUInt64(fn.sequence_nr());
}
PyObject* THPCppFunction_set_sequence_nr(
PyObject* self,
PyObject* sequence_nr) {
HANDLE_TH_ERRORS
auto& fn = *((THPCppFunction*)self)->cdata;
fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,