mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Fix graph break on calling functions decorated with special context manager (#160703)
As title. This is a follow-up of the previous patch, with the goal of
supporting a new pattern that showed up in ComfyUI:
644b23ac0b/comfy/ops.py (L44)
Effectively, the semantics of calling a function decorated with a
context manager is:
```python
@ctx_manager(args)
def f(x):
...
f(x)
# ----->
with ctx_manager(args):
f.__wrapped__(x)
```
Yes, a fresh context manager instance per invokation, see CPython source code:
https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L119-L122
So Dynamo already
1. knows how to handle the `with ctx_manager(args)` syntax, and has
special handling for a few torch native context managers, like
`sdpa_kernel` in this patch.
2. can trace through a good chunk (at least the ones that matter in this
case) of contextlib.
This patch just let Dynamo trace a bit more into contextlib, and then
keep the torch-native special cases by moving their handling a bit down
the stack, so that no additional logic is introduced -- it's only
refactored.
This also allows us to get rid of some `_sdpa_kernel_variadic` special
handling, since now we will trace through its code, and it boils down to
`sdpa_kernel` anyways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160703
Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
ghstack dependencies: #160684
This commit is contained in:
committed by
PyTorch MergeBot
parent
72b559b2c8
commit
a1a555ed7b
@ -1764,6 +1764,33 @@ class GraphModule(torch.nn.Module):
|
||||
opt_f = torch.compile(f, backend="eager")
|
||||
opt_f(torch.randn(2, 2))
|
||||
|
||||
# Regression test to make sure dynamo won't graph break on calling functions
|
||||
# decorated with special context manager.
|
||||
def test_sdpa_kernel_ctx_manager_as_decorator(self):
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
torch.nn.attention.SDPBackend.MATH,
|
||||
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||
]
|
||||
|
||||
@torch.nn.attention.sdpa_kernel(
|
||||
backends=SDPA_BACKEND_PRIORITY, set_priority=True
|
||||
)
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, *args, **kwargs
|
||||
)
|
||||
|
||||
def f(x):
|
||||
return scaled_dot_product_attention(x, x, x)
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.rand(16, 16, 64, 256, dtype=torch.float16)
|
||||
ref = f(x)
|
||||
res = opt_f(x)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
# Regression test to make sure the value of set_priority is used correctly.
|
||||
def test_sdpa_kernel_ctx_manager_set_priority(self):
|
||||
backends = [torch.nn.attention.SDPBackend.MATH]
|
||||
|
@ -125,8 +125,14 @@ supported_ctx_manager_classes = dict.fromkeys(
|
||||
torch.autograd.graph.disable_saved_tensors_hooks,
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.nn.attention.sdpa_kernel,
|
||||
torch.nn.attention._sdpa_kernel_variadic,
|
||||
# We'll let Dynamo inline into the contextlib part of these context
|
||||
# manager instances, all the way till it invokes the wrapped function
|
||||
# itself (at which point we wrap it back to special context manager
|
||||
# VTs).
|
||||
#
|
||||
# This allows us to support calling functions decorated with these
|
||||
# context managers, without much extra effort or code dup.
|
||||
torch.nn.attention.sdpa_kernel.__wrapped__, # type: ignore[attr-defined]
|
||||
]
|
||||
)
|
||||
|
||||
@ -412,18 +418,13 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
return FSDPParamGroupUseTrainingStateVariable.create(
|
||||
tx, args[0], args[1].as_python_constant()
|
||||
)
|
||||
elif self.value is torch.nn.attention.sdpa_kernel:
|
||||
source = AttrSource(self.source, "__wrapped__") if self.source else None
|
||||
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
|
||||
name_to_arg_map = bind_args_cached(
|
||||
self.value.__wrapped__, tx, source, args, kwargs
|
||||
self.value, tx, self.source, args, kwargs
|
||||
)
|
||||
backends = name_to_arg_map["backends"].as_python_constant()
|
||||
set_priority = name_to_arg_map["set_priority"].as_python_constant()
|
||||
return SDPAKernelVariable.create(tx, backends, set_priority)
|
||||
elif self.value is torch.nn.attention._sdpa_kernel_variadic:
|
||||
return SDPAKernelVariable.create(
|
||||
tx, [arg.as_python_constant() for arg in args]
|
||||
)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
|
@ -72,7 +72,6 @@ from ..source import (
|
||||
UnspecializedParamBufferSource,
|
||||
)
|
||||
from ..utils import (
|
||||
build_checkpoint_variable,
|
||||
check_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
dict_methods,
|
||||
@ -82,7 +81,6 @@ from ..utils import (
|
||||
is_frozen_dataclass,
|
||||
is_lru_cache_wrapped_function,
|
||||
is_namedtuple_cls,
|
||||
is_utils_checkpoint,
|
||||
is_wrapper_or_member_descriptor,
|
||||
istype,
|
||||
list_methods,
|
||||
@ -596,6 +594,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
from . import TorchCtxManagerClassVariable
|
||||
from .functions import (
|
||||
BaseUserFunctionVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
@ -627,7 +626,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
)
|
||||
|
||||
if self.value is contextlib._GeneratorContextManager and isinstance(
|
||||
args[0], BaseUserFunctionVariable
|
||||
args[0], (BaseUserFunctionVariable, TorchCtxManagerClassVariable)
|
||||
):
|
||||
if not torch._dynamo.config.enable_trace_contextlib:
|
||||
unimplemented_v2(
|
||||
@ -638,6 +637,29 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
"Set torch._dynamo.config.enable_trace_contextlib = True",
|
||||
],
|
||||
)
|
||||
|
||||
# Special treatments for certain context managers created via
|
||||
# contextlib, because
|
||||
# 1. we (pytorch) own their impls
|
||||
# 2. it's tedious to trace through them, so we effectively
|
||||
# "allow in graph" them without sacrificing soundness.
|
||||
#
|
||||
# We would typically reach here via either
|
||||
# 1. the instance construction in `with ctx_manager(...):`:
|
||||
# https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L301
|
||||
# 2. calling a function decorated with a context manager:
|
||||
# https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L122
|
||||
#
|
||||
# So we basically trace through the surface part of the
|
||||
# contextlib code, and then special case the shared remaining
|
||||
# logic (the actual context manager instance construction and
|
||||
# usage later on).
|
||||
if isinstance(args[0], TorchCtxManagerClassVariable):
|
||||
fn_var = args[0]
|
||||
args_list = args[1].items
|
||||
kwargs_dict = args[2].keys_as_python_constant()
|
||||
return fn_var.call_function(tx, args_list, kwargs_dict)
|
||||
|
||||
# Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable
|
||||
# if the function is annotated with @contextlib.contextmanager
|
||||
# This shouldn't be necessary once generator functions are fully
|
||||
@ -1309,7 +1331,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
from .. import trace_rules
|
||||
from . import ConstantVariable
|
||||
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
@ -1555,14 +1576,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
func, self, source_fn=source_fn, source=source
|
||||
)
|
||||
elif inspect.isfunction(dynamic_subobj):
|
||||
if is_utils_checkpoint(func):
|
||||
return build_checkpoint_variable(source=source)
|
||||
elif source is not None:
|
||||
return trace_rules.lookup(func).create_with_source(
|
||||
func, source=source
|
||||
)
|
||||
else:
|
||||
return trace_rules.lookup(func)(func)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
|
||||
if (
|
||||
# wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to
|
||||
|
Reference in New Issue
Block a user