mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change deprecate warning on dispatch_on_subclass to warn once (#132374)
Summary: # Problem `TORCH_WARN` can cause massive log spam. I output the logs for before and after adding this change. *Before:* * The log file size was ~61.15 MB(61148028 bytes). *After:* * The log filesize was ~56.44 MB(56444057) bytes. # Context Looks like we tried to land this change earlier but it was reverted: * D59413413 * Reverted https://github.com/pytorch/pytorch/pull/130047 on behalf of https://github.com/clee2000 due to broke test_overrides.py::TestTorchFunctionWarning::test_warn_on_invalid_torch_function # Testing Update `test_warn_on_invalid_torch_function` would fail because the warning would not be called on the handling of the second torch function class since `TORCH_WARN_ONCE` stops repeats globally. Updated so that it runs separate programs. (Was not able to actually run the test, could someone help me with that Test Plan: Need help with this... Differential Revision: D60561181 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132374 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2764bee942
commit
b7bcfdaff2
@ -1136,24 +1136,29 @@ class TestResolveName(TestCase):
|
||||
)
|
||||
|
||||
class TestTorchFunctionWarning(TestCase):
|
||||
def test_warn_on_invalid_torch_function(self):
|
||||
class Bad1:
|
||||
def test_warn_on_invalid_torch_function_standalone_class(self):
|
||||
class StandaloneTorchFunctionClass:
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
pass
|
||||
a = StandaloneTorchFunctionClass()
|
||||
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function on the python side
|
||||
torch.nn.functional.dropout(a)
|
||||
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function in C++
|
||||
torch.abs(a)
|
||||
|
||||
class Bad2(torch.Tensor):
|
||||
def test_warn_on_invalid_torch_function_tensor_subclass(self):
|
||||
class TensorSubclassTorchFunctionClass(torch.Tensor):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
a = Bad1()
|
||||
for a in (Bad1(), Bad2()):
|
||||
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function on the python side
|
||||
torch.nn.functional.dropout(a)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function in C++
|
||||
torch.abs(a)
|
||||
b = TensorSubclassTorchFunctionClass()
|
||||
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function on the python side
|
||||
torch.nn.functional.dropout(b)
|
||||
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function in C++
|
||||
torch.abs(b)
|
||||
|
||||
class TestDisabledUserWarnings(TestCase):
|
||||
def test_no_implicit_user_warning_for_deprecated_functions(self):
|
||||
|
@ -308,7 +308,7 @@ static py::object dispatch_on_subclass(
|
||||
PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
|
||||
.is(py::handle(arg)) &&
|
||||
torch_function.ptr() != torch::disabled_torch_function_impl()) {
|
||||
TORCH_WARN(
|
||||
TORCH_WARN_ONCE(
|
||||
"Defining your `",
|
||||
torch_function_name_str,
|
||||
"` as a plain method is deprecated ",
|
||||
|
Reference in New Issue
Block a user