[dynamo] context manager/decorator for dynamo config patching during tracing (#150586)

Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code.

The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature.

Implementation:
- Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch
- Dynamo doesn't trace into the context manager but updates config at compile time
- Correctness is based on our correctness for handling supported context managers
- Implementation is inspired by how `GradModeVariable` is implemented.

Previous attempts: https://github.com/pytorch/pytorch/pull/148736 (decorator-only global approach) and https://github.com/pytorch/pytorch/pull/149439 (decorator-only traceback approach)

See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches.

NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150586
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
This commit is contained in:
William Wen
2025-04-22 20:28:34 +00:00
committed by PyTorch MergeBot
parent 62b5649b76
commit 5b9df57b50
19 changed files with 474 additions and 42 deletions

View File

@ -1534,6 +1534,148 @@ If the above doesn't work, please subtmit an issue to GitHub.
with torch.compiler.set_stance("default", force_backend=fail_backend):
f(torch.randn(3, 3))
# also tests a lot of torch._dynamo.patch_dynamo_config functionality
def test_dont_skip_tracing(self):
from torch._dynamo.test_dont_skip_tracing_functions import f1, f3, f4, f5, f6
cnts = torch._dynamo.testing.CompileCounter()
# make sure test_dont_skip_tracing_functions is actually skipped by trace rules
torch.compile(f1, backend=cnts)(torch.randn(3))
self.assertEqual(cnts.frame_count, 0)
f1_unskip = torch._dynamo.dont_skip_tracing(f1)
# basic test
def g1(x):
return f1_unskip(x)
cnts.clear()
torch.compile(g1, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is traceable
def g2(x):
return torch._dynamo.dont_skip_tracing(f1)(x)
cnts.clear()
torch.compile(g2, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is recursive, applied to non-skipped function
@torch._dynamo.dont_skip_tracing
def g3(x):
return f1(x)
cnts.clear()
torch.compile(g3, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is recursive, applied to skipped function
f3_unskip = torch._dynamo.dont_skip_tracing(f3)
cnts.clear()
torch.compile(f3_unskip, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test dont_skip_tracing with graph breaks
inp = torch.ones(3)
res = torch.compile(f4, backend=cnts)(inp)
self.assertEqual(res, inp + 6)
@torch.compile(backend=cnts)
def g4(x):
x = f5(x, 1)
x = torch._dynamo.dont_skip_tracing(f6)(x)
x = f5(x, 8)
return x
res = g4(inp)
self.assertEqual(res, inp + 6)
# test nested dont_skip_tracing
# this also happens to test if a previously skipped frame (f4)
# can actually be compiled if called as a top-level function (in the case of a graph break)
# TODO the reset is necessary for now since attempting to trace f4 previously
# resulted in an unconditional skip
torch._dynamo.reset()
f4_unskip = torch._dynamo.dont_skip_tracing(f4)
res = torch.compile(f4_unskip, backend=cnts)(inp)
self.assertEqual(res, inp + 15)
# test dont_skip_tracing that is activated outside torch.compile
f4_unskip2 = torch._dynamo.dont_skip_tracing(torch.compile(f4, backend=cnts))
res = f4_unskip2(inp)
self.assertEqual(res, inp + 15)
# test context manager from inside
@torch.compile(backend=cnts)
def g5(x):
x = f5(x, 1)
with torch._dynamo.dont_skip_tracing():
x = f5(x, 2)
torch._dynamo.graph_break()
x = f5(x, 4)
x = f5(x, 8)
return x
res = g5(inp)
self.assertEqual(res, inp + 6)
# test context manager from outside
with torch._dynamo.dont_skip_tracing():
res = torch.compile(f4, backend=cnts)(inp)
self.assertEqual(res, inp + 15)
# test skipped function from different dont_skip_tracing regions
@torch.compile(backend=cnts)
def g6(x):
fn1 = f5
with torch._dynamo.dont_skip_tracing():
fn2 = f5
x = fn1(x, 1)
x = fn2(x, 2)
return x
res = g6(inp)
self.assertEqual(res, inp + 1)
def test_patch_dynamo_config_errors(self):
@torch.compile(backend="eager")
def f1(x):
with torch._dynamo.patch_dynamo_config(nonexistent=False):
return x + 1
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
f1(torch.randn(3))
@torch.compile(backend="eager")
def f2(x):
with torch._dynamo.patch_dynamo_config("verbose", {"a": 1}):
return x + 1
with self.assertRaisesRegex(
Exception, "patch_dynamo_config does not support .* with non-safe-constant"
):
f2(torch.randn(3))
@torch.compile(backend="eager")
def f3(x):
with torch._dynamo.patch_dynamo_config({"recompile_limit": 1}):
return x + 1
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
f3(torch.randn(3))
@torch.compile(backend="eager")
def f4(x):
with torch._dynamo.patch_dynamo_config(verbose=object()):
return x + 1
with self.assertRaisesRegex(
Exception, "Cannot convert patch_dynamo_config args/kwargs to constants."
):
f4(torch.randn(3))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -309,7 +309,7 @@ from user code:
Attempted to call function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `skip` in file `case.py` should not be traced.
Hint: Avoid calling the function `skip`.
Hint: Remove the function `skip` or the file `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `skip` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: module: unittest.case, qualname: skip, skip reason: <missing reason>
@ -358,7 +358,7 @@ from user code:
Attempted to inline function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `skip` should not be traced.
Hint: Avoid calling the function `skip`.
Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `skip` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest

View File

@ -20,6 +20,7 @@ from .decorators import (
assume_constant_result,
disable,
disallow_in_graph,
dont_skip_tracing,
forbid_in_graph,
graph_break,
mark_dynamic,
@ -27,6 +28,7 @@ from .decorators import (
mark_static_address,
maybe_mark_dynamic,
nonstrict_trace,
patch_dynamo_config,
run,
set_stance,
substitute_in_graph,
@ -57,6 +59,7 @@ __all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"dont_skip_tracing",
"forbid_in_graph",
"substitute_in_graph",
"graph_break",
@ -67,6 +70,7 @@ __all__ = [
"nonstrict_trace",
"optimize",
"optimize_assert",
"patch_dynamo_config",
"export",
"explain",
"run",

View File

@ -38,7 +38,7 @@ verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
verify_correctness = False
# need this many ops to create an FX graph
# need this many ops to create an FX graph (deprecated: not used)
minimum_call_count = 1
# turn on/off DCE pass (deprecated: always true)
@ -322,6 +322,8 @@ do_not_emit_runtime_asserts: bool = (
# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
skip_torchrec = True
# Don't apply most trace_rules.py rules
dont_skip_tracing = False
# No longer used
optimize_ddp_lazy_compile = False

View File

@ -408,11 +408,16 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile"
)
prof = cProfile.Profile()
prof.enable()
start_ts = time.time()
retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts
prof.disable()
try:
prof.enable()
start_ts = time.time()
retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts
prof.disable()
except ValueError:
log.exception("failed to enable cProfile")
profile_latency = 0
retval = func(*args, **kwargs)
log.warning(
"### Cprofile for %s trace id [%s] took %.3f seconds ###",
func.__name__,
@ -1226,6 +1231,7 @@ class ConvertFrame:
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
skip: int = 0,
) -> ConvertFrameReturn:
input_codes.add(frame.f_code)
counters["frames"]["total"] += 1
try:
result = self._inner_convert(
@ -1385,6 +1391,8 @@ class CatchErrorsWrapper:
) -> ConvertFrameReturn:
assert frame_state is not None
input_codes.add(frame.f_code)
is_skipfile = trace_rules.check(frame.f_code)
if sys.version_info >= (3, 13):
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)

View File

@ -10,7 +10,7 @@ import inspect
import sys
import weakref
from dataclasses import dataclass
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -30,7 +30,11 @@ from .eval_frame import (
skip_code,
)
from .exc import IncorrectUsage
from .external_utils import get_nonrecursive_disable_wrapper, is_compiling
from .external_utils import (
_dynamo_config_patch_proxy_dunder_call,
get_nonrecursive_disable_wrapper,
is_compiling,
)
from .utils import is_function
@ -732,3 +736,105 @@ def _allow_in_graph_einops():
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)
# Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators
# created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch.
class DynamoConfigPatchProxy:
def __init__(self, config_patch):
self.config_patch = config_patch
@property
def changes(self):
return self.config_patch.changes
# Decorator implementation that simply sets up `self` as a context manager.
# Placed in external_utils so that we can trace through it.
__call__ = _dynamo_config_patch_proxy_dunder_call
def __enter__(self):
return self.config_patch.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return self.config_patch.__exit__(exc_type, exc_val, exc_tb)
# Criteria for patchable config:
# - Config values must be constants (i.e. int, float, str, bool, None).
# - in particular, NO list, set, dict.
# - Traceable config patches are only useful for configs that change dynamo behavior
# from symbolic_convert and below.
# - e.g. patching recompile_limit won't really do anything.
# - For patching configs that affect Dynamo behavior above symbolic_convert,
# ensure that Dynamo behaves soundly even if tracing is done with different config.
# - e.g. be careful if patching guard-related configs as configs may have changed
# between guard creation and evaluation.
_allowed_config_patches = (
"verbose",
"verify_correctness",
"rewrite_assert_with_torch_assert",
"capture_scalar_outputs",
"allow_unspec_int_on_nn_module",
"skip_torchrec",
"dont_skip_tracing",
)
for name in _allowed_config_patches:
assert hasattr(torch._dynamo.config, name), "nonexistent config"
def _patch_dynamo_config_check(changes: dict[str, Any]):
for k, v in changes.items():
if k not in _allowed_config_patches:
raise ValueError(
f"patch_dynamo_config does not support patching config {k}"
)
if not torch._dynamo.utils.is_safe_constant(v):
raise ValueError(
f"patch_dynamo_config does not support patching config {k} "
f"with non-safe-constant value {v}"
)
# TODO: also implement nonrecursive patch_dynamo_config/dont_skip_tracing.
# Unlike config.patch, we also need to accept tuple as input in order to
# deal with context manager reconstruction.
def patch_dynamo_config(
arg1: Optional[Union[str, dict[str, Any], tuple[tuple[str, Any], ...]]] = None,
arg2: Any = None,
**kwargs: Any,
) -> DynamoConfigPatchProxy:
"""
A wrapper around torch._dynamo.config.patch that can be traced by Dynamo to
temporarily change config values DURING tracing.
See _allowed_config_patches for the list of allowed config patches.
Arguments are the same as with torch._dynamo.confing.patch.
Can be used as a decorator or a context manager.
User code SHOULD NOT MODIFY the return value of this function.
WARNING: changing Dynamo config during tracing can lead to unpredictable tracing behavior!
Proceed only as advised!
"""
if isinstance(arg1, tuple):
arg1 = dict(arg1)
config_patch = torch._dynamo.config.patch(arg1, arg2, **kwargs)
_patch_dynamo_config_check(config_patch.changes)
# check for valid patching using config_patch.changes
return DynamoConfigPatchProxy(config_patch)
def dont_skip_tracing(fn=None):
"""
Context manager/decorator to trace into functions intentionally marked by developers to be skipped
when tracing.
This decorator will also apply to recursively invoked functions.
"""
ctx = patch_dynamo_config(dont_skip_tracing=True)
if fn:
return ctx(fn)
return ctx

View File

@ -198,3 +198,14 @@ def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
return fn(*args, **kwargs)
return nonrecursive_disable_wrapper
def _dynamo_config_patch_proxy_dunder_call(
self: Any, func: Callable[_P, _R]
) -> Callable[_P, _R]:
@functools.wraps(func)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with self:
return func(*args, **kwargs)
return inner

View File

@ -3730,6 +3730,22 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
hints=[],
)
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
func.get_function(), "_torchdynamo_disable", False
):
msg = inspect.getattr_static(
func.get_function(), "_torchdynamo_disable_msg", None
)
unimplemented_v2(
gb_type="Skip inlining `torch.compiler.disable()`d function",
context=str(func.get_function()),
explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
f"with `torch.compiler.disable` (reason: {msg})",
hints=[
"Remove the `torch.compiler.disable` call",
],
)
result = trace_rules.check_verbose(func, is_inlined_call=True)
if result.skipped:
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
@ -3749,11 +3765,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
]
if "_dynamo" not in func.get_filename():
hints += [
f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` "
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
"attempting to trace into the function.",
f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{fn_qualname}` "
"to force tracing into the function. "
"More graph breaks may occur as a result of attempting to trace into the function.",
"Please file an issue to PyTorch.",
# TODO suggest mark_force_inline when implemented
]
unimplemented_v2(
gb_type="Attempted to inline function marked as skipped",
@ -3764,23 +3779,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
hints=hints,
)
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
func.get_function(), "_torchdynamo_disable", False
):
msg = inspect.getattr_static(
func.get_function(), "_torchdynamo_disable_msg", None
)
unimplemented_v2(
gb_type="Skip inlining `torch.compiler.disable()`d function",
context=str(func.get_function()),
explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
f"with `torch.compiler.disable` (reason: {msg})",
hints=[
"Remove the `torch.compiler.disable` call",
],
)
else:
return result
return result
@staticmethod
def build_inline_tracer(

View File

@ -0,0 +1,40 @@
"""
Functions used to test torch._dynamo.dont_skip_tracing.
This file is located in torch/_dynamo so that it is skipped by trace rules.
There is a special rule in trace_rules that doesn't skip this file when
dont_skip_tracing is active.
"""
import torch
def f1(x: torch.Tensor) -> torch.Tensor:
return x + 1
def f2(x: torch.Tensor) -> torch.Tensor:
return x + 1
def f3(x: torch.Tensor) -> torch.Tensor:
return f2(x)
def f4(x: torch.Tensor) -> torch.Tensor:
x = f5(x, 1)
x = torch._dynamo.dont_skip_tracing(f6)(x)
x = f5(x, 8)
return x
def f5(x: torch.Tensor, n: int) -> torch.Tensor:
if torch.compiler.is_compiling():
return x + n
return x
def f6(x: torch.Tensor) -> torch.Tensor:
x = f5(x, 2)
torch._dynamo.graph_break()
x = f5(x, 4)
return x

View File

@ -28,6 +28,7 @@ import torch.utils._content_store
from torch._environment import is_fbcode
from torch.utils import _config_module
from . import config
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .variables import (
@ -306,8 +307,10 @@ manual_torch_name_rule_map: dict[str, Any] = {
"torch._tensor._convert": UserFunctionVariable,
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
"torch._dynamo.dont_skip_tracing": UserFunctionVariable,
"torch._dynamo.mark_static": UserFunctionVariable,
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
"torch._dynamo.patch_dynamo_config": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
@ -3261,7 +3264,7 @@ if torch.distributed.is_available():
# the forward_hook won't be ignored.
"torch.distributed._composable.replicate",
}
if not torch._dynamo.config.skip_fsdp_hooks:
if not config.skip_fsdp_hooks:
LEGACY_MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
# Force inline functions under these modules, even they are in *_SKIPLIST.
@ -3324,7 +3327,7 @@ MOD_INLINELIST = set(MOD_INLINELIST)
if torch.distributed.is_available():
MOD_INLINELIST.add("torch.distributed")
if not torch._dynamo.config.skip_fsdp_hooks:
if not config.skip_fsdp_hooks:
MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
@ -3575,7 +3578,7 @@ def check_file(filename, is_inlined_call=False):
if (
is_fbcode()
and torch._dynamo.config.skip_torchrec
and config.skip_torchrec
and FBCODE_SKIP_TORCHREC_DIRS
and bool(FBCODE_SKIP_TORCHREC_DIRS_RE.match(filename))
and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename))
@ -3752,12 +3755,48 @@ def lookup(obj):
return lookup_inner(obj)
# also takes config.dont_skip_tracing into account
def lookup_inner(
obj,
name=None,
filename=None,
is_direct_call=True,
reasons: Union[None, set[str]] = None,
):
result = _lookup_inner(
obj,
name=name,
filename=filename,
is_direct_call=is_direct_call,
reasons=reasons,
)
# There are still some modules we should absolutely NOT trace into - e.g. most of torch._dynamo,
# as this can result in really weird tracing behaviors.
# Note that if a torch._dynamo function is already not skipped (e.g. functions in external_utils.py),
# then this branch does not apply.
if config.dont_skip_tracing and result is SkipFunctionVariable:
if filename is None:
filename = getfile(obj)
filename = _as_posix_path(filename)
dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo"
if filename.startswith(dynamo_path) and not filename.endswith(
"test_dont_skip_tracing_functions.py"
):
return SkipFunctionVariable
if reasons is not None:
reasons.add(
"Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing"
)
return UserFunctionVariable
return result
def _lookup_inner(
obj,
name=None,
filename=None,
is_direct_call=True,
reasons: Union[None, set[str]] = None,
):
# Step 1: lookup obj's tracing rule in `torch_name_rule_map`.
# The rules defined in `torch_name_rule_map` mainly includes two parts:

View File

@ -4531,6 +4531,7 @@ def set_feature_use(feature: str, usage: bool):
_ddp_optimization_mode: tuple[str, ...] = (
"ddp_optimizer",
"python_reducer", # experimental mode
"python_reducer_without_compiled_forward",
"no_optimization",
)

View File

@ -26,6 +26,7 @@ from .ctx_manager import (
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
DynamoConfigPatchVariable,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
@ -164,6 +165,7 @@ __all__ = [
"DeletedVariable",
"DeterministicAlgorithmsVariable",
"DictKeySetVariable",
"DynamoConfigPatchVariable",
"EnumVariable",
"FakeItemVariable",
"GetAttrVariable",

View File

@ -154,6 +154,7 @@ from .base import (
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
AutocastModeVariable,
DynamoConfigPatchVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
@ -594,6 +595,8 @@ class VariableBuilder:
# import here to avoid circular dependencies
from torch.utils._triton import has_triton, has_triton_tma
from ..decorators import DynamoConfigPatchProxy
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
@ -911,6 +914,8 @@ class VariableBuilder:
{},
)
)
elif isinstance(value, DynamoConfigPatchProxy):
return DynamoConfigPatchVariable(value.changes)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if trace_rules.is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True

View File

@ -1366,6 +1366,47 @@ class EventVariable(VariableTracker):
codegen.append_output(codegen.create_load_global(name, add=True))
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""
# NOTE: no need to guard on dynamo config because dynamo config should not affect soundness
# (though it may affect tracing behavior)
def __init__(self, target_values, **kwargs) -> None:
target_values = tuple(target_values.items())
super().__init__(target_values=(target_values,), initial_values=None, **kwargs)
self.initial_values = {}
for key, _ in target_values:
self.initial_values[key] = torch._dynamo.config.__getattr__(key)
self.initial_values = (tuple(self.initial_values.items()),)
def enter(self, tx):
# resets all config patches at the end of tracing
self.set_cleanup_hook(tx)
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
# manually patch dynamo config
for key, val in value:
torch._dynamo.config.__setattr__(key, val)
# No need to keep track of global side effects because
# dynamo will properly restore this context manager for
# unsupported instructions and continuation functions.
# Dynamo config also should not affect the semantics of the compiled graph.
def module_name(self):
return "torch._dynamo"
def fn_name(self):
return "patch_dynamo_config"
class WithExitFunctionVariable(VariableTracker):
_nonvar_fields = {
"target",

View File

@ -36,7 +36,7 @@ from unittest.mock import patch
import torch
from .. import graph_break_hints, polyfills, variables
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
get_dynamo_observed_exception,
@ -64,7 +64,12 @@ from ..utils import (
istype,
make_cell,
)
from .base import AttributeMutationNew, ValueMutationNew, VariableTracker
from .base import (
AsPythonConstantNotImplementedError,
AttributeMutationNew,
ValueMutationNew,
VariableTracker,
)
from .constant import ConstantVariable
@ -372,6 +377,23 @@ class UserFunctionVariable(BaseUserFunctionVariable):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle patch_dynamo_config call
if self.fn is torch._dynamo.patch_dynamo_config:
try:
args_const = [arg.as_python_constant() for arg in args]
kwargs_const = {
key: val.as_python_constant() for key, val in kwargs.items()
}
changes = torch._dynamo.patch_dynamo_config(
*args_const, **kwargs_const
).changes
return variables.DynamoConfigPatchVariable(changes)
except AsPythonConstantNotImplementedError as e:
raise RuntimeError(
"Cannot convert patch_dynamo_config args/kwargs to constants. "
"Please fix your call to patch_dynamo_config by using simpler inputs. "
f"args: {args}, kwargs: {kwargs}"
) from e
# Handle a `nonstrict_trace(fn)` call
if self.fn is torch._dynamo.nonstrict_trace:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
@ -1297,6 +1319,14 @@ class SkipFunctionVariable(VariableTracker):
torch._dynamo.utils.warn_once(msg)
unimplemented(msg)
else:
if config.dont_skip_tracing:
from .builder import SourcelessBuilder
# re-build the function, attempting to not skip
rebuilt_fn = SourcelessBuilder.create(tx, self.value)
# if we still get SkipFunctionVariable, then we *really* should skip this function
if not isinstance(rebuilt_fn, SkipFunctionVariable):
return rebuilt_fn.call_function(tx, args, kwargs)
qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
try:
path = inspect.getfile(self.value)
@ -1312,11 +1342,10 @@ class SkipFunctionVariable(VariableTracker):
# Do a very basic check for now.
if "_dynamo" not in path:
hints += [
f"Remove the function `{qualname}` or the file `{path}` "
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
"attempting to trace into the function.",
f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` "
"to force tracing into the function. "
"More graph breaks may occur as a result of attempting to trace into the function.",
"Please file an issue to PyTorch.",
# TODO suggest mark_force_inline when implemented
]
except TypeError:
known_python_builtin_modules = {"_abc", "_warnings"}

View File

@ -667,12 +667,15 @@ class ConfigModule(ModuleType):
config = self
class ConfigPatch(ContextDecorator):
def __init__(self) -> None:
self.changes = changes
def __enter__(self) -> None:
assert not prior
for key in changes.keys():
for key in self.changes.keys():
# KeyError on invalid entry
prior[key] = config.__getattr__(key)
for k, v in changes.items():
for k, v in self.changes.items():
config.__setattr__(k, v)
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]