Compare commits

...

1 Commits

3 changed files with 55 additions and 1 deletions

View File

@ -151,6 +151,28 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
def test_inline_lru_cache_fn_with_default_args(a, b): def test_inline_lru_cache_fn_with_default_args(a, b):
return inline_lru_cache_fn_with_default_args(a, 2, b) return inline_lru_cache_fn_with_default_args(a, 2, b)
def test_lru_cache_warning_issued_during_tracing(self):
import warnings
from functools import lru_cache
@lru_cache
def foo(x):
return x + 1
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.compile(foo, backend="eager")(torch.randn(4))
for warning in w:
warning_message = str(warning.message)
if (
"Dynamo detected a call to a `functools.lru_cache` wrapped function"
in warning_message
):
break
else:
self.assertTrue(False, "Expected warning about lru_cache not found")
@make_test @make_test
def test_add(a, b): def test_add(a, b):
return a + b return a + b

View File

@ -4717,6 +4717,29 @@ class ReproTests(torch._dynamo.test_case.TestCase):
): ):
f_compiled(a) f_compiled(a)
# https://github.com/pytorch/pytorch/issues/146598
@unittest.expectedFailure
def test_lru_cache_tracing(self):
from functools import lru_cache
counter = 0
@lru_cache
def cached_fn(x):
nonlocal counter
counter += 1
return x + 1
compiled_fn = torch.compile(cached_fn, backend="eager")
t = torch.randn(2, 2)
result1 = compiled_fn(t)
self.assertEqual(counter, 1)
result2 = compiled_fn(t)
self.assertEqual(counter, 1)
self.assertEqual(result1, result2)
def test_dont_aggressively_write_assert(self): def test_dont_aggressively_write_assert(self):
record_graph = torch._dynamo.testing.EagerAndRecordGraphs() record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
@ -5431,6 +5454,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
mod = Mod() mod = Mod()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True) opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.randn(4) x = torch.randn(4)
self.assertEqual(mod(x), opt_mod(x)) self.assertEqual(mod(x), opt_mod(x))
def test_enum(self): def test_enum(self):

View File

@ -29,14 +29,15 @@ import inspect
import itertools import itertools
import sys import sys
import types import types
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from types import FunctionType from types import FunctionType
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch from unittest.mock import patch
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
import torch import torch
from typing_extensions import Never
from .. import config, graph_break_hints, polyfills, variables from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
@ -445,6 +446,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
# Handle patch_dynamo_config call # Handle patch_dynamo_config call
if self.fn is torch._dynamo.patch_dynamo_config: if self.fn is torch._dynamo.patch_dynamo_config:
try: try:
args_const = [arg.as_python_constant() for arg in args] args_const = [arg.as_python_constant() for arg in args]
@ -1534,6 +1536,12 @@ class WrapperUserFunctionVariable(VariableTracker):
args: "list[VariableTracker]", args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
if hasattr(self.wrapper_obj, "cache_info"):
warnings.warn(
"Dynamo detected a call to a `functools.lru_cache` wrapped function."
"Dynamo currently ignores `functools.lru_cache` and directly traces the wrapped function."
"`functools.lru_cache` wrapped functions that read outside state may not be traced soundly."
)
return variables.UserFunctionVariable( return variables.UserFunctionVariable(
polyfills.getattr_and_trace polyfills.getattr_and_trace
).call_function( ).call_function(