mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Properly detect nested torch function args (#127496)
Dynamo was not detecting nested torch function classes in containers. This was due to pytree compatibility for variable trackers being removed. Fixes https://github.com/pytorch/pytorch/issues/127174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127496 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
16578e8584
commit
2129903aa3
@ -8,8 +8,9 @@ import pprint
|
||||
import pickle
|
||||
import collections
|
||||
import unittest
|
||||
import contextlib
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function,
|
||||
@ -377,6 +378,27 @@ class TensorLike:
|
||||
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
|
||||
|
||||
class TestTorchFunctionOverride(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._stack = contextlib.ExitStack()
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# Add classes to the wrapped tensor subclasses
|
||||
@contextlib.contextmanager
|
||||
def setup_subclasses():
|
||||
old = set(torch._dynamo.config.traceable_tensor_subclasses)
|
||||
torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.traceable_tensor_subclasses.clear()
|
||||
torch._dynamo.config.traceable_tensor_subclasses.update(old)
|
||||
|
||||
cls._stack.enter_context(setup_subclasses())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._stack.close()
|
||||
|
||||
def test_mean_semantics(self):
|
||||
"""Test that a function with one argument can be overrided"""
|
||||
t1 = DiagonalTensor(5, 2)
|
||||
|
Reference in New Issue
Block a user