mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch.export] Support is_compiling() flag for non-strict mode (#119602)
Summary: In non-strict mode of torch.export() we didn't set those `is_compiling()` to `True` which is needed by some models. Test Plan: Unit tests and manual testing. Differential Revision: D53624452 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119602 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a46102b37
commit
4b18ab869f
@ -20,3 +20,5 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
|
||||
list_backends
|
||||
disable
|
||||
cudagraph_mark_step_begin
|
||||
is_compiling
|
||||
is_dynamo_compiling
|
||||
|
@ -26,6 +26,8 @@ disable compilation are listed in the following table:
|
||||
"``torch._dynamo.disallow_in_graph``", "Disallows the marked op in the TorchDynamo graph. TorchDynamo causes graph break, and runs the op in the eager (no compile) mode.\n\nThis is suitable for the ops, while ``torch.compiler.disable`` is suitable for decorating functions.", "This API is excellent for both debugging and unblocking if a custom op like ``torch.ops.fbgemm.*`` is causing issues with the ``torch.compile`` function."
|
||||
"``torch.compile.allow_in_graph``", "The annotated callable goes as is in the TorchDynamo graph. For example, a black-box for TorchDynamo Dynamo.\n\nNote that AOT Autograd will trace through it, so the ``allow_in_graph`` is only a Dynamo-level concept.", "This API is useful for portions of the model which have known TorchDynamo hard-to-support features, like hooks or ``autograd.Function``. However, each usage of ``allow_in_graph`` **must be carefully screened** (no graph breaks, no closures)."
|
||||
"``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``."
|
||||
"``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()."
|
||||
"``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used."
|
||||
|
||||
``torch.compiler.disable``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -5908,16 +5908,35 @@ def fn():
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
||||
def test_is_compiling(self):
|
||||
def f():
|
||||
def f1():
|
||||
if torch._dynamo.is_compiling():
|
||||
return torch.ones(2, 2)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
opt_f = torch._dynamo.optimize("eager")(f)
|
||||
def f2():
|
||||
if torch._utils.is_compiling():
|
||||
return torch.ones(2, 2)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
self.assertEqual(f(), torch.zeros(2, 2))
|
||||
self.assertEqual(opt_f(), torch.ones(2, 2))
|
||||
def f3():
|
||||
if torch.compiler.is_compiling():
|
||||
return torch.ones(2, 2)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
def f4():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return torch.ones(2, 2)
|
||||
else:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
for f in [f1, f2, f3, f4]:
|
||||
opt_f = torch._dynamo.optimize("eager")(f)
|
||||
|
||||
self.assertEqual(f(), torch.zeros(2, 2))
|
||||
self.assertEqual(opt_f(), torch.ones(2, 2))
|
||||
|
||||
def test_torch_generator_set_state(self):
|
||||
def fn():
|
||||
|
@ -3404,6 +3404,37 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
# under a new FakeTensorMode.
|
||||
ep = torch.export.export(m, (inp,))
|
||||
|
||||
def test_compiling_state(self):
|
||||
class TestModule1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
if torch._dynamo.is_compiling():
|
||||
return x * 2
|
||||
else:
|
||||
return x * 3
|
||||
|
||||
class TestModule2(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
if torch._utils.is_compiling():
|
||||
return x * 2
|
||||
else:
|
||||
return x * 3
|
||||
|
||||
class TestModule3(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
if torch.compiler.is_compiling():
|
||||
return x * 2
|
||||
else:
|
||||
return x * 3
|
||||
|
||||
for m in [TestModule1(), TestModule2(), TestModule3()]:
|
||||
input = torch.randn(5)
|
||||
ep_strict = export(m, (input,), strict=True)
|
||||
ep_non_strict = export(m, (input,), strict=False)
|
||||
|
||||
self.assertTrue(torch.allclose(input * 3, m(input)))
|
||||
self.assertTrue(torch.allclose(input * 2, ep_strict(input)))
|
||||
self.assertTrue(torch.allclose(input * 2, ep_non_strict(input)))
|
||||
|
||||
def test_user_input_and_buffer_mutation(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -12,7 +12,18 @@ except ModuleNotFoundError:
|
||||
|
||||
|
||||
def is_compiling() -> bool:
|
||||
return False
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
||||
If need to check specifically that TorchDynamo is used, then use
|
||||
torch.compiler.is_dynamo_compiling().
|
||||
|
||||
TODO(khabinov): we should deprecate this function and use one of these two:
|
||||
* torch.compiler.is_compiling(),
|
||||
* torch.compiler.is_dynamo_compiling().
|
||||
It will depend on the context where to use what.
|
||||
"""
|
||||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def wrap_inline(fn):
|
||||
|
@ -101,6 +101,8 @@ manual_torch_name_rule_map = {
|
||||
"torch.overrides.get_default_nowrap_functions": TorchInGraphFunctionVariable,
|
||||
"torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable,
|
||||
"torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.autograd._profiler_enabled": SkipFunctionVariable,
|
||||
# We graph break on RNG state setters or getters like
|
||||
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
|
||||
|
@ -110,6 +110,8 @@ tracing_state_functions = {
|
||||
torch.onnx.is_in_onnx_export: False,
|
||||
torch._dynamo.external_utils.is_compiling: True,
|
||||
torch._utils.is_compiling: True,
|
||||
torch.compiler.is_compiling: True,
|
||||
torch.compiler.is_dynamo_compiling: True,
|
||||
}
|
||||
|
||||
|
||||
@ -304,6 +306,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
if self.value in (
|
||||
torch._utils.is_compiling,
|
||||
torch._dynamo.external_utils.is_compiling,
|
||||
torch.compiler.is_compiling,
|
||||
torch.compiler.is_dynamo_compiling,
|
||||
):
|
||||
tx.mark_inconsistent_side_effects()
|
||||
return ConstantVariable.create(tracing_state_functions[self.value])
|
||||
|
@ -103,7 +103,7 @@ def cond(pred, true_fn, false_fn, operands):
|
||||
|
||||
"""
|
||||
|
||||
if torch._dynamo.is_compiling():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
def _validate_input(pred, true_fn, false_fn, operands):
|
||||
|
@ -1,5 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
|
||||
@ -8,7 +6,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._functorch.utils import exposed_in
|
||||
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
@ -20,21 +18,9 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
try:
|
||||
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
||||
# once we are confident fx tracing works with dynamo.
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
||||
yield
|
||||
finally:
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def strict_mode(callable, operands):
|
||||
if torch._dynamo.is_compiling():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return strict_mode_op(callable, operands)
|
||||
|
||||
with _set_compilation_env():
|
||||
|
@ -96,7 +96,7 @@ def while_loop(cond_fn, body_fn, operands):
|
||||
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
|
||||
|
||||
"""
|
||||
if torch._dynamo.is_compiling():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return while_loop_op(cond_fn, body_fn, operands)
|
||||
|
||||
def _validate_input(cond_fn, body_fn, operands):
|
||||
|
@ -848,9 +848,13 @@ def classproperty(func):
|
||||
return _ClassPropertyDescriptor(func)
|
||||
|
||||
|
||||
# Whether we are compiling with torch.compile or not
|
||||
def is_compiling():
|
||||
return False
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
||||
TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
|
||||
"""
|
||||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def _functionalize_sync(t):
|
||||
|
@ -10,6 +10,8 @@ __all__ = [
|
||||
"disable",
|
||||
"cudagraph_mark_step_begin",
|
||||
"wrap_numpy",
|
||||
"is_compiling",
|
||||
"is_dynamo_compiling",
|
||||
]
|
||||
|
||||
def compile(*args, **kwargs):
|
||||
@ -149,3 +151,43 @@ def wrap_numpy(fn):
|
||||
"""
|
||||
from torch._dynamo.external_utils import wrap_numpy as wrap
|
||||
return wrap(fn)
|
||||
|
||||
_is_compiling_flag: bool = False
|
||||
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
|
||||
|
||||
Note that there are 2 other related flags that should deprecated eventually:
|
||||
* torch._dynamo.external_utils.is_compiling()
|
||||
* torch._utils.is_compiling()
|
||||
|
||||
Example::
|
||||
|
||||
>>> def forward(self, x):
|
||||
>>> if not torch.compiler.is_compiling():
|
||||
>>> ...logic that is not needed in a compiled/traced graph...
|
||||
>>>
|
||||
>>> ...rest of the function...
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return False
|
||||
else:
|
||||
return _is_compiling_flag
|
||||
|
||||
def is_dynamo_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether a graph is traced via TorchDynamo.
|
||||
|
||||
It's stricter than is_compiling() flag, as it would only be set to True when
|
||||
TorchDynamo is used.
|
||||
|
||||
Example::
|
||||
|
||||
>>> def forward(self, x):
|
||||
>>> if not torch.compiler.is_dynamo_compiling():
|
||||
>>> ...logic that is not needed in a TorchDynamo-traced graph...
|
||||
>>>
|
||||
>>> ...rest of the function...
|
||||
"""
|
||||
return False
|
||||
|
@ -29,9 +29,7 @@ if torch._running_with_deploy():
|
||||
|
||||
else:
|
||||
try:
|
||||
from torch._dynamo.external_utils import (
|
||||
is_compiling as is_torchdynamo_compiling,
|
||||
)
|
||||
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
|
||||
|
@ -419,12 +419,22 @@ def _export_non_strict(
|
||||
grad_safe_guard = (
|
||||
AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext()
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _compiling_state_context():
|
||||
old_value = torch.compiler._is_compiling_flag
|
||||
try:
|
||||
torch.compiler._is_compiling_flag = True
|
||||
yield
|
||||
finally:
|
||||
torch.compiler._is_compiling_flag = old_value
|
||||
|
||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
with torch.nn.utils.stateless._reparametrize_module(
|
||||
mod, fake_params_buffers
|
||||
), grad_safe_guard, _ignore_backend_decomps(): # type: ignore[attr-defined]
|
||||
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined]
|
||||
gm, graph_signature = transform(aot_export_module)(
|
||||
mod,
|
||||
fake_args,
|
||||
|
@ -74,7 +74,7 @@ class InterpreterModule(torch.nn.Module):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
|
||||
if torch._dynamo.is_compiling():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
# Dynamo cannot trace through torch.fx.Interpreter, so fall back to
|
||||
# GraphModule codegen in this instance.
|
||||
return self.graph_module(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user