mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#157566)
inside torch.compile.disable function always triggers recompilation. because a user inside function decorated with torch._dynamo.disable would be used as an argument in the resume_in_xx function. In the current implementation, it will always be a new object, resulting in the ID_MATCH guard always failing and triggering recompilation. Fixes https://github.com/pytorch/pytorch/issues/157399 @xmfan Pull Request resolved: https://github.com/pytorch/pytorch/pull/157566 Approved by: https://github.com/mlazos, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
a76147c9e0
commit
8e07c9870d
@ -8556,15 +8556,64 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
|
||||
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
|
||||
|
||||
def test_error_on_recompile(self):
|
||||
def test_recompile_on_disable_1(self):
|
||||
# fix https://github.com/pytorch/pytorch/issues/157399
|
||||
@torch.compile(backend="eager")
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
def fn(x):
|
||||
@torch._dynamo.disable
|
||||
def inner(x):
|
||||
return x + 10
|
||||
|
||||
return inner(x) + 1
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
try:
|
||||
for i in range(5):
|
||||
fn(torch.rand(2, 3))
|
||||
except torch._dynamo.exc.RecompileError as e:
|
||||
self.fail("RecompileError raised unexpectedly: " + str(e))
|
||||
|
||||
def test_recompile_on_disable_2(self):
|
||||
def outer(x, cond):
|
||||
@torch._dynamo.disable()
|
||||
def fn0(y):
|
||||
return y + 1
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def fn1(y):
|
||||
return y + 2
|
||||
|
||||
if cond:
|
||||
f = fn0
|
||||
else:
|
||||
f = fn1
|
||||
|
||||
torch._dynamo.graph_break()
|
||||
# there will be a resume function here
|
||||
return f(x)
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
||||
fn(torch.rand(2, 3), torch.rand(2, 3))
|
||||
fn(torch.rand(2, 3), (1, 2, 3))
|
||||
x = torch.rand(2, 3)
|
||||
self.assertEqual(outer(x, True), torch.compile(outer)(x, True))
|
||||
self.assertEqual(outer(x, False), torch.compile(outer)(x, False))
|
||||
|
||||
def test_create_nested_fn_cache_clear(self):
|
||||
def outer(x):
|
||||
@torch._dynamo.disable()
|
||||
def f(y):
|
||||
return y + 2
|
||||
|
||||
return f(x) + 1
|
||||
|
||||
outer = torch.compile(outer)
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
||||
outer(torch.randn(3, 3))
|
||||
from torch._dynamo.utils import create_nested_fn_cache
|
||||
|
||||
create_nested_fn_cache.clear()
|
||||
outer(torch.randn(3, 3))
|
||||
|
||||
def test_guards_strip_function_call(self):
|
||||
from torch._dynamo.guards import strip_function_call
|
||||
|
@ -610,6 +610,8 @@ class TestAutograd(TestCase):
|
||||
|
||||
with disable_gc():
|
||||
unpack_hook_ref = scope()
|
||||
if torch._dynamo.is_compiling():
|
||||
torch._dynamo.reset()
|
||||
self.assertIsNone(unpack_hook_ref())
|
||||
|
||||
def test_will_engine_execute_node(self):
|
||||
|
@ -51,6 +51,7 @@ from .mutation_guard import GenerationTracker
|
||||
from .pgo import reset_code_state
|
||||
from .symbolic_convert import TensorifyState
|
||||
from .utils import (
|
||||
create_nested_fn_cache,
|
||||
graph_break_reasons,
|
||||
guard_failures,
|
||||
orig_code_map,
|
||||
@ -144,6 +145,7 @@ def reset() -> None:
|
||||
torch._dynamo.utils.warn_once_cache.clear()
|
||||
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
||||
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
||||
create_nested_fn_cache.clear()
|
||||
|
||||
|
||||
def reset_code_caches() -> None:
|
||||
|
@ -4780,3 +4780,22 @@ def get_traced_code() -> Optional[list[CodeType]]:
|
||||
from torch._guards import TracingContext
|
||||
|
||||
return TracingContext.get_traced_code()
|
||||
|
||||
|
||||
class CreateNestedFnCache:
|
||||
cache: dict[str, types.FunctionType] = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, key):
|
||||
return cls.cache.get(key, None)
|
||||
|
||||
@classmethod
|
||||
def set(cls, key, value):
|
||||
cls.cache[key] = value
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
cls.cache.clear()
|
||||
|
||||
|
||||
create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache()
|
||||
|
@ -62,6 +62,7 @@ from ..utils import (
|
||||
check_unspec_or_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
counters,
|
||||
create_nested_fn_cache,
|
||||
identity,
|
||||
is_function,
|
||||
is_wrapper_or_member_descriptor,
|
||||
@ -269,6 +270,11 @@ def _create_nested_fn(
|
||||
):
|
||||
from types import FunctionType
|
||||
|
||||
# Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard.
|
||||
cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals))
|
||||
if create_nested_fn_cache.get(cache_key):
|
||||
return create_nested_fn_cache.get(cache_key)
|
||||
|
||||
func = FunctionType(code, f_globals, name, defaults, closure)
|
||||
func.__kwdefaults__ = kwdefaults
|
||||
|
||||
@ -280,7 +286,7 @@ def _create_nested_fn(
|
||||
# TypeError: __annotations__ must be set to a dict object
|
||||
assert annotations is None or isinstance(annotations, dict)
|
||||
func.__annotations__ = annotations
|
||||
|
||||
create_nested_fn_cache.set(cache_key, func)
|
||||
return func
|
||||
|
||||
|
||||
@ -1427,7 +1433,13 @@ class SkipFunctionVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
if not is_wrapper_or_member_descriptor(value):
|
||||
if inspect.getattr_static(value, "_torchdynamo_orig_callable", False):
|
||||
install_guard(
|
||||
AttrSource(source, "_torchdynamo_orig_callable").make_guard(
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
)
|
||||
)
|
||||
elif not is_wrapper_or_member_descriptor(value):
|
||||
# These descriptors are not guaranteed to return the same object on
|
||||
# attribute lookup. They are unlikely to be changed, so we can skip
|
||||
# guarding them.
|
||||
|
Reference in New Issue
Block a user