Properly detect nested torch function args (#127496)

Dynamo was not detecting nested torch function classes in containers. This was due to pytree compatibility for variable trackers being removed.
Fixes https://github.com/pytorch/pytorch/issues/127174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127496
Approved by: https://github.com/anijain2305
This commit is contained in:
Michael Lazos
2024-06-02 03:43:22 +00:00
committed by PyTorch MergeBot
parent 16578e8584
commit 2129903aa3
7 changed files with 109 additions and 19 deletions

View File

@ -8,8 +8,9 @@ import pprint
import pickle
import collections
import unittest
import contextlib
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
from torch.overrides import (
handle_torch_function,
has_torch_function,
@ -377,6 +378,27 @@ class TensorLike:
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
class TestTorchFunctionOverride(TestCase):
@classmethod
def setUpClass(cls):
cls._stack = contextlib.ExitStack()
if TEST_WITH_TORCHDYNAMO:
# Add classes to the wrapped tensor subclasses
@contextlib.contextmanager
def setup_subclasses():
old = set(torch._dynamo.config.traceable_tensor_subclasses)
torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
try:
yield
finally:
torch._dynamo.config.traceable_tensor_subclasses.clear()
torch._dynamo.config.traceable_tensor_subclasses.update(old)
cls._stack.enter_context(setup_subclasses())
@classmethod
def tearDownClass(cls):
cls._stack.close()
def test_mean_semantics(self):
"""Test that a function with one argument can be overrided"""
t1 = DiagonalTensor(5, 2)