mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable set sequence nr (#114120)
Summary: In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path. For example: ``` tbe lookup -> a2a -> overarch dense -------------> ``` if the forward code is written as a2a_out = a2a dense = dense_net out = overarch(a2a_out, dense) out.backward() The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap. Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first. Test Plan: Tests incoming Differential Revision: D51445261 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a3dbf57ca
commit
85b97605ab
@ -11820,6 +11820,47 @@ class TestAutogradMultipleDispatch(TestCase):
|
||||
TestFn.apply(inp, None).sum().backward()
|
||||
self.assertEqual(local.my_obj[10], 5)
|
||||
|
||||
def test_set_sequence_nr(self):
|
||||
x = torch.randn((10,), dtype=torch.float32, requires_grad=True)
|
||||
y = torch.randn((10,), dtype=torch.float32, requires_grad=True)
|
||||
z = torch.randn((10,), dtype=torch.float32, requires_grad=True)
|
||||
|
||||
a = x + y
|
||||
b = y + z
|
||||
c = a + b
|
||||
|
||||
self.assertIsNotNone(a.grad_fn)
|
||||
self.assertIsNotNone(b.grad_fn)
|
||||
self.assertIsNotNone(c.grad_fn)
|
||||
|
||||
a.grad_fn._set_sequence_nr(100)
|
||||
b.grad_fn._set_sequence_nr(99)
|
||||
c.grad_fn._set_sequence_nr(98)
|
||||
|
||||
self.assertEqual(a.grad_fn._sequence_nr(), 100)
|
||||
self.assertEqual(b.grad_fn._sequence_nr(), 99)
|
||||
self.assertEqual(c.grad_fn._sequence_nr(), 98)
|
||||
|
||||
def log_grad_order(grad: torch.Tensor, name: str, order):
|
||||
order.append(name)
|
||||
return grad
|
||||
|
||||
order = []
|
||||
a.register_hook(partial(log_grad_order, name="a", order=order))
|
||||
b.register_hook(partial(log_grad_order, name="b", order=order))
|
||||
c.register_hook(partial(log_grad_order, name="c", order=order))
|
||||
|
||||
c.sum().backward()
|
||||
|
||||
# Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd.
|
||||
# Also check that although a comes first during the forward, after giving it priority with sequence_nr,
|
||||
# its autograd node is run before that of b.
|
||||
self.assertEqual(order, ['c', 'a', 'b'])
|
||||
|
||||
self.assertEqual(x.grad, torch.ones_like(x))
|
||||
self.assertEqual(y.grad, 2 * torch.ones_like(x))
|
||||
self.assertEqual(z.grad, torch.ones_like(x))
|
||||
|
||||
|
||||
# Import test cases from below autograd/ here. These are found
|
||||
# implicitly by the loader, so Flake8 thinks they are unused, hence
|
||||
|
@ -328,6 +328,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return sequence_nr_;
|
||||
}
|
||||
|
||||
void set_sequence_nr(uint64_t sequence_nr) {
|
||||
sequence_nr_ = sequence_nr;
|
||||
}
|
||||
|
||||
// NOTE [ Topological Number ]
|
||||
//
|
||||
// topological_nr is used to prune branches in the DAG during autograd
|
||||
@ -590,7 +594,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// Sequence number used to correlate backward nodes with forward ops in the
|
||||
// profiler and provide determinism in the engine.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const uint64_t sequence_nr_;
|
||||
uint64_t sequence_nr_;
|
||||
|
||||
// See NOTE [ Topological Number ]
|
||||
uint64_t topological_nr_ = 0;
|
||||
|
@ -201,6 +201,17 @@ PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
return THPUtils_packUInt64(fn.sequence_nr());
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_set_sequence_nr(
|
||||
PyObject* self,
|
||||
PyObject* sequence_nr) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static struct PyMethodDef default_methods[] = {
|
||||
THP_FUNCTION_DEFAULT_METHODS,
|
||||
|
@ -43,8 +43,13 @@ PyObject* CppFunction_pynew(
|
||||
THPCppFunction_register_prehook, \
|
||||
METH_O, \
|
||||
nullptr}, \
|
||||
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, { \
|
||||
(char*)"_sequence_nr", THPCppFunction_sequence_nr, METH_NOARGS, nullptr \
|
||||
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, \
|
||||
{(char*)"_sequence_nr", \
|
||||
THPCppFunction_sequence_nr, \
|
||||
METH_NOARGS, \
|
||||
nullptr}, \
|
||||
{ \
|
||||
(char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
|
||||
}
|
||||
|
||||
#define THP_FUNCTION_DEFAULT_PROPERTIES \
|
||||
|
@ -964,6 +964,14 @@ PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
|
||||
HANDLE_TH_ERRORS;
|
||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||
cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPFunction_maybe_clear_saved_tensors(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
@ -1532,6 +1540,7 @@ static struct PyGetSetDef THPFunction_properties[] = {
|
||||
static struct PyMethodDef THPFunction_methods[] = {
|
||||
{(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
|
||||
{(char*)"_sequence_nr", THPFunction_sequence_nr, METH_NOARGS, nullptr},
|
||||
{(char*)"_set_sequence_nr", THPFunction_set_sequence_nr, METH_O, nullptr},
|
||||
{(char*)"maybe_clear_saved_tensors",
|
||||
THPFunction_maybe_clear_saved_tensors,
|
||||
METH_NOARGS,
|
||||
|
Reference in New Issue
Block a user