[dynamo] Fix graph break on calling functions decorated with special context manager (#160703)

As title. This is a follow-up of the previous patch, with the goal of
supporting a new pattern that showed up in ComfyUI:
644b23ac0b/comfy/ops.py (L44)

Effectively, the semantics of calling a function decorated with a
context manager is:

```python
@ctx_manager(args)
def f(x):
    ...

f(x)
# ----->
with ctx_manager(args):
    f.__wrapped__(x)
```

Yes, a fresh context manager instance per invokation, see CPython source code:
https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L119-L122

So Dynamo already
1. knows how to handle the `with ctx_manager(args)` syntax, and has
   special handling for a few torch native context managers, like
   `sdpa_kernel` in this patch.
2. can trace through a good chunk (at least the ones that matter in this
   case) of contextlib.

This patch just let Dynamo trace a bit more into contextlib, and then
keep the torch-native special cases by moving their handling a bit down
the stack, so that no additional logic is introduced -- it's only
refactored.

This also allows us to get rid of some `_sdpa_kernel_variadic` special
handling, since now we will trace through its code, and it boils down to
`sdpa_kernel` anyways.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160703
Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
ghstack dependencies: #160684
This commit is contained in:
Ryan Guo
2025-08-15 13:33:45 -07:00
committed by PyTorch MergeBot
parent 72b559b2c8
commit a1a555ed7b
3 changed files with 63 additions and 21 deletions

View File

@ -1764,6 +1764,33 @@ class GraphModule(torch.nn.Module):
opt_f = torch.compile(f, backend="eager")
opt_f(torch.randn(2, 2))
# Regression test to make sure dynamo won't graph break on calling functions
# decorated with special context manager.
def test_sdpa_kernel_ctx_manager_as_decorator(self):
SDPA_BACKEND_PRIORITY = [
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
]
@torch.nn.attention.sdpa_kernel(
backends=SDPA_BACKEND_PRIORITY, set_priority=True
)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, *args, **kwargs
)
def f(x):
return scaled_dot_product_attention(x, x, x)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.rand(16, 16, 64, 256, dtype=torch.float16)
ref = f(x)
res = opt_f(x)
self.assertEqual(ref, res)
# Regression test to make sure the value of set_priority is used correctly.
def test_sdpa_kernel_ctx_manager_set_priority(self):
backends = [torch.nn.attention.SDPBackend.MATH]

View File

@ -125,8 +125,14 @@ supported_ctx_manager_classes = dict.fromkeys(
torch.autograd.graph.disable_saved_tensors_hooks,
torch.cpu.amp.autocast_mode.autocast,
torch.cuda.amp.autocast_mode.autocast,
torch.nn.attention.sdpa_kernel,
torch.nn.attention._sdpa_kernel_variadic,
# We'll let Dynamo inline into the contextlib part of these context
# manager instances, all the way till it invokes the wrapped function
# itself (at which point we wrap it back to special context manager
# VTs).
#
# This allows us to support calling functions decorated with these
# context managers, without much extra effort or code dup.
torch.nn.attention.sdpa_kernel.__wrapped__, # type: ignore[attr-defined]
]
)
@ -412,18 +418,13 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
return FSDPParamGroupUseTrainingStateVariable.create(
tx, args[0], args[1].as_python_constant()
)
elif self.value is torch.nn.attention.sdpa_kernel:
source = AttrSource(self.source, "__wrapped__") if self.source else None
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
self.value.__wrapped__, tx, source, args, kwargs
self.value, tx, self.source, args, kwargs
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
return SDPAKernelVariable.create(tx, backends, set_priority)
elif self.value is torch.nn.attention._sdpa_kernel_variadic:
return SDPAKernelVariable.create(
tx, [arg.as_python_constant() for arg in args]
)
return super().call_function(tx, args, kwargs)

View File

@ -72,7 +72,6 @@ from ..source import (
UnspecializedParamBufferSource,
)
from ..utils import (
build_checkpoint_variable,
check_constant_args,
cmp_name_to_op_mapping,
dict_methods,
@ -82,7 +81,6 @@ from ..utils import (
is_frozen_dataclass,
is_lru_cache_wrapped_function,
is_namedtuple_cls,
is_utils_checkpoint,
is_wrapper_or_member_descriptor,
istype,
list_methods,
@ -596,6 +594,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
from . import TorchCtxManagerClassVariable
from .functions import (
BaseUserFunctionVariable,
FunctionDecoratedByContextlibContextManagerVariable,
@ -627,7 +626,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
if self.value is contextlib._GeneratorContextManager and isinstance(
args[0], BaseUserFunctionVariable
args[0], (BaseUserFunctionVariable, TorchCtxManagerClassVariable)
):
if not torch._dynamo.config.enable_trace_contextlib:
unimplemented_v2(
@ -638,6 +637,29 @@ class UserDefinedClassVariable(UserDefinedVariable):
"Set torch._dynamo.config.enable_trace_contextlib = True",
],
)
# Special treatments for certain context managers created via
# contextlib, because
# 1. we (pytorch) own their impls
# 2. it's tedious to trace through them, so we effectively
# "allow in graph" them without sacrificing soundness.
#
# We would typically reach here via either
# 1. the instance construction in `with ctx_manager(...):`:
# https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L301
# 2. calling a function decorated with a context manager:
# https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L122
#
# So we basically trace through the surface part of the
# contextlib code, and then special case the shared remaining
# logic (the actual context manager instance construction and
# usage later on).
if isinstance(args[0], TorchCtxManagerClassVariable):
fn_var = args[0]
args_list = args[1].items
kwargs_dict = args[2].keys_as_python_constant()
return fn_var.call_function(tx, args_list, kwargs_dict)
# Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable
# if the function is annotated with @contextlib.contextmanager
# This shouldn't be necessary once generator functions are fully
@ -1309,7 +1331,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)
def var_getattr(self, tx: "InstructionTranslator", name):
from .. import trace_rules
from . import ConstantVariable
source = AttrSource(self.source, name) if self.source else None
@ -1555,14 +1576,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
func, self, source_fn=source_fn, source=source
)
elif inspect.isfunction(dynamic_subobj):
if is_utils_checkpoint(func):
return build_checkpoint_variable(source=source)
elif source is not None:
return trace_rules.lookup(func).create_with_source(
func, source=source
)
else:
return trace_rules.lookup(func)(func)
return VariableTracker.build(tx, func, source)
if (
# wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to