[dynamo][be] hide warnings without invalidating warnings cache (#158520)

I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning.

So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor.

There is a behavior change if someone calls a compiled graph with simplefilter("error"):
```python
# e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing
with warnings.catch_warnings():
    warnings.simplefilter("error")  # turns all warnings into errors
    compiled_fn()  # will throw if any of the muted warnings fire
```

FIXES https://github.com/pytorch/pytorch/issues/128427

A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
Simon Fan
2025-07-17 07:24:50 -07:00
committed by PyTorch MergeBot
parent 89850bbc07
commit 07c4c2a792
11 changed files with 104 additions and 15 deletions

View File

@ -7593,6 +7593,62 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out2 = torch.compile(model, backend="eager")(input.clone())
self.assertEqual(out1, out2)
def test_filter_warnings(self):
x = torch.ones(2, 2, requires_grad=True)
def call_foobar(x):
warnings.warn("foobar")
@torch.compile(backend="eager")
def f(x):
call_foobar(x)
call_foobar(x)
call_foobar(x)
call_foobar(x)
return call_foobar(x)
with warnings.catch_warnings(record=True) as w:
f(x)
self.assertEqual(len(w), 1)
self.assertEqual(str(w[0].message), "foobar")
def test_filter_safe_grad_warning(self):
x = torch.ones(2, 2, requires_grad=True)
y = x * 5 # non-leaf, .grad should warn
torch._subclasses.meta_utils.safe_grad(y) # filters out warning
def unsafe_grad(y):
return y.grad
with warnings.catch_warnings(record=True) as w:
unsafe_grad(y) # should still warn, different callsite
self.assertEqual(len(w), 1)
self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message))
unsafe_grad(y) # should not warn
self.assertEqual(len(w), 1)
def test_filter_user_warnings(self):
x = torch.ones(2, 2, requires_grad=True)
y = x * 5 # non-leaf, .grad should warn
@torch._dynamo.eval_frame.TorchPatcher.suppress_torch_distributed_warnings
def mute_warn(y):
return y.grad
mute_warn(y) # filters out warning
def unsafe_grad(y):
return y.grad
with warnings.catch_warnings(record=True) as w:
unsafe_grad(y) # should still warn, different callsite
self.assertEqual(len(w), 1)
self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message))
unsafe_grad(y) # should not warn
self.assertEqual(len(w), 1)
instantiate_parametrized_tests(ReproTests)

View File

@ -2273,10 +2273,10 @@ class TorchPatcher:
fn: Callable[..., Any],
) -> Callable[..., Any]:
def inner_fn(*args: Any, **kwargs: Any) -> Any:
warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.distributed"
)
return fn(*args, **kwargs)
with torch._logging.hide_warnings(
torch._logging._internal.user_warning_filter
):
return fn(*args, **kwargs)
return inner_fn

View File

@ -532,11 +532,6 @@ def unimplemented_v2(
raise Unsupported(msg)
def warning(msg: str) -> None:
counters["warnings"][msg] += 1
assert msg != os.environ.get("BREAK", False)
# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:

View File

@ -36,7 +36,6 @@ import re
import sys
import traceback
import types
import warnings
import weakref
from collections.abc import MutableMapping
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
@ -304,8 +303,7 @@ DimList = list
def safe_has_grad(t):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
return hasattr(t, "grad")

View File

@ -12,6 +12,7 @@ from ._internal import (
dtrace_structured,
get_structured_logging_overhead,
getArtifactLogger,
hide_warnings,
LazyString,
set_logs,
trace_structured,

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
import functools
import hashlib
import importlib.util
@ -12,6 +13,7 @@ import re
import sys
import tempfile
import time
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, Optional, Union
@ -1156,6 +1158,45 @@ def warning_once(logger_obj, *args, **kwargs) -> None:
logger_obj.warning(*args, **kwargs)
def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool:
return "The .grad attribute of a Tensor" not in str(message)
def user_warning_filter(
message, category, filename, lineno, file=None, line=None
) -> bool:
return not category == UserWarning
@contextlib.contextmanager
def hide_warnings(filter_fn=lambda *args, **kwargs: True):
"""
A context manager that temporarily suppresses warnings,
using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning.
Useful to hide warnings without mutating warnings module state, see:
https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162.
NOTE: Warnings issued under this context will still be cached in the __warningregistry__
and count towards the once/default rule. So you should NEVER use this on a user-land function.
Filter must implement the showwarning API:
def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool:
return True # show this warning entry
"""
prior = warnings.showwarning
def _showwarning(*args, **kwargs):
if filter_fn(*args, **kwargs):
prior(*args, **kwargs)
try:
warnings.showwarning = _showwarning
yield
finally:
warnings.showwarning = prior
class LazyString(Generic[_P]):
def __init__(
self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs

View File

@ -5,7 +5,6 @@ import dataclasses
import functools
import threading
import typing
import warnings
import weakref
from abc import abstractmethod
from contextlib import AbstractContextManager, contextmanager
@ -81,8 +80,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
return t.grad