mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo][be] hide warnings without invalidating warnings cache (#158520)
I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning. So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor. There is a behavior change if someone calls a compiled graph with simplefilter("error"): ```python # e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing with warnings.catch_warnings(): warnings.simplefilter("error") # turns all warnings into errors compiled_fn() # will throw if any of the muted warnings fire ``` FIXES https://github.com/pytorch/pytorch/issues/128427 A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520 Approved by: https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
89850bbc07
commit
07c4c2a792
@ -7593,6 +7593,62 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
out2 = torch.compile(model, backend="eager")(input.clone())
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_filter_warnings(self):
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
|
||||
def call_foobar(x):
|
||||
warnings.warn("foobar")
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
call_foobar(x)
|
||||
call_foobar(x)
|
||||
call_foobar(x)
|
||||
call_foobar(x)
|
||||
return call_foobar(x)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
f(x)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(str(w[0].message), "foobar")
|
||||
|
||||
def test_filter_safe_grad_warning(self):
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
y = x * 5 # non-leaf, .grad should warn
|
||||
torch._subclasses.meta_utils.safe_grad(y) # filters out warning
|
||||
|
||||
def unsafe_grad(y):
|
||||
return y.grad
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
unsafe_grad(y) # should still warn, different callsite
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message))
|
||||
|
||||
unsafe_grad(y) # should not warn
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
def test_filter_user_warnings(self):
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
y = x * 5 # non-leaf, .grad should warn
|
||||
|
||||
@torch._dynamo.eval_frame.TorchPatcher.suppress_torch_distributed_warnings
|
||||
def mute_warn(y):
|
||||
return y.grad
|
||||
|
||||
mute_warn(y) # filters out warning
|
||||
|
||||
def unsafe_grad(y):
|
||||
return y.grad
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
unsafe_grad(y) # should still warn, different callsite
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message))
|
||||
|
||||
unsafe_grad(y) # should not warn
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -2273,10 +2273,10 @@ class TorchPatcher:
|
||||
fn: Callable[..., Any],
|
||||
) -> Callable[..., Any]:
|
||||
def inner_fn(*args: Any, **kwargs: Any) -> Any:
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, module="torch.distributed"
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
with torch._logging.hide_warnings(
|
||||
torch._logging._internal.user_warning_filter
|
||||
):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner_fn
|
||||
|
||||
|
@ -532,11 +532,6 @@ def unimplemented_v2(
|
||||
raise Unsupported(msg)
|
||||
|
||||
|
||||
def warning(msg: str) -> None:
|
||||
counters["warnings"][msg] += 1
|
||||
assert msg != os.environ.get("BREAK", False)
|
||||
|
||||
|
||||
# KeyError has special handling for its args
|
||||
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
|
||||
class KeyErrorMsg:
|
||||
|
@ -36,7 +36,6 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
import weakref
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
@ -304,8 +303,7 @@ DimList = list
|
||||
|
||||
|
||||
def safe_has_grad(t):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
||||
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
||||
return hasattr(t, "grad")
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ from ._internal import (
|
||||
dtrace_structured,
|
||||
get_structured_logging_overhead,
|
||||
getArtifactLogger,
|
||||
hide_warnings,
|
||||
LazyString,
|
||||
set_logs,
|
||||
trace_structured,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib.util
|
||||
@ -12,6 +13,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, Optional, Union
|
||||
@ -1156,6 +1158,45 @@ def warning_once(logger_obj, *args, **kwargs) -> None:
|
||||
logger_obj.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool:
|
||||
return "The .grad attribute of a Tensor" not in str(message)
|
||||
|
||||
|
||||
def user_warning_filter(
|
||||
message, category, filename, lineno, file=None, line=None
|
||||
) -> bool:
|
||||
return not category == UserWarning
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hide_warnings(filter_fn=lambda *args, **kwargs: True):
|
||||
"""
|
||||
A context manager that temporarily suppresses warnings,
|
||||
using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning.
|
||||
|
||||
Useful to hide warnings without mutating warnings module state, see:
|
||||
https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162.
|
||||
|
||||
NOTE: Warnings issued under this context will still be cached in the __warningregistry__
|
||||
and count towards the once/default rule. So you should NEVER use this on a user-land function.
|
||||
|
||||
Filter must implement the showwarning API:
|
||||
def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool:
|
||||
return True # show this warning entry
|
||||
"""
|
||||
prior = warnings.showwarning
|
||||
|
||||
def _showwarning(*args, **kwargs):
|
||||
if filter_fn(*args, **kwargs):
|
||||
prior(*args, **kwargs)
|
||||
|
||||
try:
|
||||
warnings.showwarning = _showwarning
|
||||
yield
|
||||
finally:
|
||||
warnings.showwarning = prior
|
||||
|
||||
|
||||
class LazyString(Generic[_P]):
|
||||
def __init__(
|
||||
self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs
|
||||
|
@ -5,7 +5,6 @@ import dataclasses
|
||||
import functools
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
@ -81,8 +80,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
||||
|
||||
|
||||
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
||||
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
||||
return t.grad
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user