mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] context manager/decorator for dynamo config patching during tracing (#150586)
Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code. The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature. Implementation: - Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch - Dynamo doesn't trace into the context manager but updates config at compile time - Correctness is based on our correctness for handling supported context managers - Implementation is inspired by how `GradModeVariable` is implemented. Previous attempts: https://github.com/pytorch/pytorch/pull/148736 (decorator-only global approach) and https://github.com/pytorch/pytorch/pull/149439 (decorator-only traceback approach) See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches. NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150586 Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
62b5649b76
commit
5b9df57b50
@ -1534,6 +1534,148 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
||||
with torch.compiler.set_stance("default", force_backend=fail_backend):
|
||||
f(torch.randn(3, 3))
|
||||
|
||||
# also tests a lot of torch._dynamo.patch_dynamo_config functionality
|
||||
def test_dont_skip_tracing(self):
|
||||
from torch._dynamo.test_dont_skip_tracing_functions import f1, f3, f4, f5, f6
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
# make sure test_dont_skip_tracing_functions is actually skipped by trace rules
|
||||
torch.compile(f1, backend=cnts)(torch.randn(3))
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
|
||||
f1_unskip = torch._dynamo.dont_skip_tracing(f1)
|
||||
|
||||
# basic test
|
||||
def g1(x):
|
||||
return f1_unskip(x)
|
||||
|
||||
cnts.clear()
|
||||
torch.compile(g1, backend=cnts, fullgraph=True)(torch.randn(3))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# test that dont_skip_tracing is traceable
|
||||
def g2(x):
|
||||
return torch._dynamo.dont_skip_tracing(f1)(x)
|
||||
|
||||
cnts.clear()
|
||||
torch.compile(g2, backend=cnts, fullgraph=True)(torch.randn(3))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# test that dont_skip_tracing is recursive, applied to non-skipped function
|
||||
@torch._dynamo.dont_skip_tracing
|
||||
def g3(x):
|
||||
return f1(x)
|
||||
|
||||
cnts.clear()
|
||||
torch.compile(g3, backend=cnts, fullgraph=True)(torch.randn(3))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# test that dont_skip_tracing is recursive, applied to skipped function
|
||||
f3_unskip = torch._dynamo.dont_skip_tracing(f3)
|
||||
cnts.clear()
|
||||
torch.compile(f3_unskip, backend=cnts, fullgraph=True)(torch.randn(3))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# test dont_skip_tracing with graph breaks
|
||||
inp = torch.ones(3)
|
||||
res = torch.compile(f4, backend=cnts)(inp)
|
||||
self.assertEqual(res, inp + 6)
|
||||
|
||||
@torch.compile(backend=cnts)
|
||||
def g4(x):
|
||||
x = f5(x, 1)
|
||||
x = torch._dynamo.dont_skip_tracing(f6)(x)
|
||||
x = f5(x, 8)
|
||||
return x
|
||||
|
||||
res = g4(inp)
|
||||
self.assertEqual(res, inp + 6)
|
||||
|
||||
# test nested dont_skip_tracing
|
||||
# this also happens to test if a previously skipped frame (f4)
|
||||
# can actually be compiled if called as a top-level function (in the case of a graph break)
|
||||
# TODO the reset is necessary for now since attempting to trace f4 previously
|
||||
# resulted in an unconditional skip
|
||||
torch._dynamo.reset()
|
||||
f4_unskip = torch._dynamo.dont_skip_tracing(f4)
|
||||
res = torch.compile(f4_unskip, backend=cnts)(inp)
|
||||
self.assertEqual(res, inp + 15)
|
||||
|
||||
# test dont_skip_tracing that is activated outside torch.compile
|
||||
f4_unskip2 = torch._dynamo.dont_skip_tracing(torch.compile(f4, backend=cnts))
|
||||
res = f4_unskip2(inp)
|
||||
self.assertEqual(res, inp + 15)
|
||||
|
||||
# test context manager from inside
|
||||
@torch.compile(backend=cnts)
|
||||
def g5(x):
|
||||
x = f5(x, 1)
|
||||
with torch._dynamo.dont_skip_tracing():
|
||||
x = f5(x, 2)
|
||||
torch._dynamo.graph_break()
|
||||
x = f5(x, 4)
|
||||
x = f5(x, 8)
|
||||
return x
|
||||
|
||||
res = g5(inp)
|
||||
self.assertEqual(res, inp + 6)
|
||||
|
||||
# test context manager from outside
|
||||
with torch._dynamo.dont_skip_tracing():
|
||||
res = torch.compile(f4, backend=cnts)(inp)
|
||||
self.assertEqual(res, inp + 15)
|
||||
|
||||
# test skipped function from different dont_skip_tracing regions
|
||||
@torch.compile(backend=cnts)
|
||||
def g6(x):
|
||||
fn1 = f5
|
||||
with torch._dynamo.dont_skip_tracing():
|
||||
fn2 = f5
|
||||
x = fn1(x, 1)
|
||||
x = fn2(x, 2)
|
||||
return x
|
||||
|
||||
res = g6(inp)
|
||||
self.assertEqual(res, inp + 1)
|
||||
|
||||
def test_patch_dynamo_config_errors(self):
|
||||
@torch.compile(backend="eager")
|
||||
def f1(x):
|
||||
with torch._dynamo.patch_dynamo_config(nonexistent=False):
|
||||
return x + 1
|
||||
|
||||
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
|
||||
f1(torch.randn(3))
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f2(x):
|
||||
with torch._dynamo.patch_dynamo_config("verbose", {"a": 1}):
|
||||
return x + 1
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "patch_dynamo_config does not support .* with non-safe-constant"
|
||||
):
|
||||
f2(torch.randn(3))
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f3(x):
|
||||
with torch._dynamo.patch_dynamo_config({"recompile_limit": 1}):
|
||||
return x + 1
|
||||
|
||||
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
|
||||
f3(torch.randn(3))
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f4(x):
|
||||
with torch._dynamo.patch_dynamo_config(verbose=object()):
|
||||
return x + 1
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "Cannot convert patch_dynamo_config args/kwargs to constants."
|
||||
):
|
||||
f4(torch.randn(3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -309,7 +309,7 @@ from user code:
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo developers have intentionally marked that the function `skip` in file `case.py` should not be traced.
|
||||
Hint: Avoid calling the function `skip`.
|
||||
Hint: Remove the function `skip` or the file `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `skip` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Please file an issue to PyTorch.
|
||||
|
||||
Developer debug context: module: unittest.case, qualname: skip, skip reason: <missing reason>
|
||||
@ -358,7 +358,7 @@ from user code:
|
||||
Attempted to inline function marked as skipped
|
||||
Explanation: Dynamo developers have intentionally marked that the function `skip` should not be traced.
|
||||
Hint: Avoid calling the function `skip`.
|
||||
Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `skip` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Please file an issue to PyTorch.
|
||||
|
||||
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest
|
||||
|
@ -20,6 +20,7 @@ from .decorators import (
|
||||
assume_constant_result,
|
||||
disable,
|
||||
disallow_in_graph,
|
||||
dont_skip_tracing,
|
||||
forbid_in_graph,
|
||||
graph_break,
|
||||
mark_dynamic,
|
||||
@ -27,6 +28,7 @@ from .decorators import (
|
||||
mark_static_address,
|
||||
maybe_mark_dynamic,
|
||||
nonstrict_trace,
|
||||
patch_dynamo_config,
|
||||
run,
|
||||
set_stance,
|
||||
substitute_in_graph,
|
||||
@ -57,6 +59,7 @@ __all__ = [
|
||||
"allow_in_graph",
|
||||
"assume_constant_result",
|
||||
"disallow_in_graph",
|
||||
"dont_skip_tracing",
|
||||
"forbid_in_graph",
|
||||
"substitute_in_graph",
|
||||
"graph_break",
|
||||
@ -67,6 +70,7 @@ __all__ = [
|
||||
"nonstrict_trace",
|
||||
"optimize",
|
||||
"optimize_assert",
|
||||
"patch_dynamo_config",
|
||||
"export",
|
||||
"explain",
|
||||
"run",
|
||||
|
@ -38,7 +38,7 @@ verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
|
||||
# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
|
||||
verify_correctness = False
|
||||
|
||||
# need this many ops to create an FX graph
|
||||
# need this many ops to create an FX graph (deprecated: not used)
|
||||
minimum_call_count = 1
|
||||
|
||||
# turn on/off DCE pass (deprecated: always true)
|
||||
@ -322,6 +322,8 @@ do_not_emit_runtime_asserts: bool = (
|
||||
# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
|
||||
skip_torchrec = True
|
||||
|
||||
# Don't apply most trace_rules.py rules
|
||||
dont_skip_tracing = False
|
||||
|
||||
# No longer used
|
||||
optimize_ddp_lazy_compile = False
|
||||
|
@ -408,11 +408,16 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile"
|
||||
)
|
||||
prof = cProfile.Profile()
|
||||
prof.enable()
|
||||
start_ts = time.time()
|
||||
retval = prof.runcall(func, *args, **kwargs)
|
||||
profile_latency = time.time() - start_ts
|
||||
prof.disable()
|
||||
try:
|
||||
prof.enable()
|
||||
start_ts = time.time()
|
||||
retval = prof.runcall(func, *args, **kwargs)
|
||||
profile_latency = time.time() - start_ts
|
||||
prof.disable()
|
||||
except ValueError:
|
||||
log.exception("failed to enable cProfile")
|
||||
profile_latency = 0
|
||||
retval = func(*args, **kwargs)
|
||||
log.warning(
|
||||
"### Cprofile for %s trace id [%s] took %.3f seconds ###",
|
||||
func.__name__,
|
||||
@ -1226,6 +1231,7 @@ class ConvertFrame:
|
||||
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
||||
skip: int = 0,
|
||||
) -> ConvertFrameReturn:
|
||||
input_codes.add(frame.f_code)
|
||||
counters["frames"]["total"] += 1
|
||||
try:
|
||||
result = self._inner_convert(
|
||||
@ -1385,6 +1391,8 @@ class CatchErrorsWrapper:
|
||||
) -> ConvertFrameReturn:
|
||||
assert frame_state is not None
|
||||
|
||||
input_codes.add(frame.f_code)
|
||||
|
||||
is_skipfile = trace_rules.check(frame.f_code)
|
||||
if sys.version_info >= (3, 13):
|
||||
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
|
||||
|
@ -10,7 +10,7 @@ import inspect
|
||||
import sys
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -30,7 +30,11 @@ from .eval_frame import (
|
||||
skip_code,
|
||||
)
|
||||
from .exc import IncorrectUsage
|
||||
from .external_utils import get_nonrecursive_disable_wrapper, is_compiling
|
||||
from .external_utils import (
|
||||
_dynamo_config_patch_proxy_dunder_call,
|
||||
get_nonrecursive_disable_wrapper,
|
||||
is_compiling,
|
||||
)
|
||||
from .utils import is_function
|
||||
|
||||
|
||||
@ -732,3 +736,105 @@ def _allow_in_graph_einops():
|
||||
|
||||
|
||||
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)
|
||||
|
||||
|
||||
# Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators
|
||||
# created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch.
|
||||
class DynamoConfigPatchProxy:
|
||||
def __init__(self, config_patch):
|
||||
self.config_patch = config_patch
|
||||
|
||||
@property
|
||||
def changes(self):
|
||||
return self.config_patch.changes
|
||||
|
||||
# Decorator implementation that simply sets up `self` as a context manager.
|
||||
# Placed in external_utils so that we can trace through it.
|
||||
__call__ = _dynamo_config_patch_proxy_dunder_call
|
||||
|
||||
def __enter__(self):
|
||||
return self.config_patch.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return self.config_patch.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
# Criteria for patchable config:
|
||||
# - Config values must be constants (i.e. int, float, str, bool, None).
|
||||
# - in particular, NO list, set, dict.
|
||||
# - Traceable config patches are only useful for configs that change dynamo behavior
|
||||
# from symbolic_convert and below.
|
||||
# - e.g. patching recompile_limit won't really do anything.
|
||||
# - For patching configs that affect Dynamo behavior above symbolic_convert,
|
||||
# ensure that Dynamo behaves soundly even if tracing is done with different config.
|
||||
# - e.g. be careful if patching guard-related configs as configs may have changed
|
||||
# between guard creation and evaluation.
|
||||
_allowed_config_patches = (
|
||||
"verbose",
|
||||
"verify_correctness",
|
||||
"rewrite_assert_with_torch_assert",
|
||||
"capture_scalar_outputs",
|
||||
"allow_unspec_int_on_nn_module",
|
||||
"skip_torchrec",
|
||||
"dont_skip_tracing",
|
||||
)
|
||||
|
||||
for name in _allowed_config_patches:
|
||||
assert hasattr(torch._dynamo.config, name), "nonexistent config"
|
||||
|
||||
|
||||
def _patch_dynamo_config_check(changes: dict[str, Any]):
|
||||
for k, v in changes.items():
|
||||
if k not in _allowed_config_patches:
|
||||
raise ValueError(
|
||||
f"patch_dynamo_config does not support patching config {k}"
|
||||
)
|
||||
if not torch._dynamo.utils.is_safe_constant(v):
|
||||
raise ValueError(
|
||||
f"patch_dynamo_config does not support patching config {k} "
|
||||
f"with non-safe-constant value {v}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: also implement nonrecursive patch_dynamo_config/dont_skip_tracing.
|
||||
# Unlike config.patch, we also need to accept tuple as input in order to
|
||||
# deal with context manager reconstruction.
|
||||
def patch_dynamo_config(
|
||||
arg1: Optional[Union[str, dict[str, Any], tuple[tuple[str, Any], ...]]] = None,
|
||||
arg2: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> DynamoConfigPatchProxy:
|
||||
"""
|
||||
A wrapper around torch._dynamo.config.patch that can be traced by Dynamo to
|
||||
temporarily change config values DURING tracing.
|
||||
|
||||
See _allowed_config_patches for the list of allowed config patches.
|
||||
|
||||
Arguments are the same as with torch._dynamo.confing.patch.
|
||||
|
||||
Can be used as a decorator or a context manager.
|
||||
|
||||
User code SHOULD NOT MODIFY the return value of this function.
|
||||
|
||||
WARNING: changing Dynamo config during tracing can lead to unpredictable tracing behavior!
|
||||
Proceed only as advised!
|
||||
"""
|
||||
if isinstance(arg1, tuple):
|
||||
arg1 = dict(arg1)
|
||||
config_patch = torch._dynamo.config.patch(arg1, arg2, **kwargs)
|
||||
_patch_dynamo_config_check(config_patch.changes)
|
||||
# check for valid patching using config_patch.changes
|
||||
return DynamoConfigPatchProxy(config_patch)
|
||||
|
||||
|
||||
def dont_skip_tracing(fn=None):
|
||||
"""
|
||||
Context manager/decorator to trace into functions intentionally marked by developers to be skipped
|
||||
when tracing.
|
||||
|
||||
This decorator will also apply to recursively invoked functions.
|
||||
"""
|
||||
ctx = patch_dynamo_config(dont_skip_tracing=True)
|
||||
if fn:
|
||||
return ctx(fn)
|
||||
return ctx
|
||||
|
@ -198,3 +198,14 @@ def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return nonrecursive_disable_wrapper
|
||||
|
||||
|
||||
def _dynamo_config_patch_proxy_dunder_call(
|
||||
self: Any, func: Callable[_P, _R]
|
||||
) -> Callable[_P, _R]:
|
||||
@functools.wraps(func)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
@ -3730,6 +3730,22 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
|
||||
func.get_function(), "_torchdynamo_disable", False
|
||||
):
|
||||
msg = inspect.getattr_static(
|
||||
func.get_function(), "_torchdynamo_disable_msg", None
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Skip inlining `torch.compiler.disable()`d function",
|
||||
context=str(func.get_function()),
|
||||
explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
|
||||
f"with `torch.compiler.disable` (reason: {msg})",
|
||||
hints=[
|
||||
"Remove the `torch.compiler.disable` call",
|
||||
],
|
||||
)
|
||||
|
||||
result = trace_rules.check_verbose(func, is_inlined_call=True)
|
||||
if result.skipped:
|
||||
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
|
||||
@ -3749,11 +3765,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
]
|
||||
if "_dynamo" not in func.get_filename():
|
||||
hints += [
|
||||
f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` "
|
||||
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
|
||||
"attempting to trace into the function.",
|
||||
f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{fn_qualname}` "
|
||||
"to force tracing into the function. "
|
||||
"More graph breaks may occur as a result of attempting to trace into the function.",
|
||||
"Please file an issue to PyTorch.",
|
||||
# TODO suggest mark_force_inline when implemented
|
||||
]
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to inline function marked as skipped",
|
||||
@ -3764,23 +3779,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
|
||||
func.get_function(), "_torchdynamo_disable", False
|
||||
):
|
||||
msg = inspect.getattr_static(
|
||||
func.get_function(), "_torchdynamo_disable_msg", None
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Skip inlining `torch.compiler.disable()`d function",
|
||||
context=str(func.get_function()),
|
||||
explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
|
||||
f"with `torch.compiler.disable` (reason: {msg})",
|
||||
hints=[
|
||||
"Remove the `torch.compiler.disable` call",
|
||||
],
|
||||
)
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def build_inline_tracer(
|
||||
|
40
torch/_dynamo/test_dont_skip_tracing_functions.py
Normal file
40
torch/_dynamo/test_dont_skip_tracing_functions.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
Functions used to test torch._dynamo.dont_skip_tracing.
|
||||
This file is located in torch/_dynamo so that it is skipped by trace rules.
|
||||
There is a special rule in trace_rules that doesn't skip this file when
|
||||
dont_skip_tracing is active.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def f1(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + 1
|
||||
|
||||
|
||||
def f2(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + 1
|
||||
|
||||
|
||||
def f3(x: torch.Tensor) -> torch.Tensor:
|
||||
return f2(x)
|
||||
|
||||
|
||||
def f4(x: torch.Tensor) -> torch.Tensor:
|
||||
x = f5(x, 1)
|
||||
x = torch._dynamo.dont_skip_tracing(f6)(x)
|
||||
x = f5(x, 8)
|
||||
return x
|
||||
|
||||
|
||||
def f5(x: torch.Tensor, n: int) -> torch.Tensor:
|
||||
if torch.compiler.is_compiling():
|
||||
return x + n
|
||||
return x
|
||||
|
||||
|
||||
def f6(x: torch.Tensor) -> torch.Tensor:
|
||||
x = f5(x, 2)
|
||||
torch._dynamo.graph_break()
|
||||
x = f5(x, 4)
|
||||
return x
|
@ -28,6 +28,7 @@ import torch.utils._content_store
|
||||
from torch._environment import is_fbcode
|
||||
from torch.utils import _config_module
|
||||
|
||||
from . import config
|
||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
||||
from .variables import (
|
||||
@ -306,8 +307,10 @@ manual_torch_name_rule_map: dict[str, Any] = {
|
||||
"torch._tensor._convert": UserFunctionVariable,
|
||||
"torch.jit._unwrap_optional": UserFunctionVariable,
|
||||
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
|
||||
"torch._dynamo.dont_skip_tracing": UserFunctionVariable,
|
||||
"torch._dynamo.mark_static": UserFunctionVariable,
|
||||
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
|
||||
"torch._dynamo.patch_dynamo_config": UserFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
|
||||
@ -3261,7 +3264,7 @@ if torch.distributed.is_available():
|
||||
# the forward_hook won't be ignored.
|
||||
"torch.distributed._composable.replicate",
|
||||
}
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
if not config.skip_fsdp_hooks:
|
||||
LEGACY_MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||
|
||||
# Force inline functions under these modules, even they are in *_SKIPLIST.
|
||||
@ -3324,7 +3327,7 @@ MOD_INLINELIST = set(MOD_INLINELIST)
|
||||
|
||||
if torch.distributed.is_available():
|
||||
MOD_INLINELIST.add("torch.distributed")
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
if not config.skip_fsdp_hooks:
|
||||
MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||
|
||||
|
||||
@ -3575,7 +3578,7 @@ def check_file(filename, is_inlined_call=False):
|
||||
|
||||
if (
|
||||
is_fbcode()
|
||||
and torch._dynamo.config.skip_torchrec
|
||||
and config.skip_torchrec
|
||||
and FBCODE_SKIP_TORCHREC_DIRS
|
||||
and bool(FBCODE_SKIP_TORCHREC_DIRS_RE.match(filename))
|
||||
and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename))
|
||||
@ -3752,12 +3755,48 @@ def lookup(obj):
|
||||
return lookup_inner(obj)
|
||||
|
||||
|
||||
# also takes config.dont_skip_tracing into account
|
||||
def lookup_inner(
|
||||
obj,
|
||||
name=None,
|
||||
filename=None,
|
||||
is_direct_call=True,
|
||||
reasons: Union[None, set[str]] = None,
|
||||
):
|
||||
result = _lookup_inner(
|
||||
obj,
|
||||
name=name,
|
||||
filename=filename,
|
||||
is_direct_call=is_direct_call,
|
||||
reasons=reasons,
|
||||
)
|
||||
# There are still some modules we should absolutely NOT trace into - e.g. most of torch._dynamo,
|
||||
# as this can result in really weird tracing behaviors.
|
||||
# Note that if a torch._dynamo function is already not skipped (e.g. functions in external_utils.py),
|
||||
# then this branch does not apply.
|
||||
if config.dont_skip_tracing and result is SkipFunctionVariable:
|
||||
if filename is None:
|
||||
filename = getfile(obj)
|
||||
filename = _as_posix_path(filename)
|
||||
dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo"
|
||||
if filename.startswith(dynamo_path) and not filename.endswith(
|
||||
"test_dont_skip_tracing_functions.py"
|
||||
):
|
||||
return SkipFunctionVariable
|
||||
if reasons is not None:
|
||||
reasons.add(
|
||||
"Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing"
|
||||
)
|
||||
return UserFunctionVariable
|
||||
return result
|
||||
|
||||
|
||||
def _lookup_inner(
|
||||
obj,
|
||||
name=None,
|
||||
filename=None,
|
||||
is_direct_call=True,
|
||||
reasons: Union[None, set[str]] = None,
|
||||
):
|
||||
# Step 1: lookup obj's tracing rule in `torch_name_rule_map`.
|
||||
# The rules defined in `torch_name_rule_map` mainly includes two parts:
|
||||
|
@ -4531,6 +4531,7 @@ def set_feature_use(feature: str, usage: bool):
|
||||
_ddp_optimization_mode: tuple[str, ...] = (
|
||||
"ddp_optimizer",
|
||||
"python_reducer", # experimental mode
|
||||
"python_reducer_without_compiled_forward",
|
||||
"no_optimization",
|
||||
)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from .ctx_manager import (
|
||||
DeterministicAlgorithmsVariable,
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
DualLevelContextManager,
|
||||
DynamoConfigPatchVariable,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
@ -164,6 +165,7 @@ __all__ = [
|
||||
"DeletedVariable",
|
||||
"DeterministicAlgorithmsVariable",
|
||||
"DictKeySetVariable",
|
||||
"DynamoConfigPatchVariable",
|
||||
"EnumVariable",
|
||||
"FakeItemVariable",
|
||||
"GetAttrVariable",
|
||||
|
@ -154,6 +154,7 @@ from .base import (
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
from .ctx_manager import (
|
||||
AutocastModeVariable,
|
||||
DynamoConfigPatchVariable,
|
||||
EventVariable,
|
||||
NullContextVariable,
|
||||
PreserveVersionContextVariable,
|
||||
@ -594,6 +595,8 @@ class VariableBuilder:
|
||||
# import here to avoid circular dependencies
|
||||
from torch.utils._triton import has_triton, has_triton_tma
|
||||
|
||||
from ..decorators import DynamoConfigPatchProxy
|
||||
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
@ -911,6 +914,8 @@ class VariableBuilder:
|
||||
{},
|
||||
)
|
||||
)
|
||||
elif isinstance(value, DynamoConfigPatchProxy):
|
||||
return DynamoConfigPatchVariable(value.changes)
|
||||
elif callable(value) and trace_rules.lookup_callable(value) is not None:
|
||||
if trace_rules.is_callable_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
|
@ -1366,6 +1366,47 @@ class EventVariable(VariableTracker):
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
||||
# NOTE: no need to guard on dynamo config because dynamo config should not affect soundness
|
||||
# (though it may affect tracing behavior)
|
||||
def __init__(self, target_values, **kwargs) -> None:
|
||||
target_values = tuple(target_values.items())
|
||||
super().__init__(target_values=(target_values,), initial_values=None, **kwargs)
|
||||
self.initial_values = {}
|
||||
for key, _ in target_values:
|
||||
self.initial_values[key] = torch._dynamo.config.__getattr__(key)
|
||||
self.initial_values = (tuple(self.initial_values.items()),)
|
||||
|
||||
def enter(self, tx):
|
||||
# resets all config patches at the end of tracing
|
||||
self.set_cleanup_hook(tx)
|
||||
self._call_func(tx, self.target_values)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self._call_func(tx, self.initial_values)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
assert len(values) == 1
|
||||
value = values[0]
|
||||
# manually patch dynamo config
|
||||
for key, val in value:
|
||||
torch._dynamo.config.__setattr__(key, val)
|
||||
# No need to keep track of global side effects because
|
||||
# dynamo will properly restore this context manager for
|
||||
# unsupported instructions and continuation functions.
|
||||
# Dynamo config also should not affect the semantics of the compiled graph.
|
||||
|
||||
def module_name(self):
|
||||
return "torch._dynamo"
|
||||
|
||||
def fn_name(self):
|
||||
return "patch_dynamo_config"
|
||||
|
||||
|
||||
class WithExitFunctionVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
"target",
|
||||
|
@ -36,7 +36,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from .. import graph_break_hints, polyfills, variables
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
get_dynamo_observed_exception,
|
||||
@ -64,7 +64,12 @@ from ..utils import (
|
||||
istype,
|
||||
make_cell,
|
||||
)
|
||||
from .base import AttributeMutationNew, ValueMutationNew, VariableTracker
|
||||
from .base import (
|
||||
AsPythonConstantNotImplementedError,
|
||||
AttributeMutationNew,
|
||||
ValueMutationNew,
|
||||
VariableTracker,
|
||||
)
|
||||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
@ -372,6 +377,23 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# Handle patch_dynamo_config call
|
||||
if self.fn is torch._dynamo.patch_dynamo_config:
|
||||
try:
|
||||
args_const = [arg.as_python_constant() for arg in args]
|
||||
kwargs_const = {
|
||||
key: val.as_python_constant() for key, val in kwargs.items()
|
||||
}
|
||||
changes = torch._dynamo.patch_dynamo_config(
|
||||
*args_const, **kwargs_const
|
||||
).changes
|
||||
return variables.DynamoConfigPatchVariable(changes)
|
||||
except AsPythonConstantNotImplementedError as e:
|
||||
raise RuntimeError(
|
||||
"Cannot convert patch_dynamo_config args/kwargs to constants. "
|
||||
"Please fix your call to patch_dynamo_config by using simpler inputs. "
|
||||
f"args: {args}, kwargs: {kwargs}"
|
||||
) from e
|
||||
# Handle a `nonstrict_trace(fn)` call
|
||||
if self.fn is torch._dynamo.nonstrict_trace:
|
||||
bound = inspect.signature(self.fn).bind(*args, **kwargs)
|
||||
@ -1297,6 +1319,14 @@ class SkipFunctionVariable(VariableTracker):
|
||||
torch._dynamo.utils.warn_once(msg)
|
||||
unimplemented(msg)
|
||||
else:
|
||||
if config.dont_skip_tracing:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
# re-build the function, attempting to not skip
|
||||
rebuilt_fn = SourcelessBuilder.create(tx, self.value)
|
||||
# if we still get SkipFunctionVariable, then we *really* should skip this function
|
||||
if not isinstance(rebuilt_fn, SkipFunctionVariable):
|
||||
return rebuilt_fn.call_function(tx, args, kwargs)
|
||||
qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
|
||||
try:
|
||||
path = inspect.getfile(self.value)
|
||||
@ -1312,11 +1342,10 @@ class SkipFunctionVariable(VariableTracker):
|
||||
# Do a very basic check for now.
|
||||
if "_dynamo" not in path:
|
||||
hints += [
|
||||
f"Remove the function `{qualname}` or the file `{path}` "
|
||||
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
|
||||
"attempting to trace into the function.",
|
||||
f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` "
|
||||
"to force tracing into the function. "
|
||||
"More graph breaks may occur as a result of attempting to trace into the function.",
|
||||
"Please file an issue to PyTorch.",
|
||||
# TODO suggest mark_force_inline when implemented
|
||||
]
|
||||
except TypeError:
|
||||
known_python_builtin_modules = {"_abc", "_warnings"}
|
||||
|
@ -667,12 +667,15 @@ class ConfigModule(ModuleType):
|
||||
config = self
|
||||
|
||||
class ConfigPatch(ContextDecorator):
|
||||
def __init__(self) -> None:
|
||||
self.changes = changes
|
||||
|
||||
def __enter__(self) -> None:
|
||||
assert not prior
|
||||
for key in changes.keys():
|
||||
for key in self.changes.keys():
|
||||
# KeyError on invalid entry
|
||||
prior[key] = config.__getattr__(key)
|
||||
for k, v in changes.items():
|
||||
for k, v in self.changes.items():
|
||||
config.__setattr__(k, v)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
|
||||
|
Reference in New Issue
Block a user