mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code. The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature. Implementation: - Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch - Dynamo doesn't trace into the context manager but updates config at compile time - Correctness is based on our correctness for handling supported context managers - Implementation is inspired by how `GradModeVariable` is implemented. Previous attempts: https://github.com/pytorch/pytorch/pull/148736 (decorator-only global approach) and https://github.com/pytorch/pytorch/pull/149439 (decorator-only traceback approach) See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches. NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150586 Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
160 lines
5.1 KiB
Python
160 lines
5.1 KiB
Python
"""
|
|
TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster.
|
|
TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python
|
|
bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of
|
|
PyTorch operations into an FX Graph which is then just-in-time compiled with a customizable backend.
|
|
It creates this FX Graph through bytecode analysis and is designed to mix Python execution with
|
|
compiled backends to get the best of both worlds: usability and performance. This allows it to
|
|
seamlessly optimize PyTorch programs, including those using modern Python features.
|
|
"""
|
|
|
|
import torch
|
|
|
|
from . import config, convert_frame, eval_frame, resume_execution
|
|
from .backends.registry import list_backends, lookup_backend, register_backend
|
|
from .callback import callback_handler, on_compile_end, on_compile_start
|
|
from .code_context import code_context
|
|
from .convert_frame import replay
|
|
from .decorators import (
|
|
allow_in_graph,
|
|
assume_constant_result,
|
|
disable,
|
|
disallow_in_graph,
|
|
dont_skip_tracing,
|
|
forbid_in_graph,
|
|
graph_break,
|
|
mark_dynamic,
|
|
mark_static,
|
|
mark_static_address,
|
|
maybe_mark_dynamic,
|
|
nonstrict_trace,
|
|
patch_dynamo_config,
|
|
run,
|
|
set_stance,
|
|
substitute_in_graph,
|
|
)
|
|
from .eval_frame import (
|
|
_reset_guarded_backend_cache,
|
|
explain,
|
|
export,
|
|
is_dynamo_supported,
|
|
is_inductor_supported,
|
|
optimize,
|
|
optimize_assert,
|
|
OptimizedModule,
|
|
reset_code,
|
|
)
|
|
from .external_utils import is_compiling
|
|
from .mutation_guard import GenerationTracker
|
|
from .pgo import reset_code_state
|
|
from .symbolic_convert import TensorifyState
|
|
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
|
|
|
|
|
# Register polyfill functions
|
|
from .polyfills import loader as _ # usort: skip # noqa: F401
|
|
|
|
|
|
__all__ = [
|
|
"allow_in_graph",
|
|
"assume_constant_result",
|
|
"disallow_in_graph",
|
|
"dont_skip_tracing",
|
|
"forbid_in_graph",
|
|
"substitute_in_graph",
|
|
"graph_break",
|
|
"mark_dynamic",
|
|
"maybe_mark_dynamic",
|
|
"mark_static",
|
|
"mark_static_address",
|
|
"nonstrict_trace",
|
|
"optimize",
|
|
"optimize_assert",
|
|
"patch_dynamo_config",
|
|
"export",
|
|
"explain",
|
|
"run",
|
|
"replay",
|
|
"disable",
|
|
"set_stance",
|
|
"reset",
|
|
"OptimizedModule",
|
|
"is_compiling",
|
|
"register_backend",
|
|
"list_backends",
|
|
"lookup_backend",
|
|
"config",
|
|
]
|
|
|
|
# allowlist this for weights_only load of NJTs
|
|
torch.serialization.add_safe_globals([torch._dynamo.decorators._DimRange])
|
|
|
|
if torch.manual_seed is torch.random.manual_seed:
|
|
import torch.jit._builtins
|
|
|
|
# Wrap manual_seed with the disable decorator.
|
|
# Can't do it at its implementation due to dependency issues.
|
|
torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
|
|
# Add the new manual_seed to the builtin registry.
|
|
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
|
|
|
|
|
|
def reset() -> None:
|
|
"""
|
|
Clear all compile caches and restore initial state. This function is intended
|
|
to reset Dynamo's state *as if* you had started a fresh process invocation, which
|
|
makes it good for testing scenarios where you want to behave as if you started
|
|
a new process. It does NOT affect any file system caches.
|
|
|
|
NB: this does NOT reset logging state. Don't use this to test logging
|
|
initialization/reinitialization.
|
|
"""
|
|
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.info("torch._dynamo.reset")
|
|
with convert_frame.compile_lock:
|
|
reset_code_caches()
|
|
convert_frame.input_codes.clear()
|
|
reset_code_state()
|
|
convert_frame.output_codes.clear()
|
|
orig_code_map.clear()
|
|
guard_failures.clear()
|
|
graph_break_reasons.clear()
|
|
resume_execution.ContinueExecutionCache.cache.clear()
|
|
_reset_guarded_backend_cache()
|
|
reset_frame_count()
|
|
torch._dynamo.compiled_autograd.reset()
|
|
convert_frame.FRAME_COUNTER = 0
|
|
convert_frame.FRAME_COMPILE_COUNTER.clear()
|
|
callback_handler.clear()
|
|
GenerationTracker.clear()
|
|
TensorifyState.clear()
|
|
torch._dynamo.utils.warn_once_cache.clear()
|
|
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
|
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
|
|
|
|
|
def reset_code_caches() -> None:
|
|
"""
|
|
Clears in-memory code cache, which is what stores compiled products. This
|
|
resets less state than :func:`reset` and is mostly only used for testing
|
|
purposes.
|
|
"""
|
|
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.info("torch._dynamo.reset_code_caches")
|
|
"""Clear compile caches that are keyed by code objects"""
|
|
with convert_frame.compile_lock:
|
|
reset_code_state()
|
|
for weak_code in (
|
|
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
|
):
|
|
code = weak_code()
|
|
if code:
|
|
reset_code(code)
|
|
code_context.clear()
|