mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Determine autograd engine ready queue based on InputMetadata instead of InputBuffer (#135633)
Thanks @awgu for raising this issue and the small repro From offline discussion with @albanD, in the case where a forward returns multiple outputs with different devices, we'd want to select the ready queue based on the device of the first one. Even though this is somewhat arbitrary, we prefer this over deciding which ready queue to push based on whichever input buffer's we happen to compute last, which can vary depending on more factors and thus be harder to reason about. This is in theory bc-breaking, but it seems unlikely that someone would depend on this behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135633 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
79562f3af8
commit
d6f340f66c
@ -2916,6 +2916,7 @@ known_failing_tests = {
|
||||
"test_grad_mode_restored_reentrant", # hangs with graph breaks
|
||||
"test_no_grad_copy", # setting static member in lifted backward
|
||||
"test_no_grad_copy_sparse", # setting static member in lifted backward
|
||||
"test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError <built-in method clone
|
||||
"test_reentrant_priority", # hangs with graph breaks
|
||||
"test_reentrant_with_callbacks_both_depths", # hangs with graph breaks
|
||||
"test_reentrant_with_callbacks_depth_0", # probably hangs with graph breaks
|
||||
|
@ -4379,6 +4379,49 @@ class TestAutograd(TestCase):
|
||||
run_test((10, 10), torch.zeros(10, 10))
|
||||
run_test((10,), 0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
|
||||
def test_node_ordering_when_none_returned(self):
|
||||
class Matmul(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
# x: [M, N]
|
||||
# w: [N, K]
|
||||
ctx.save_for_backward(x, w)
|
||||
return x @ w
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, g_out):
|
||||
# g_out: [M, K]
|
||||
x, w = ctx.saved_tensors
|
||||
g_x = g_out @ w.T
|
||||
g_w = x.T @ g_out
|
||||
w.main_grad = g_w.float()
|
||||
return g_x, None
|
||||
|
||||
executed = []
|
||||
|
||||
class HookFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, g):
|
||||
executed.append("A")
|
||||
return g
|
||||
|
||||
def hook(*args, **kwargs):
|
||||
executed.append("B")
|
||||
|
||||
x = torch.randn((3, 3), dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
x = HookFunction.apply(x)
|
||||
w = torch.randn((3, 3), dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
w.register_hook(hook)
|
||||
o = Matmul.apply(x, w)
|
||||
o.sum().backward()
|
||||
|
||||
self.assertEqual(executed, ["B", "A"])
|
||||
|
||||
def test_current_graph_task_id(self):
|
||||
id = [-1]
|
||||
|
||||
|
@ -1110,7 +1110,7 @@ void Engine::evaluate_function(
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
|
||||
if (is_ready) {
|
||||
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
|
||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||
queue->push(
|
||||
NodeTask(graph_task, next.function, std::move(input_buffer)));
|
||||
} else {
|
||||
@ -1125,7 +1125,7 @@ void Engine::evaluate_function(
|
||||
input_buffer.add(
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
if (is_ready) {
|
||||
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
|
||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||
queue->push(
|
||||
NodeTask(graph_task, next.function, std::move(input_buffer)));
|
||||
not_ready.erase(not_ready_it);
|
||||
@ -1310,7 +1310,7 @@ c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
|
||||
// Lock mutex for GraphTask.
|
||||
std::unique_lock<std::mutex> lock(graph_task->mutex_);
|
||||
|
||||
auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
|
||||
auto queue = ready_queue(graph_task->cpu_ready_queue_, graph_root->device());
|
||||
|
||||
// worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
|
||||
// autograd engine with corresponding GraphTask, and its NOT a re-entrant call
|
||||
|
@ -252,6 +252,23 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Used by the engine to determine what device thread to run on
|
||||
at::Device device() {
|
||||
// Since we pick the first non-CPU tensor, this won't work with
|
||||
// mixed device-type operations (e.g., an op that is both CUDA
|
||||
// and XLA). This is *incredibly* unlikely, so we don't worry
|
||||
// about it.
|
||||
for (const auto& metadata : input_metadata_) {
|
||||
auto device = metadata.device();
|
||||
if (device.type() != at::kCPU) {
|
||||
return device;
|
||||
}
|
||||
}
|
||||
// Only report to the CPU thread if there really were no tensors
|
||||
// from other devices.
|
||||
return at::kCPU;
|
||||
}
|
||||
|
||||
void clear_input_metadata() {
|
||||
input_metadata_.clear();
|
||||
}
|
||||
|
@ -222,24 +222,6 @@ void InputBuffer::add(
|
||||
}
|
||||
}
|
||||
|
||||
auto InputBuffer::device() const -> at::Device {
|
||||
// Since we pick the first non-CPU tensor, this won't work with
|
||||
// mixed device-type operations (e.g., an op that is both CUDA
|
||||
// and XLA). This is *incredibly* unlikely, so we don't worry
|
||||
// about it.
|
||||
for (auto& var : buffer) {
|
||||
if (var.defined()) {
|
||||
auto device = var.device();
|
||||
if (device.type() != at::kCPU) {
|
||||
return device;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Only report to the CPU thread if there really were no tensors
|
||||
// from other devices.
|
||||
return at::kCPU;
|
||||
}
|
||||
|
||||
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
|
||||
std::vector<Variable> result = std::move(g.buffer);
|
||||
return result;
|
||||
|
@ -30,8 +30,6 @@ struct InputBuffer {
|
||||
const std::optional<c10::Stream>& opt_producer_stream,
|
||||
const std::optional<c10::Stream>& opt_consumer_stream);
|
||||
|
||||
at::Device device() const;
|
||||
|
||||
Variable operator[](size_t pos) {
|
||||
return buffer[pos];
|
||||
}
|
||||
|
Reference in New Issue
Block a user