mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)"
This reverts commit a70ef954b919e990ebaba715b4072e76352867bf. Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/izaitsevfb due to breaks lint ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3391387386))
This commit is contained in:
@ -825,14 +825,6 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
|
||||
display_vmap_fallback_warnings_ = enabled;
|
||||
}
|
||||
|
||||
bool Context::warnOnAccumulateGradStreamMismatch() const {
|
||||
return warn_on_accumulate_grad_stream_mismatch_;
|
||||
}
|
||||
|
||||
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
|
||||
warn_on_accumulate_grad_stream_mismatch_ = enabled;
|
||||
}
|
||||
|
||||
bool Context::isDefaultMobileCPUAllocatorSet() {
|
||||
return prev_allocator_ptr_ != nullptr;
|
||||
}
|
||||
|
@ -401,9 +401,6 @@ class TORCH_API Context {
|
||||
void setDisplayVmapFallbackWarnings(bool enabled);
|
||||
bool areVmapFallbackWarningsEnabled() const;
|
||||
|
||||
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
|
||||
bool warnOnAccumulateGradStreamMismatch() const;
|
||||
|
||||
bool isDefaultMobileCPUAllocatorSet();
|
||||
void setDefaultMobileCPUAllocator();
|
||||
void unsetDefaultMobileCPUAllocator();
|
||||
@ -494,7 +491,6 @@ class TORCH_API Context {
|
||||
bool release_original_weights = false;
|
||||
#endif
|
||||
bool display_vmap_fallback_warnings_ = false;
|
||||
bool warn_on_accumulate_grad_stream_mismatch_ = true;
|
||||
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
||||
bool enable_sparse_tensor_invariant_checks = false;
|
||||
bool allow_fp16_reduction_cpu = false;
|
||||
|
@ -423,10 +423,8 @@ Also see {ref}`saved-tensors-hooks-doc`.
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.autograd.graph.get_gradient_edge
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch
|
||||
|
||||
```
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
@ -13712,53 +13712,6 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||
populate_events()
|
||||
check_ordering()
|
||||
|
||||
# Fails on MPS
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_warn_on_accumulate_grad_stream_mismatch_flag(self):
|
||||
def do_test(suppress_warn, keep_grad_acc):
|
||||
def _test():
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
with torch.Stream(0) as s0:
|
||||
a = torch.ones(8, 8, device="cuda", requires_grad=True)
|
||||
if keep_grad_acc:
|
||||
# create grad_acc under s1 and keep alive with b
|
||||
b = a.clone()
|
||||
|
||||
with torch.Stream(0) as s1:
|
||||
s1.wait_stream(s0)
|
||||
c = a.sum()
|
||||
|
||||
c.backward()
|
||||
|
||||
filter_str = "set_warn_on_accumulate_grad_stream_mismatch"
|
||||
return sum([filter_str in str(w.message) for w in warns]) > 0
|
||||
|
||||
if suppress_warn:
|
||||
try:
|
||||
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||
False
|
||||
)
|
||||
actual_warn = _test()
|
||||
finally:
|
||||
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||
True
|
||||
)
|
||||
else:
|
||||
actual_warn = _test()
|
||||
|
||||
expect_warn = not suppress_warn and keep_grad_acc
|
||||
self.assertEqual(actual_warn, expect_warn)
|
||||
|
||||
# Warn by default
|
||||
self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch())
|
||||
|
||||
for suppress_warn in (True, False):
|
||||
for keep_grad_acc in (True, False):
|
||||
do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc)
|
||||
|
||||
|
||||
class TestMultithreadAutograd(TestCase):
|
||||
def _run_py_multithread_fn(
|
||||
|
@ -1307,8 +1307,6 @@ def _group_tensors_by_device_and_dtype(
|
||||
]: ...
|
||||
def _initCrashHandler() -> None: ...
|
||||
|
||||
def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ...
|
||||
|
||||
# NB: There is no Capsule type in typing, see
|
||||
# https://github.com/python/cpython/issues/109562
|
||||
def _to_dlpack(
|
||||
|
@ -44,7 +44,6 @@ __all__ = [
|
||||
"GradientEdge",
|
||||
"get_gradient_edge",
|
||||
"increment_version",
|
||||
"set_warn_on_accumulate_grad_stream_mismatch",
|
||||
]
|
||||
|
||||
|
||||
@ -438,13 +437,6 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non
|
||||
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
|
||||
|
||||
|
||||
def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None:
|
||||
"""Whether to warn when the AccumulateGrad node's stream does not match the stream
|
||||
of the node that produced the incoming gradient.
|
||||
"""
|
||||
return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled)
|
||||
|
||||
|
||||
class _MultiHandle(RemovableHandle):
|
||||
handles: tuple[RemovableHandle, ...]
|
||||
|
||||
|
@ -1604,32 +1604,6 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"enabled must be a bool, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::globalContext().warnOnAccumulateGradStreamMismatch()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
@ -1847,14 +1821,6 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
THPModule_are_vmap_fallback_warnings_enabled,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_warn_on_accumulate_grad_stream_mismatch",
|
||||
THPModule_set_warn_on_accumulate_grad_stream_mismatch,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_warn_on_accumulate_grad_stream_mismatch",
|
||||
THPModule_warn_on_accumulate_grad_stream_mismatch,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_to_dlpack",
|
||||
castPyCFunctionWithKeywords(THPModule_toDLPack),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
|
@ -1199,11 +1199,7 @@ void Engine::evaluate_function(
|
||||
// Accumulates into buffer
|
||||
auto opt_next_stream = next.function->stream();
|
||||
input_buffer.add(
|
||||
next.input_nr,
|
||||
std::move(output),
|
||||
opt_parent_stream,
|
||||
opt_next_stream,
|
||||
next.function);
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
|
||||
if (is_ready) {
|
||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||
@ -1219,11 +1215,7 @@ void Engine::evaluate_function(
|
||||
// Accumulates into buffer
|
||||
auto opt_next_stream = next.function->stream();
|
||||
input_buffer.add(
|
||||
next.input_nr,
|
||||
std::move(output),
|
||||
opt_parent_stream,
|
||||
opt_next_stream,
|
||||
next.function);
|
||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
||||
if (is_ready) {
|
||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||
queue->push(
|
||||
@ -1376,8 +1368,7 @@ auto Engine::execute(
|
||||
root_edges.at(0).input_nr,
|
||||
std::move(input),
|
||||
input_stream,
|
||||
opt_next_stream,
|
||||
root_edges.at(0).function);
|
||||
opt_next_stream);
|
||||
|
||||
execute_with_graph_task(
|
||||
graph_task, std::move(graph_root), std::move(input_buffer));
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
#include <torch/csrc/autograd/input_buffer.h>
|
||||
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
@ -12,7 +11,6 @@
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Event.h>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <optional>
|
||||
|
||||
#include <cstddef>
|
||||
@ -193,8 +191,7 @@ void InputBuffer::add(
|
||||
size_t pos,
|
||||
Variable&& var,
|
||||
const std::optional<c10::Stream>& opt_producer_stream_,
|
||||
const std::optional<c10::Stream>& opt_consumer_stream_,
|
||||
const std::shared_ptr<Node>& fn) {
|
||||
const std::optional<c10::Stream>& opt_consumer_stream_) {
|
||||
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
||||
|
||||
if (!var.defined()) {
|
||||
@ -234,21 +231,6 @@ void InputBuffer::add(
|
||||
|
||||
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
||||
|
||||
if (*opt_consumer_stream != *opt_producer_stream &&
|
||||
dynamic_cast<AccumulateGrad*>(fn.get()) &&
|
||||
at::globalContext().warnOnAccumulateGradStreamMismatch()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"The AccumulateGrad node's stream does not match the stream of the node that produced "
|
||||
"the incoming gradient. This may incur unnecessary synchronization and break CUDA graph "
|
||||
"capture if the AccumulateGrad node's stream is the default stream. This mismatch is "
|
||||
"caused by an AccumulateGrad node created prior to the current iteration being kept alive. "
|
||||
"This can happen if the autograd graph is still being kept alive by tensors such as the "
|
||||
"loss, or if you are using DDP, which will stash a reference to the node. To resolve the "
|
||||
"mismatch, delete all references to the autograd graph or ensure that DDP initialization is "
|
||||
"performed under the same stream as subsequent forwards. If the mismatch is intentional, "
|
||||
"you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this "
|
||||
"warning.");
|
||||
}
|
||||
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
||||
if (!opt_accum_streams[pos].has_value()) {
|
||||
// [ First producer ]
|
||||
|
@ -32,8 +32,7 @@ struct InputBuffer {
|
||||
size_t pos,
|
||||
Variable&& var,
|
||||
const std::optional<c10::Stream>& opt_producer_stream,
|
||||
const std::optional<c10::Stream>& opt_consumer_stream,
|
||||
const std::shared_ptr<Node>& fn);
|
||||
const std::optional<c10::Stream>& opt_consumer_stream);
|
||||
|
||||
Variable operator[](size_t pos) {
|
||||
return buffer[pos];
|
||||
|
@ -98,8 +98,7 @@ void DistEngine::globalCpuThread(
|
||||
InputBuffer::variables(std::move(task.inputs_))]() mutable {
|
||||
InputBuffer inputs(variables.size());
|
||||
for (const auto i : c10::irange(variables.size())) {
|
||||
inputs.add(
|
||||
i, std::move(variables[i]), std::nullopt, std::nullopt, graphRoot);
|
||||
inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt);
|
||||
}
|
||||
execute_graph_task_until_ready_queue_empty(
|
||||
/*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
|
||||
|
Reference in New Issue
Block a user