[torch.export] Support is_compiling() flag for non-strict mode (#119602)

Summary: In non-strict mode of torch.export() we didn't set those `is_compiling()` to `True` which is needed by some models.

Test Plan: Unit tests and manual testing.

Differential Revision: D53624452

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119602
Approved by: https://github.com/suo
This commit is contained in:
Oleg Khabinov
2024-02-29 05:52:51 +00:00
committed by PyTorch MergeBot
parent 0a46102b37
commit 4b18ab869f
15 changed files with 142 additions and 31 deletions

View File

@ -20,3 +20,5 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`.
list_backends
disable
cudagraph_mark_step_begin
is_compiling
is_dynamo_compiling

View File

@ -26,6 +26,8 @@ disable compilation are listed in the following table:
"``torch._dynamo.disallow_in_graph``", "Disallows the marked op in the TorchDynamo graph. TorchDynamo causes graph break, and runs the op in the eager (no compile) mode.\n\nThis is suitable for the ops, while ``torch.compiler.disable`` is suitable for decorating functions.", "This API is excellent for both debugging and unblocking if a custom op like ``torch.ops.fbgemm.*`` is causing issues with the ``torch.compile`` function."
"``torch.compile.allow_in_graph``", "The annotated callable goes as is in the TorchDynamo graph. For example, a black-box for TorchDynamo Dynamo.\n\nNote that AOT Autograd will trace through it, so the ``allow_in_graph`` is only a Dynamo-level concept.", "This API is useful for portions of the model which have known TorchDynamo hard-to-support features, like hooks or ``autograd.Function``. However, each usage of ``allow_in_graph`` **must be carefully screened** (no graph breaks, no closures)."
"``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``."
"``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()."
"``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used."
``torch.compiler.disable``
~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -5908,16 +5908,35 @@ def fn():
self.assertEqual(cnt.frame_count, 0)
def test_is_compiling(self):
def f():
def f1():
if torch._dynamo.is_compiling():
return torch.ones(2, 2)
else:
return torch.zeros(2, 2)
opt_f = torch._dynamo.optimize("eager")(f)
def f2():
if torch._utils.is_compiling():
return torch.ones(2, 2)
else:
return torch.zeros(2, 2)
self.assertEqual(f(), torch.zeros(2, 2))
self.assertEqual(opt_f(), torch.ones(2, 2))
def f3():
if torch.compiler.is_compiling():
return torch.ones(2, 2)
else:
return torch.zeros(2, 2)
def f4():
if torch.compiler.is_dynamo_compiling():
return torch.ones(2, 2)
else:
return torch.zeros(2, 2)
for f in [f1, f2, f3, f4]:
opt_f = torch._dynamo.optimize("eager")(f)
self.assertEqual(f(), torch.zeros(2, 2))
self.assertEqual(opt_f(), torch.ones(2, 2))
def test_torch_generator_set_state(self):
def fn():

View File

@ -3404,6 +3404,37 @@ def forward(self, arg0_1, arg1_1, arg2_1):
# under a new FakeTensorMode.
ep = torch.export.export(m, (inp,))
def test_compiling_state(self):
class TestModule1(torch.nn.Module):
def forward(self, x):
if torch._dynamo.is_compiling():
return x * 2
else:
return x * 3
class TestModule2(torch.nn.Module):
def forward(self, x):
if torch._utils.is_compiling():
return x * 2
else:
return x * 3
class TestModule3(torch.nn.Module):
def forward(self, x):
if torch.compiler.is_compiling():
return x * 2
else:
return x * 3
for m in [TestModule1(), TestModule2(), TestModule3()]:
input = torch.randn(5)
ep_strict = export(m, (input,), strict=True)
ep_non_strict = export(m, (input,), strict=False)
self.assertTrue(torch.allclose(input * 3, m(input)))
self.assertTrue(torch.allclose(input * 2, ep_strict(input)))
self.assertTrue(torch.allclose(input * 2, ep_non_strict(input)))
def test_user_input_and_buffer_mutation(self):
class MyModule(torch.nn.Module):
def __init__(self):

View File

@ -12,7 +12,18 @@ except ModuleNotFoundError:
def is_compiling() -> bool:
return False
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
If need to check specifically that TorchDynamo is used, then use
torch.compiler.is_dynamo_compiling().
TODO(khabinov): we should deprecate this function and use one of these two:
* torch.compiler.is_compiling(),
* torch.compiler.is_dynamo_compiling().
It will depend on the context where to use what.
"""
return torch.compiler.is_compiling()
def wrap_inline(fn):

View File

@ -101,6 +101,8 @@ manual_torch_name_rule_map = {
"torch.overrides.get_default_nowrap_functions": TorchInGraphFunctionVariable,
"torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable,
"torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
"torch.autograd._profiler_enabled": SkipFunctionVariable,
# We graph break on RNG state setters or getters like
# `torch.get_rng_state` or `torch.set_rng_state`. These functions

View File

@ -110,6 +110,8 @@ tracing_state_functions = {
torch.onnx.is_in_onnx_export: False,
torch._dynamo.external_utils.is_compiling: True,
torch._utils.is_compiling: True,
torch.compiler.is_compiling: True,
torch.compiler.is_dynamo_compiling: True,
}
@ -304,6 +306,8 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if self.value in (
torch._utils.is_compiling,
torch._dynamo.external_utils.is_compiling,
torch.compiler.is_compiling,
torch.compiler.is_dynamo_compiling,
):
tx.mark_inconsistent_side_effects()
return ConstantVariable.create(tracing_state_functions[self.value])

View File

@ -103,7 +103,7 @@ def cond(pred, true_fn, false_fn, operands):
"""
if torch._dynamo.is_compiling():
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
def _validate_input(pred, true_fn, false_fn, operands):

View File

@ -1,5 +1,3 @@
from contextlib import contextmanager
import torch
import torch._subclasses.functional_tensor
@ -8,7 +6,7 @@ import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@ -20,21 +18,9 @@ from torch.fx.experimental.proxy_tensor import (
from torch.utils._python_dispatch import _get_current_dispatch_mode
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
@exposed_in("torch")
def strict_mode(callable, operands):
if torch._dynamo.is_compiling():
if torch.compiler.is_dynamo_compiling():
return strict_mode_op(callable, operands)
with _set_compilation_env():

View File

@ -96,7 +96,7 @@ def while_loop(cond_fn, body_fn, operands):
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
"""
if torch._dynamo.is_compiling():
if torch.compiler.is_dynamo_compiling():
return while_loop_op(cond_fn, body_fn, operands)
def _validate_input(cond_fn, body_fn, operands):

View File

@ -848,9 +848,13 @@ def classproperty(func):
return _ClassPropertyDescriptor(func)
# Whether we are compiling with torch.compile or not
def is_compiling():
return False
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
"""
return torch.compiler.is_compiling()
def _functionalize_sync(t):

View File

@ -10,6 +10,8 @@ __all__ = [
"disable",
"cudagraph_mark_step_begin",
"wrap_numpy",
"is_compiling",
"is_dynamo_compiling",
]
def compile(*args, **kwargs):
@ -149,3 +151,43 @@ def wrap_numpy(fn):
"""
from torch._dynamo.external_utils import wrap_numpy as wrap
return wrap(fn)
_is_compiling_flag: bool = False
def is_compiling() -> bool:
"""
Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
Note that there are 2 other related flags that should deprecated eventually:
* torch._dynamo.external_utils.is_compiling()
* torch._utils.is_compiling()
Example::
>>> def forward(self, x):
>>> if not torch.compiler.is_compiling():
>>> ...logic that is not needed in a compiled/traced graph...
>>>
>>> ...rest of the function...
"""
if torch.jit.is_scripting():
return False
else:
return _is_compiling_flag
def is_dynamo_compiling() -> bool:
"""
Indicates whether a graph is traced via TorchDynamo.
It's stricter than is_compiling() flag, as it would only be set to True when
TorchDynamo is used.
Example::
>>> def forward(self, x):
>>> if not torch.compiler.is_dynamo_compiling():
>>> ...logic that is not needed in a TorchDynamo-traced graph...
>>>
>>> ...rest of the function...
"""
return False

View File

@ -29,9 +29,7 @@ if torch._running_with_deploy():
else:
try:
from torch._dynamo.external_utils import (
is_compiling as is_torchdynamo_compiling,
)
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception:
warnings.warn(
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"

View File

@ -419,12 +419,22 @@ def _export_non_strict(
grad_safe_guard = (
AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext()
)
@contextmanager
def _compiling_state_context():
old_value = torch.compiler._is_compiling_flag
try:
torch.compiler._is_compiling_flag = True
yield
finally:
torch.compiler._is_compiling_flag = old_value
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module(
mod, fake_params_buffers
), grad_safe_guard, _ignore_backend_decomps(): # type: ignore[attr-defined]
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined]
gm, graph_signature = transform(aot_export_module)(
mod,
fake_args,

View File

@ -74,7 +74,7 @@ class InterpreterModule(torch.nn.Module):
def forward(self, *args, **kwargs):
assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
if torch._dynamo.is_compiling():
if torch.compiler.is_dynamo_compiling():
# Dynamo cannot trace through torch.fx.Interpreter, so fall back to
# GraphModule codegen in this instance.
return self.graph_module(*args, **kwargs)