Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"

This reverts commit 0e7e61f7cec82a43f2de52b83eff152d703be7a3.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
This commit is contained in:
PyTorch MergeBot
2024-08-07 00:05:20 +00:00
parent e98eac76b3
commit cbee9c1fd2
29 changed files with 83 additions and 87 deletions

View File

@ -12,12 +12,12 @@ _variable_2 = 0
def user_function():
return torch.compiler.is_compiling()
return torch._utils.is_compiling()
def user_generator():
for _ in range(1):
yield torch.compiler.is_compiling()
yield torch._utils.is_compiling()
return
@ -38,7 +38,7 @@ class MyModule(torch.nn.Module):
global _variable, _variable_2
if self.mode == 1:
if torch.compiler.is_compiling():
if torch._utils.is_compiling():
_variable += 1
else:
_variable_2 += 1
@ -46,7 +46,7 @@ class MyModule(torch.nn.Module):
if user_function():
_variable += 1
elif self.mode == 3:
lambda_f = lambda: torch.compiler.is_compiling() # noqa: E731
lambda_f = lambda: torch._utils.is_compiling() # noqa: E731
if lambda_f():
_variable += 1
elif self.mode == 4:
@ -163,7 +163,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
def test_do_not_skip_side_effects(self):
# https://github.com/pytorch/pytorch/issues/110765
# By invoking torch.compiler.is_compiling(),
# By invoking torch._utils.is_compiling(),
# there may be side-effects inconsistent with eager when
# compiling. Thus we force dynamo to commit the graph,
# even if it does not perform any tensor operation

View File

@ -1281,7 +1281,7 @@ class TestCompileTorchbind(TestCase):
f(_empty_tensor_queue(), x),
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
)
if not torch.compiler.is_compiling() and backend == "eager":
if not torch._dynamo.is_compiling() and backend == "eager":
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\

View File

@ -278,7 +278,7 @@ class NoChangeTestCase(TestCase):
# Test to repro issue with fx_graph_cse when
# hash((primals_2, 1.0)) == hash((primals_2, 1))
if torch.compiler.is_compiling():
if torch._dynamo.is_compiling():
self.skipTest("Unsupported if test run is compiled")
def f(inpt, osize):

View File

@ -110,13 +110,13 @@ def init_fake_distributed(device="cpu"):
def init_module_bw_hooks(allow_eager):
def bw_pre_hook(mod, gO):
assert allow_eager or torch.compiler.is_compiling()
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_pre.add_(1)
return (torch.sin(gO[0] + 1.2),)
def bw_post_hook(mod, gI, gO):
assert allow_eager or torch.compiler.is_compiling()
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_post.add_(1)
return (torch.sin(gI[0] + 3.4),)

View File

@ -288,7 +288,7 @@ class TestOptimRenewed(TestCase):
inpt = torch.randn(5, device=device, dtype=dtype)
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]

View File

@ -10,6 +10,7 @@ from . import trace_rules, variables
from .comptime import comptime
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage
from .external_utils import is_compiling
if TYPE_CHECKING:
@ -291,7 +292,7 @@ def mark_static(t, index=None):
Unlike mark_dynamic, this can be done inside a graph, in which case it
induces specialization on the tensor.
"""
if torch.compiler.is_compiling():
if is_compiling():
if index is None:
for s in t.size():
comptime.force_static(s)

View File

@ -3,7 +3,6 @@
import functools
from typing import List
from typing_extensions import deprecated
import torch
import torch.utils._pytree as pytree
@ -15,10 +14,6 @@ except ModuleNotFoundError:
np = None # type: ignore[assignment]
@deprecated(
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
category=FutureWarning,
)
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().

View File

@ -191,7 +191,7 @@ def vmap(
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
from torch.compiler import is_compiling
from torch._dynamo import is_compiling
_check_randomness_arg(randomness)
if not (chunk_size is None or chunk_size > 0):
@ -393,7 +393,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
"""
# To avoid cyclical dependency.
import torch._functorch.eager_transforms as eager_transforms
from torch.compiler import is_compiling
from torch._dynamo import is_compiling
def wrapper(*args, **kwargs):
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
@ -435,8 +435,8 @@ def grad_and_value(
See :func:`grad` for examples
"""
from torch._dynamo import is_compiling
from torch._functorch import eager_transforms
from torch.compiler import is_compiling
def wrapper(*args, **kwargs):
return eager_transforms.grad_and_value_impl(

View File

@ -764,7 +764,7 @@ def jacrev(
# Dynamo does not support HOP composition if their inner function is
# annotated with @functools.wraps(...). We circumvent this issue by applying
# wraps only if we're not tracing with dynamo.
if not torch.compiler.is_compiling():
if not torch._dynamo.is_compiling():
wrapper_fn = wraps(func)(wrapper_fn)
return wrapper_fn
@ -1344,7 +1344,7 @@ def jacfwd(
# Dynamo does not support HOP composition if their inner function is
# annotated with @functools.wraps(...). We circumvent this issue by applying
# wraps only if we're not tracing with dynamo.
if not torch.compiler.is_compiling():
if not torch._dynamo.is_compiling():
wrapper_fn = wraps(func)(wrapper_fn)
return wrapper_fn

View File

@ -74,7 +74,7 @@ def associative_scan(
assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
if not torch.compiler.is_compiling():
if not torch._dynamo.is_compiling():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(associative_scan, fullgraph=True)(
combine_fn, input, dim

View File

@ -8,7 +8,7 @@ import traceback
import warnings
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Generic, List, Optional
from typing_extensions import deprecated, ParamSpec
from typing_extensions import ParamSpec
import torch
@ -868,10 +868,6 @@ def classproperty(func):
return _ClassPropertyDescriptor(func)
@deprecated(
"`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
category=FutureWarning,
)
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().

View File

@ -6,6 +6,7 @@ from enum import auto, Enum
from typing import Any, cast, List, Optional
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
@ -114,7 +115,6 @@ def _from_local_no_grad(
This method is similar to ``DTensor.from_local()`` except that in eager mode
it avoids some CPU overhead by avoiding default args and not being differentiable.
"""
import torch._dynamo.compiled_autograd as ca
if not ca.compiled_autograd_enabled:
return DTensor(
@ -124,13 +124,14 @@ def _from_local_no_grad(
sharding_spec,
requires_grad=local_tensor.requires_grad,
)
return DTensor.from_local(
local_tensor,
sharding_spec.mesh,
sharding_spec.placements,
shape=sharding_spec.shape,
stride=sharding_spec.stride,
)
else:
return DTensor.from_local(
local_tensor,
sharding_spec.mesh,
sharding_spec.placements,
shape=sharding_spec.shape,
stride=sharding_spec.stride,
)
def _to_dtype_if_needed(

View File

@ -87,7 +87,7 @@ def fp16_compress_hook(
decompressed_tensor.copy_(value)
return decompressed_tensor
if torch.compiler.is_compiling():
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
@ -136,7 +136,7 @@ def bf16_compress_hook(
decompressed_tensor.copy_(value)
return decompressed_tensor
if torch.compiler.is_compiling():
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)

View File

@ -2,20 +2,22 @@
import warnings
from typing import Tuple, Union
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.device_mesh import _mesh_resources
try:
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
except Exception:
def is_torchdynamo_compiling(): # type: ignore[misc]
return False
LayoutsType = Union[Placement, Tuple[Placement, ...]]
def is_torchdynamo_compiling() -> bool:
# Use local function to avoid circular imports
return torch.compiler.is_compiling()
def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
"""
Inject common validation logics for `_prepare_input` funcs via this decorator.

View File

@ -1484,7 +1484,7 @@ class DistributedDataParallel(Module, Joinable):
def _should_disable_cpp_reducer(self) -> bool:
return self._use_python_reducer and (
torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
)
def _pre_forward(self, *inputs, **kwargs):
@ -1497,7 +1497,7 @@ class DistributedDataParallel(Module, Joinable):
h.remove()
self._accum_grad_hooks.clear()
if not self._lazy_init_ran and not torch.compiler.is_compiling():
if not self._lazy_init_ran and not torch._utils.is_compiling():
self._lazy_init()
if self._delay_all_reduce_all_params:

View File

@ -423,7 +423,7 @@ def adafactor(
See :class:`~torch.optim.Adafactor` for details.
"""
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -259,7 +259,7 @@ def _single_tensor_adadelta(
has_complex: bool,
):
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -315,7 +315,7 @@ def _multi_tensor_adadelta(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -347,7 +347,7 @@ def _multi_tensor_adadelta(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -418,7 +418,7 @@ def adadelta(
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -451,7 +451,7 @@ def _multi_tensor_adagrad(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)

View File

@ -361,7 +361,7 @@ def _single_tensor_adam(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
@ -474,7 +474,7 @@ def _multi_tensor_adam(
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -521,7 +521,7 @@ def _multi_tensor_adam(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -754,7 +754,7 @@ def adam(
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -248,7 +248,7 @@ def _single_tensor_adamax(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
@ -320,7 +320,7 @@ def _multi_tensor_adamax(
return
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -352,7 +352,7 @@ def _multi_tensor_adamax(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -429,7 +429,7 @@ def adamax(
See :class:`~torch.optim.Adamax` for details.
"""
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -362,7 +362,7 @@ def _single_tensor_adamw(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
@ -475,7 +475,7 @@ def _multi_tensor_adamw(
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -521,7 +521,7 @@ def _multi_tensor_adamw(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -739,7 +739,7 @@ def adamw(
See :class:`~torch.optim.AdamW` for details.
"""
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -219,7 +219,7 @@ def _single_tensor_asgd(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type
@ -292,7 +292,7 @@ def _multi_tensor_asgd(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -326,7 +326,7 @@ def _multi_tensor_asgd(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)

View File

@ -310,7 +310,7 @@ def _single_tensor_nadam(
exp_avg_sq = torch.view_as_real(exp_avg_sq)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == mu_product.device.type == step_t.device.type
@ -396,7 +396,7 @@ def _multi_tensor_nadam(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -430,7 +430,7 @@ def _multi_tensor_nadam(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)

View File

@ -26,6 +26,7 @@ from typing_extensions import ParamSpec, Self, TypeAlias
import torch
import torch.utils.hooks as hooks
from torch._utils import is_compiling
from torch.utils._foreach_utils import (
_get_foreach_kernels_supported_devices,
_get_fused_kernels_supported_devices,
@ -99,14 +100,14 @@ def _use_grad_for_differentiable(func):
def _get_value(x):
# item is significantly faster than a cpu tensor in eager mode
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
if not torch.jit.is_scripting() and is_compiling():
return x
else:
return x.item() if isinstance(x, torch.Tensor) else x
def _stack_if_compiling(x):
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
if not torch.jit.is_scripting() and is_compiling():
return torch.stack(x)
else:
return x
@ -138,7 +139,7 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None):
# the capturable flag. If capturable=True, this is not a problem.
@functools.wraps(func)
def maybe_fallback(*args, **kwargs):
if torch.compiler.is_compiling() and (
if is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
@ -413,7 +414,7 @@ class Optimizer:
# Thus, when compiling, inductor will determine if cudagraphs
# can be enabled based on whether there is input mutation or CPU tensors.
if (
not torch.compiler.is_compiling()
not is_compiling()
and torch.backends.cuda.is_built()
and torch.cuda.is_available()
):
@ -501,7 +502,7 @@ class Optimizer:
Skips this step if we are compiling since this will occur during inductor lowering.
"""
if torch.compiler.is_compiling():
if is_compiling():
return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
else:
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]

View File

@ -276,7 +276,7 @@ def _single_tensor_radam(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
@ -374,7 +374,7 @@ def _multi_tensor_radam(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
@ -398,7 +398,7 @@ def _multi_tensor_radam(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)

View File

@ -283,7 +283,7 @@ def _single_tensor_rmsprop(
step = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step.device.type
@ -356,7 +356,7 @@ def _multi_tensor_rmsprop(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert all(
p.device.type == step.device.type
@ -392,7 +392,7 @@ def _multi_tensor_rmsprop(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -475,7 +475,7 @@ def rmsprop(
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -243,7 +243,7 @@ def _single_tensor_rprop(
step = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step.device.type
@ -309,7 +309,7 @@ def _multi_tensor_rprop(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert all(
p.device.type == step.device.type
@ -331,7 +331,7 @@ def _multi_tensor_rprop(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -421,7 +421,7 @@ def rprop(
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch.compiler.is_compiling() and not all(
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(

View File

@ -434,7 +434,7 @@ def _multi_tensor_sgd(
if not device_has_sparse_grad:
# handle internal item() call if lr is a tensor
if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
grads_x_lr = torch._foreach_mul(device_grads, -lr)
torch._foreach_add_(device_params, grads_x_lr)
else:

View File

@ -567,7 +567,7 @@ class OpCheckMode(TorchFunctionMode):
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or torch.compiler.is_compiling()
or torch._dynamo.is_compiling()
):
return func(*args, **kwargs)
# Pre-existing code may not use the .default overload. If we see an