Determine autograd engine ready queue based on InputMetadata instead of InputBuffer (#135633)

Thanks @awgu for raising this issue and the small repro

From offline discussion with @albanD, in the case where a forward returns multiple outputs with different devices, we'd want to select the ready queue based on the device of the first one. Even though this is somewhat arbitrary, we prefer this over deciding which ready queue to push based on whichever input buffer's we happen to compute last, which can vary depending on more factors and thus be harder to reason about. This is in theory bc-breaking, but it seems unlikely that someone would depend on this behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135633
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2024-10-04 13:00:27 -07:00
committed by PyTorch MergeBot
parent 79562f3af8
commit d6f340f66c
6 changed files with 64 additions and 23 deletions

View File

@ -2916,6 +2916,7 @@ known_failing_tests = {
"test_grad_mode_restored_reentrant", # hangs with graph breaks
"test_no_grad_copy", # setting static member in lifted backward
"test_no_grad_copy_sparse", # setting static member in lifted backward
"test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError <built-in method clone
"test_reentrant_priority", # hangs with graph breaks
"test_reentrant_with_callbacks_both_depths", # hangs with graph breaks
"test_reentrant_with_callbacks_depth_0", # probably hangs with graph breaks

View File

@ -4379,6 +4379,49 @@ class TestAutograd(TestCase):
run_test((10, 10), torch.zeros(10, 10))
run_test((10,), 0)
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_node_ordering_when_none_returned(self):
class Matmul(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w):
# x: [M, N]
# w: [N, K]
ctx.save_for_backward(x, w)
return x @ w
@staticmethod
def backward(ctx, g_out):
# g_out: [M, K]
x, w = ctx.saved_tensors
g_x = g_out @ w.T
g_w = x.T @ g_out
w.main_grad = g_w.float()
return g_x, None
executed = []
class HookFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, g):
executed.append("A")
return g
def hook(*args, **kwargs):
executed.append("B")
x = torch.randn((3, 3), dtype=torch.bfloat16, device="cuda", requires_grad=True)
x = HookFunction.apply(x)
w = torch.randn((3, 3), dtype=torch.bfloat16, device="cuda", requires_grad=True)
w.register_hook(hook)
o = Matmul.apply(x, w)
o.sum().backward()
self.assertEqual(executed, ["B", "A"])
def test_current_graph_task_id(self):
id = [-1]

View File

@ -1110,7 +1110,7 @@ void Engine::evaluate_function(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
auto queue = ready_queue(cpu_ready_queue, next.function->device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
@ -1125,7 +1125,7 @@ void Engine::evaluate_function(
input_buffer.add(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
auto queue = ready_queue(cpu_ready_queue, next.function->device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
@ -1310,7 +1310,7 @@ c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
// Lock mutex for GraphTask.
std::unique_lock<std::mutex> lock(graph_task->mutex_);
auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
auto queue = ready_queue(graph_task->cpu_ready_queue_, graph_root->device());
// worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
// autograd engine with corresponding GraphTask, and its NOT a re-entrant call

View File

@ -252,6 +252,23 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return std::nullopt;
}
// Used by the engine to determine what device thread to run on
at::Device device() {
// Since we pick the first non-CPU tensor, this won't work with
// mixed device-type operations (e.g., an op that is both CUDA
// and XLA). This is *incredibly* unlikely, so we don't worry
// about it.
for (const auto& metadata : input_metadata_) {
auto device = metadata.device();
if (device.type() != at::kCPU) {
return device;
}
}
// Only report to the CPU thread if there really were no tensors
// from other devices.
return at::kCPU;
}
void clear_input_metadata() {
input_metadata_.clear();
}

View File

@ -222,24 +222,6 @@ void InputBuffer::add(
}
}
auto InputBuffer::device() const -> at::Device {
// Since we pick the first non-CPU tensor, this won't work with
// mixed device-type operations (e.g., an op that is both CUDA
// and XLA). This is *incredibly* unlikely, so we don't worry
// about it.
for (auto& var : buffer) {
if (var.defined()) {
auto device = var.device();
if (device.type() != at::kCPU) {
return device;
}
}
}
// Only report to the CPU thread if there really were no tensors
// from other devices.
return at::kCPU;
}
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
std::vector<Variable> result = std::move(g.buffer);
return result;

View File

@ -30,8 +30,6 @@ struct InputBuffer {
const std::optional<c10::Stream>& opt_producer_stream,
const std::optional<c10::Stream>& opt_consumer_stream);
at::Device device() const;
Variable operator[](size_t pos) {
return buffer[pos];
}