mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283 Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
2c68c3e8d5
commit
b46eb1ccaf
@ -1814,9 +1814,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||||||
)
|
)
|
||||||
# Ensure no more re-compilation after the second automatic dynamic shape version.
|
# Ensure no more re-compilation after the second automatic dynamic shape version.
|
||||||
if i == 0:
|
if i == 0:
|
||||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
||||||
else:
|
|
||||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||||
|
else:
|
||||||
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 4)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||||
|
@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
skipIfCrossRef,
|
skipIfCrossRef,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
|
skipIfWindows,
|
||||||
TemporaryFileName,
|
TemporaryFileName,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
TEST_WITH_TORCHDYNAMO,
|
||||||
TestCase,
|
TestCase,
|
||||||
@ -2226,6 +2227,9 @@ class FakeTensorDispatchCache(TestCase):
|
|||||||
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
|
lambda: torch.ops.aten.index(x, [None, idx_tensor1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfWindows(
|
||||||
|
msg="weird bug - cache may not be cleared after https://github.com/pytorch/pytorch/pull/154283"
|
||||||
|
)
|
||||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||||
def test_invoke_subgraph(self):
|
def test_invoke_subgraph(self):
|
||||||
"""
|
"""
|
||||||
|
@ -325,6 +325,11 @@ skip_torchrec = True
|
|||||||
# Don't apply most trace_rules.py rules
|
# Don't apply most trace_rules.py rules
|
||||||
dont_skip_tracing = False
|
dont_skip_tracing = False
|
||||||
|
|
||||||
|
# If True, enforce fullgraph=True - raise errors on graph break
|
||||||
|
# NOTE: do not set manually - this is modified internally by Dynamo.
|
||||||
|
# Use the fullgraph option of torch.compile instead.
|
||||||
|
error_on_graph_break = False
|
||||||
|
|
||||||
# No longer used
|
# No longer used
|
||||||
optimize_ddp_lazy_compile = False
|
optimize_ddp_lazy_compile = False
|
||||||
|
|
||||||
|
@ -654,7 +654,7 @@ def convert_frame_assert(
|
|||||||
export_constraints: Optional[typing.Never] = None,
|
export_constraints: Optional[typing.Never] = None,
|
||||||
package: Optional[CompilePackage] = None,
|
package: Optional[CompilePackage] = None,
|
||||||
) -> ConvertFrameAssert:
|
) -> ConvertFrameAssert:
|
||||||
"""Fully convert a frame into an FX graph"""
|
"""Fully convert a frame into an FX graph, raising an exception if we fail."""
|
||||||
return ConvertFrameAssert(
|
return ConvertFrameAssert(
|
||||||
compiler_fn, one_graph, export, export_constraints, package
|
compiler_fn, one_graph, export, export_constraints, package
|
||||||
)
|
)
|
||||||
@ -862,8 +862,10 @@ def _compile(
|
|||||||
code.co_filename,
|
code.co_filename,
|
||||||
code.co_firstlineno,
|
code.co_firstlineno,
|
||||||
)
|
)
|
||||||
if one_graph:
|
if one_graph or config.error_on_graph_break:
|
||||||
log.debug("No graph captured with one_graph=True")
|
log.debug(
|
||||||
|
"No graph captured with one_graph=True or torch._dynamo.config.error_on_graph_break=True"
|
||||||
|
)
|
||||||
return ConvertFrameReturn()
|
return ConvertFrameReturn()
|
||||||
|
|
||||||
assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
|
assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
|
||||||
@ -1029,9 +1031,10 @@ def _compile(
|
|||||||
raise FailOnRecompileLimitHit(
|
raise FailOnRecompileLimitHit(
|
||||||
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
||||||
)
|
)
|
||||||
elif one_graph:
|
elif one_graph or config.error_on_graph_break:
|
||||||
raise FailOnRecompileLimitHit(
|
raise FailOnRecompileLimitHit(
|
||||||
f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade "
|
f"{limit_type} reached with one_graph=True or torch._dynamo.config.error_on_graph_break=True. "
|
||||||
|
"Excessive recompilations can degrade "
|
||||||
"performance due to the compilation overhead of each recompilation. To monitor "
|
"performance due to the compilation overhead of each recompilation. To monitor "
|
||||||
"recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
|
"recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
|
||||||
"increasing torch._dynamo.config.cache_size_limit to an appropriate value."
|
"increasing torch._dynamo.config.cache_size_limit to an appropriate value."
|
||||||
@ -1245,6 +1248,7 @@ class ConvertFrame:
|
|||||||
self,
|
self,
|
||||||
compiler_fn: CompilerFn,
|
compiler_fn: CompilerFn,
|
||||||
hooks: Hooks,
|
hooks: Hooks,
|
||||||
|
error_on_graph_break: bool,
|
||||||
package: Optional[CompilePackage] = None,
|
package: Optional[CompilePackage] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._torchdynamo_orig_callable = compiler_fn
|
self._torchdynamo_orig_callable = compiler_fn
|
||||||
@ -1252,10 +1256,13 @@ class ConvertFrame:
|
|||||||
compiler_fn, one_graph=False, package=package
|
compiler_fn, one_graph=False, package=package
|
||||||
)
|
)
|
||||||
self._hooks = hooks
|
self._hooks = hooks
|
||||||
|
self._error_on_graph_break = error_on_graph_break
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
|
def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
|
||||||
return lambda backend: convert_frame(backend, self._hooks)
|
return lambda backend: convert_frame(
|
||||||
|
backend, self._hooks, self._error_on_graph_break
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -1267,13 +1274,17 @@ class ConvertFrame:
|
|||||||
) -> ConvertFrameReturn:
|
) -> ConvertFrameReturn:
|
||||||
input_codes.add(frame.f_code)
|
input_codes.add(frame.f_code)
|
||||||
counters["frames"]["total"] += 1
|
counters["frames"]["total"] += 1
|
||||||
|
prev_error_on_graph_break = config.error_on_graph_break
|
||||||
try:
|
try:
|
||||||
|
config.error_on_graph_break = self._error_on_graph_break
|
||||||
result = self._inner_convert(
|
result = self._inner_convert(
|
||||||
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
||||||
)
|
)
|
||||||
counters["frames"]["ok"] += 1
|
counters["frames"]["ok"] += 1
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if config.error_on_graph_break:
|
||||||
|
raise
|
||||||
# These two exception types are "soft" failure, in the sense that
|
# These two exception types are "soft" failure, in the sense that
|
||||||
# we know this is due to something we didn't implement all the
|
# we know this is due to something we didn't implement all the
|
||||||
# way, scare the user less about it. That being said, if you
|
# way, scare the user less about it. That being said, if you
|
||||||
@ -1349,15 +1360,24 @@ class ConvertFrame:
|
|||||||
FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
|
FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
config.error_on_graph_break = prev_error_on_graph_break
|
||||||
|
|
||||||
return ConvertFrameReturn()
|
return ConvertFrameReturn()
|
||||||
|
|
||||||
|
|
||||||
def convert_frame(
|
def convert_frame(
|
||||||
compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None
|
compiler_fn: CompilerFn,
|
||||||
|
hooks: Hooks,
|
||||||
|
error_on_graph_break: bool,
|
||||||
|
package: Optional[CompilePackage] = None,
|
||||||
) -> ConvertFrame:
|
) -> ConvertFrame:
|
||||||
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
"""Try to convert a frame into an FX graph, if error leave frame unmodified
|
||||||
return ConvertFrame(compiler_fn, hooks, package=package)
|
|
||||||
|
If error_on_graph_break=True, graph breaks become errors (resulting in an unmodified frame).
|
||||||
|
If error_on_graph_break=False, we will attempt to generate optimized and resume functions.
|
||||||
|
"""
|
||||||
|
return ConvertFrame(compiler_fn, hooks, error_on_graph_break, package=package)
|
||||||
|
|
||||||
|
|
||||||
# TODO mlazos: add support for same args, or record them
|
# TODO mlazos: add support for same args, or record them
|
||||||
@ -1370,7 +1390,9 @@ def replay(filename: str) -> None:
|
|||||||
record = ExecutionRecord.load(in_file)
|
record = ExecutionRecord.load(in_file)
|
||||||
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
||||||
|
|
||||||
|
prev_error_on_graph_break = config.error_on_graph_break
|
||||||
try:
|
try:
|
||||||
|
config.error_on_graph_break = False
|
||||||
_compile(
|
_compile(
|
||||||
record.code,
|
record.code,
|
||||||
record.globals,
|
record.globals,
|
||||||
@ -1390,6 +1412,7 @@ def replay(filename: str) -> None:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
config.replay_record_enabled = original_replay_val
|
config.replay_record_enabled = original_replay_val
|
||||||
|
config.error_on_graph_break = prev_error_on_graph_break
|
||||||
|
|
||||||
|
|
||||||
def first_real_inst_idx(code: CodeType) -> int:
|
def first_real_inst_idx(code: CodeType) -> int:
|
||||||
|
@ -227,6 +227,7 @@ def _create_wrapped_callback(compiler_fn):
|
|||||||
convert_frame.convert_frame( # type: ignore[arg-type]
|
convert_frame.convert_frame( # type: ignore[arg-type]
|
||||||
compiler_fn,
|
compiler_fn,
|
||||||
hooks,
|
hooks,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
hooks,
|
hooks,
|
||||||
)
|
)
|
||||||
@ -1080,15 +1081,6 @@ def _optimize(
|
|||||||
):
|
):
|
||||||
return _NullDecorator()
|
return _NullDecorator()
|
||||||
|
|
||||||
if nopython:
|
|
||||||
return optimize_assert(
|
|
||||||
backend,
|
|
||||||
dynamic=dynamic,
|
|
||||||
hooks=hooks,
|
|
||||||
rebuild_ctx=rebuild_ctx,
|
|
||||||
package=package,
|
|
||||||
)
|
|
||||||
|
|
||||||
backend = get_compiler_fn(backend)
|
backend = get_compiler_fn(backend)
|
||||||
|
|
||||||
# Find if backend has any extra context manager
|
# Find if backend has any extra context manager
|
||||||
@ -1098,7 +1090,7 @@ def _optimize(
|
|||||||
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
|
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
|
||||||
# be used by eval_frame.c to insert a guard on the backend.
|
# be used by eval_frame.c to insert a guard on the backend.
|
||||||
return _optimize_catch_errors(
|
return _optimize_catch_errors(
|
||||||
convert_frame.convert_frame(backend, hooks=hooks, package=package),
|
convert_frame.convert_frame(backend, hooks, nopython, package=package),
|
||||||
hooks,
|
hooks,
|
||||||
backend_ctx_ctor,
|
backend_ctx_ctor,
|
||||||
dynamic=dynamic,
|
dynamic=dynamic,
|
||||||
@ -2002,7 +1994,11 @@ def _optimize_assert(
|
|||||||
package=None,
|
package=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The same as `torch._dynamo.optimize(backend, nopython=True)`
|
The same as `torch._dynamo.optimize(backend, nopython=True)`,
|
||||||
|
but ignores config.error_on_graph_break setting.
|
||||||
|
|
||||||
|
Used for export, since we must always error on graph breaks and ignore
|
||||||
|
config.error_on_graph_break.
|
||||||
"""
|
"""
|
||||||
backend = get_compiler_fn(backend)
|
backend = get_compiler_fn(backend)
|
||||||
|
|
||||||
|
@ -3243,6 +3243,9 @@ class InstructionTranslatorBase(
|
|||||||
self.num_calls: dict[str, int] = {}
|
self.num_calls: dict[str, int] = {}
|
||||||
# Flag to indicate whether tracing is used for export.
|
# Flag to indicate whether tracing is used for export.
|
||||||
self.export = export
|
self.export = export
|
||||||
|
# NOTE: one_graph is used for export/debugging to always force errors on graph breaks.
|
||||||
|
# For allow for fullgraph toggle during normal compile, config.error_on_graph_break
|
||||||
|
# is used instead.
|
||||||
self.one_graph = False
|
self.one_graph = False
|
||||||
|
|
||||||
self.current_speculation = None
|
self.current_speculation = None
|
||||||
@ -3507,6 +3510,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
return (
|
return (
|
||||||
all(b.can_restore() for b in self.block_stack)
|
all(b.can_restore() for b in self.block_stack)
|
||||||
and not self.one_graph
|
and not self.one_graph
|
||||||
|
and not config.error_on_graph_break
|
||||||
and not self.active_generic_context_managers
|
and not self.active_generic_context_managers
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3641,6 +3645,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
and not self.symbolic_locals_contain_module_class()
|
and not self.symbolic_locals_contain_module_class()
|
||||||
and not self.export
|
and not self.export
|
||||||
and not self.one_graph
|
and not self.one_graph
|
||||||
|
and not config.error_on_graph_break
|
||||||
):
|
):
|
||||||
raise exc.SkipFrame("because no content in function call")
|
raise exc.SkipFrame("because no content in function call")
|
||||||
|
|
||||||
|
@ -1006,7 +1006,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
|||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
if name == "queue_callback":
|
if name == "queue_callback":
|
||||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||||
assert tx.one_graph, (
|
assert tx.one_graph or config.error_on_graph_break, (
|
||||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||||
)
|
)
|
||||||
return variables.UserFunctionVariable(
|
return variables.UserFunctionVariable(
|
||||||
|
Reference in New Issue
Block a user