Change deprecate warning on dispatch_on_subclass to warn once (#132374)

Summary:
# Problem

`TORCH_WARN` can cause massive log spam.

I output the logs for before and after adding this change.

*Before:*

* The log file size was ~61.15 MB(61148028 bytes).

*After:*

* The log filesize was ~56.44 MB(56444057) bytes.

# Context

Looks like we tried to land this change earlier but it was reverted:

* D59413413
* Reverted https://github.com/pytorch/pytorch/pull/130047 on behalf of https://github.com/clee2000 due to broke test_overrides.py::TestTorchFunctionWarning::test_warn_on_invalid_torch_function

# Testing Update

`test_warn_on_invalid_torch_function` would fail because the warning would not be called on the handling of the second torch function class since `TORCH_WARN_ONCE` stops repeats globally.

Updated so that it runs separate programs. (Was not able to actually run the test, could someone help me with that

Test Plan: Need help with this...

Differential Revision: D60561181

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132374
Approved by: https://github.com/ezyang
This commit is contained in:
Basil Wong
2024-08-05 20:02:33 +00:00
committed by PyTorch MergeBot
parent 2764bee942
commit b7bcfdaff2
2 changed files with 19 additions and 14 deletions

View File

@ -1136,24 +1136,29 @@ class TestResolveName(TestCase):
)
class TestTorchFunctionWarning(TestCase):
def test_warn_on_invalid_torch_function(self):
class Bad1:
def test_warn_on_invalid_torch_function_standalone_class(self):
class StandaloneTorchFunctionClass:
def __torch_function__(self, *args, **kwargs):
pass
a = StandaloneTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(a)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(a)
class Bad2(torch.Tensor):
def test_warn_on_invalid_torch_function_tensor_subclass(self):
class TensorSubclassTorchFunctionClass(torch.Tensor):
def __torch_function__(self, *args, **kwargs):
pass
a = Bad1()
for a in (Bad1(), Bad2()):
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(a)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(a)
b = TensorSubclassTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(b)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(b)
class TestDisabledUserWarnings(TestCase):
def test_no_implicit_user_warning_for_deprecated_functions(self):

View File

@ -308,7 +308,7 @@ static py::object dispatch_on_subclass(
PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
.is(py::handle(arg)) &&
torch_function.ptr() != torch::disabled_torch_function_impl()) {
TORCH_WARN(
TORCH_WARN_ONCE(
"Defining your `",
torch_function_name_str,
"` as a plain method is deprecated ",