From f975bd58af0e8dc4a71001b71a75a361c20a6203 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 10 Oct 2025 17:29:29 +0000 Subject: [PATCH] Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)" This reverts commit a70ef954b919e990ebaba715b4072e76352867bf. Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/izaitsevfb due to breaks lint ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3391387386)) --- aten/src/ATen/Context.cpp | 8 ---- aten/src/ATen/Context.h | 4 -- docs/source/autograd.md | 4 +- test/test_autograd.py | 47 ------------------- torch/_C/__init__.pyi.in | 2 - torch/autograd/graph.py | 8 ---- torch/csrc/Module.cpp | 34 -------------- torch/csrc/autograd/engine.cpp | 15 ++---- torch/csrc/autograd/input_buffer.cpp | 20 +------- torch/csrc/autograd/input_buffer.h | 3 +- .../autograd/engine/dist_engine.cpp | 3 +- 11 files changed, 7 insertions(+), 141 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index facb88c47bd1..3310abfb41d5 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -825,14 +825,6 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) { display_vmap_fallback_warnings_ = enabled; } -bool Context::warnOnAccumulateGradStreamMismatch() const { - return warn_on_accumulate_grad_stream_mismatch_; -} - -void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) { - warn_on_accumulate_grad_stream_mismatch_ = enabled; -} - bool Context::isDefaultMobileCPUAllocatorSet() { return prev_allocator_ptr_ != nullptr; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 5bc265891bff..d0f6ce18862a 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -401,9 +401,6 @@ class TORCH_API Context { void setDisplayVmapFallbackWarnings(bool enabled); bool areVmapFallbackWarningsEnabled() const; - void setWarnOnAccumulateGradStreamMismatch(bool enabled); - bool warnOnAccumulateGradStreamMismatch() const; - bool isDefaultMobileCPUAllocatorSet(); void setDefaultMobileCPUAllocator(); void unsetDefaultMobileCPUAllocator(); @@ -494,7 +491,6 @@ class TORCH_API Context { bool release_original_weights = false; #endif bool display_vmap_fallback_warnings_ = false; - bool warn_on_accumulate_grad_stream_mismatch_ = true; std::atomic quantized_engine = at::QEngine::NoQEngine; bool enable_sparse_tensor_invariant_checks = false; bool allow_fp16_reduction_cpu = false; diff --git a/docs/source/autograd.md b/docs/source/autograd.md index e78b77e4eb45..4218eac05d79 100644 --- a/docs/source/autograd.md +++ b/docs/source/autograd.md @@ -423,10 +423,8 @@ Also see {ref}`saved-tensors-hooks-doc`. ```{eval-rst} .. autofunction:: torch.autograd.graph.get_gradient_edge -``` -```{eval-rst} -.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch + ``` % This module needs to be documented. Adding here in the meantime diff --git a/test/test_autograd.py b/test/test_autograd.py index 58f7ed0526b1..a94a26afdbb8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -13712,53 +13712,6 @@ class TestAutogradStreamSynchronization(TestCase): populate_events() check_ordering() - # Fails on MPS - @skipIfMPS - @unittest.skipIf(not TEST_CUDA, "requires CUDA") - def test_warn_on_accumulate_grad_stream_mismatch_flag(self): - def do_test(suppress_warn, keep_grad_acc): - def _test(): - with warnings.catch_warnings(record=True) as warns: - warnings.simplefilter("always") - - with torch.Stream(0) as s0: - a = torch.ones(8, 8, device="cuda", requires_grad=True) - if keep_grad_acc: - # create grad_acc under s1 and keep alive with b - b = a.clone() - - with torch.Stream(0) as s1: - s1.wait_stream(s0) - c = a.sum() - - c.backward() - - filter_str = "set_warn_on_accumulate_grad_stream_mismatch" - return sum([filter_str in str(w.message) for w in warns]) > 0 - - if suppress_warn: - try: - torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( - False - ) - actual_warn = _test() - finally: - torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( - True - ) - else: - actual_warn = _test() - - expect_warn = not suppress_warn and keep_grad_acc - self.assertEqual(actual_warn, expect_warn) - - # Warn by default - self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch()) - - for suppress_warn in (True, False): - for keep_grad_acc in (True, False): - do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc) - class TestMultithreadAutograd(TestCase): def _run_py_multithread_fn( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ad59b3efd5bd..2f6ad3f6de67 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1307,8 +1307,6 @@ def _group_tensors_by_device_and_dtype( ]: ... def _initCrashHandler() -> None: ... -def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ... - # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 def _to_dlpack( diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 888fd1e6e140..7fcc5e4b8769 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -44,7 +44,6 @@ __all__ = [ "GradientEdge", "get_gradient_edge", "increment_version", - "set_warn_on_accumulate_grad_stream_mismatch", ] @@ -438,13 +437,6 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) -def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None: - """Whether to warn when the AccumulateGrad node's stream does not match the stream - of the node that produced the incoming gradient. - """ - return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled) - - class _MultiHandle(RemovableHandle): handles: tuple[RemovableHandle, ...] diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ef3a05a7649d..4f99fa40bc6c 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1604,32 +1604,6 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled( END_HANDLE_TH_ERRORS } -static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch( - PyObject* _unused, - PyObject* arg) { - HANDLE_TH_ERRORS - TORCH_CHECK( - PyBool_Check(arg), - "enabled must be a bool, " - "but got ", - THPUtils_typename(arg)); - at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch( - PyObject* _unused, - PyObject* noargs) { - HANDLE_TH_ERRORS - if (at::globalContext().warnOnAccumulateGradStreamMismatch()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - static PyObject* THCPModule_ensureCUDADeviceGuardSet( PyObject* self, PyObject* noargs) { @@ -1847,14 +1821,6 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, - {"_set_warn_on_accumulate_grad_stream_mismatch", - THPModule_set_warn_on_accumulate_grad_stream_mismatch, - METH_O, - nullptr}, - {"_warn_on_accumulate_grad_stream_mismatch", - THPModule_warn_on_accumulate_grad_stream_mismatch, - METH_NOARGS, - nullptr}, {"_to_dlpack", castPyCFunctionWithKeywords(THPModule_toDLPack), METH_VARARGS | METH_KEYWORDS, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 556c2c6ad17e..f92af4994fd5 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1199,11 +1199,7 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, - std::move(output), - opt_parent_stream, - opt_next_stream, - next.function); + next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); @@ -1219,11 +1215,7 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, - std::move(output), - opt_parent_stream, - opt_next_stream, - next.function); + next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); queue->push( @@ -1376,8 +1368,7 @@ auto Engine::execute( root_edges.at(0).input_nr, std::move(input), input_stream, - opt_next_stream, - root_edges.at(0).function); + opt_next_stream); execute_with_graph_task( graph_task, std::move(graph_root), std::move(input_buffer)); diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 3224860592ca..63ca5daedd23 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -12,7 +11,6 @@ #include #include #include -#include #include #include @@ -193,8 +191,7 @@ void InputBuffer::add( size_t pos, Variable&& var, const std::optional& opt_producer_stream_, - const std::optional& opt_consumer_stream_, - const std::shared_ptr& fn) { + const std::optional& opt_consumer_stream_) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); if (!var.defined()) { @@ -234,21 +231,6 @@ void InputBuffer::add( TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream); - if (*opt_consumer_stream != *opt_producer_stream && - dynamic_cast(fn.get()) && - at::globalContext().warnOnAccumulateGradStreamMismatch()) { - TORCH_WARN_ONCE( - "The AccumulateGrad node's stream does not match the stream of the node that produced " - "the incoming gradient. This may incur unnecessary synchronization and break CUDA graph " - "capture if the AccumulateGrad node's stream is the default stream. This mismatch is " - "caused by an AccumulateGrad node created prior to the current iteration being kept alive. " - "This can happen if the autograd graph is still being kept alive by tensors such as the " - "loss, or if you are using DDP, which will stash a reference to the node. To resolve the " - "mismatch, delete all references to the autograd graph or ensure that DDP initialization is " - "performed under the same stream as subsequent forwards. If the mismatch is intentional, " - "you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this " - "warning."); - } // See Note: [Autograd Producer-Consumer Stream Syncs] if (!opt_accum_streams[pos].has_value()) { // [ First producer ] diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index d4928935aa6b..89abd91f4912 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -32,8 +32,7 @@ struct InputBuffer { size_t pos, Variable&& var, const std::optional& opt_producer_stream, - const std::optional& opt_consumer_stream, - const std::shared_ptr& fn); + const std::optional& opt_consumer_stream); Variable operator[](size_t pos) { return buffer[pos]; diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index b1347c3f715b..3743476c7a52 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -98,8 +98,7 @@ void DistEngine::globalCpuThread( InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { - inputs.add( - i, std::move(variables[i]), std::nullopt, std::nullopt, graphRoot); + inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),