Enable set sequence nr (#114120)

Summary:
In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path.

For  example:

```
tbe lookup -> a2a ->
                     overarch
dense ------------->
```

if the forward code is written as
a2a_out = a2a
dense = dense_net
out = overarch(a2a_out, dense)
out.backward()

The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap.

Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first.

Test Plan: Tests incoming

Differential Revision: D51445261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Ying Liu
2023-11-21 19:47:24 +00:00
committed by PyTorch MergeBot
parent 1a3dbf57ca
commit 85b97605ab
5 changed files with 73 additions and 3 deletions

View File

@ -11820,6 +11820,47 @@ class TestAutogradMultipleDispatch(TestCase):
TestFn.apply(inp, None).sum().backward()
self.assertEqual(local.my_obj[10], 5)
def test_set_sequence_nr(self):
x = torch.randn((10,), dtype=torch.float32, requires_grad=True)
y = torch.randn((10,), dtype=torch.float32, requires_grad=True)
z = torch.randn((10,), dtype=torch.float32, requires_grad=True)
a = x + y
b = y + z
c = a + b
self.assertIsNotNone(a.grad_fn)
self.assertIsNotNone(b.grad_fn)
self.assertIsNotNone(c.grad_fn)
a.grad_fn._set_sequence_nr(100)
b.grad_fn._set_sequence_nr(99)
c.grad_fn._set_sequence_nr(98)
self.assertEqual(a.grad_fn._sequence_nr(), 100)
self.assertEqual(b.grad_fn._sequence_nr(), 99)
self.assertEqual(c.grad_fn._sequence_nr(), 98)
def log_grad_order(grad: torch.Tensor, name: str, order):
order.append(name)
return grad
order = []
a.register_hook(partial(log_grad_order, name="a", order=order))
b.register_hook(partial(log_grad_order, name="b", order=order))
c.register_hook(partial(log_grad_order, name="c", order=order))
c.sum().backward()
# Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd.
# Also check that although a comes first during the forward, after giving it priority with sequence_nr,
# its autograd node is run before that of b.
self.assertEqual(order, ['c', 'a', 'b'])
self.assertEqual(x.grad, torch.ones_like(x))
self.assertEqual(y.grad, 2 * torch.ones_like(x))
self.assertEqual(z.grad, torch.ones_like(x))
# Import test cases from below autograd/ here. These are found
# implicitly by the loader, so Flake8 thinks they are unused, hence

View File

@ -328,6 +328,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return sequence_nr_;
}
void set_sequence_nr(uint64_t sequence_nr) {
sequence_nr_ = sequence_nr;
}
// NOTE [ Topological Number ]
//
// topological_nr is used to prune branches in the DAG during autograd
@ -590,7 +594,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// Sequence number used to correlate backward nodes with forward ops in the
// profiler and provide determinism in the engine.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const uint64_t sequence_nr_;
uint64_t sequence_nr_;
// See NOTE [ Topological Number ]
uint64_t topological_nr_ = 0;

View File

@ -201,6 +201,17 @@ PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
auto& fn = *((THPCppFunction*)self)->cdata;
return THPUtils_packUInt64(fn.sequence_nr());
}
PyObject* THPCppFunction_set_sequence_nr(
PyObject* self,
PyObject* sequence_nr) {
HANDLE_TH_ERRORS
auto& fn = *((THPCppFunction*)self)->cdata;
fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,

View File

@ -43,8 +43,13 @@ PyObject* CppFunction_pynew(
THPCppFunction_register_prehook, \
METH_O, \
nullptr}, \
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, { \
(char*)"_sequence_nr", THPCppFunction_sequence_nr, METH_NOARGS, nullptr \
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, \
{(char*)"_sequence_nr", \
THPCppFunction_sequence_nr, \
METH_NOARGS, \
nullptr}, \
{ \
(char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
}
#define THP_FUNCTION_DEFAULT_PROPERTIES \

View File

@ -964,6 +964,14 @@ PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
HANDLE_TH_ERRORS;
auto cdata = ((THPFunction*)self)->cdata.lock();
cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPFunction_maybe_clear_saved_tensors(
PyObject* self,
PyObject* noargs) {
@ -1532,6 +1540,7 @@ static struct PyGetSetDef THPFunction_properties[] = {
static struct PyMethodDef THPFunction_methods[] = {
{(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
{(char*)"_sequence_nr", THPFunction_sequence_nr, METH_NOARGS, nullptr},
{(char*)"_set_sequence_nr", THPFunction_set_sequence_nr, METH_O, nullptr},
{(char*)"maybe_clear_saved_tensors",
THPFunction_maybe_clear_saved_tensors,
METH_NOARGS,