mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Deprecate torch._utils.is_compiling()
and torch._dynamo.external_utils.is_compiling()
(#127690)"
This reverts commit 0e7e61f7cec82a43f2de52b83eff152d703be7a3. Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
This commit is contained in:
@ -12,12 +12,12 @@ _variable_2 = 0
|
||||
|
||||
|
||||
def user_function():
|
||||
return torch.compiler.is_compiling()
|
||||
return torch._utils.is_compiling()
|
||||
|
||||
|
||||
def user_generator():
|
||||
for _ in range(1):
|
||||
yield torch.compiler.is_compiling()
|
||||
yield torch._utils.is_compiling()
|
||||
return
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class MyModule(torch.nn.Module):
|
||||
global _variable, _variable_2
|
||||
|
||||
if self.mode == 1:
|
||||
if torch.compiler.is_compiling():
|
||||
if torch._utils.is_compiling():
|
||||
_variable += 1
|
||||
else:
|
||||
_variable_2 += 1
|
||||
@ -46,7 +46,7 @@ class MyModule(torch.nn.Module):
|
||||
if user_function():
|
||||
_variable += 1
|
||||
elif self.mode == 3:
|
||||
lambda_f = lambda: torch.compiler.is_compiling() # noqa: E731
|
||||
lambda_f = lambda: torch._utils.is_compiling() # noqa: E731
|
||||
if lambda_f():
|
||||
_variable += 1
|
||||
elif self.mode == 4:
|
||||
@ -163,7 +163,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_do_not_skip_side_effects(self):
|
||||
# https://github.com/pytorch/pytorch/issues/110765
|
||||
|
||||
# By invoking torch.compiler.is_compiling(),
|
||||
# By invoking torch._utils.is_compiling(),
|
||||
# there may be side-effects inconsistent with eager when
|
||||
# compiling. Thus we force dynamo to commit the graph,
|
||||
# even if it does not perform any tensor operation
|
||||
|
@ -1281,7 +1281,7 @@ class TestCompileTorchbind(TestCase):
|
||||
f(_empty_tensor_queue(), x),
|
||||
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
|
||||
)
|
||||
if not torch.compiler.is_compiling() and backend == "eager":
|
||||
if not torch._dynamo.is_compiling() and backend == "eager":
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
|
@ -278,7 +278,7 @@ class NoChangeTestCase(TestCase):
|
||||
# Test to repro issue with fx_graph_cse when
|
||||
# hash((primals_2, 1.0)) == hash((primals_2, 1))
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
if torch._dynamo.is_compiling():
|
||||
self.skipTest("Unsupported if test run is compiled")
|
||||
|
||||
def f(inpt, osize):
|
||||
|
@ -110,13 +110,13 @@ def init_fake_distributed(device="cpu"):
|
||||
|
||||
def init_module_bw_hooks(allow_eager):
|
||||
def bw_pre_hook(mod, gO):
|
||||
assert allow_eager or torch.compiler.is_compiling()
|
||||
assert allow_eager or torch._dynamo.is_compiling()
|
||||
assert mod.weight.size() == (10, 10)
|
||||
mod.hook_count_pre.add_(1)
|
||||
return (torch.sin(gO[0] + 1.2),)
|
||||
|
||||
def bw_post_hook(mod, gI, gO):
|
||||
assert allow_eager or torch.compiler.is_compiling()
|
||||
assert allow_eager or torch._dynamo.is_compiling()
|
||||
assert mod.weight.size() == (10, 10)
|
||||
mod.hook_count_post.add_(1)
|
||||
return (torch.sin(gI[0] + 3.4),)
|
||||
|
@ -288,7 +288,7 @@ class TestOptimRenewed(TestCase):
|
||||
inpt = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
|
||||
lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01
|
||||
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
|
||||
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
|
||||
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
|
||||
|
||||
|
@ -10,6 +10,7 @@ from . import trace_rules, variables
|
||||
from .comptime import comptime
|
||||
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
||||
from .exc import IncorrectUsage
|
||||
from .external_utils import is_compiling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -291,7 +292,7 @@ def mark_static(t, index=None):
|
||||
Unlike mark_dynamic, this can be done inside a graph, in which case it
|
||||
induces specialization on the tensor.
|
||||
"""
|
||||
if torch.compiler.is_compiling():
|
||||
if is_compiling():
|
||||
if index is None:
|
||||
for s in t.size():
|
||||
comptime.force_static(s)
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
import functools
|
||||
from typing import List
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -15,10 +14,6 @@ except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
@ -191,7 +191,7 @@ def vmap(
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
sequences out of the box.
|
||||
"""
|
||||
from torch.compiler import is_compiling
|
||||
from torch._dynamo import is_compiling
|
||||
|
||||
_check_randomness_arg(randomness)
|
||||
if not (chunk_size is None or chunk_size > 0):
|
||||
@ -393,7 +393,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
||||
"""
|
||||
# To avoid cyclical dependency.
|
||||
import torch._functorch.eager_transforms as eager_transforms
|
||||
from torch.compiler import is_compiling
|
||||
from torch._dynamo import is_compiling
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
|
||||
@ -435,8 +435,8 @@ def grad_and_value(
|
||||
|
||||
See :func:`grad` for examples
|
||||
"""
|
||||
from torch._dynamo import is_compiling
|
||||
from torch._functorch import eager_transforms
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_and_value_impl(
|
||||
|
@ -764,7 +764,7 @@ def jacrev(
|
||||
# Dynamo does not support HOP composition if their inner function is
|
||||
# annotated with @functools.wraps(...). We circumvent this issue by applying
|
||||
# wraps only if we're not tracing with dynamo.
|
||||
if not torch.compiler.is_compiling():
|
||||
if not torch._dynamo.is_compiling():
|
||||
wrapper_fn = wraps(func)(wrapper_fn)
|
||||
|
||||
return wrapper_fn
|
||||
@ -1344,7 +1344,7 @@ def jacfwd(
|
||||
# Dynamo does not support HOP composition if their inner function is
|
||||
# annotated with @functools.wraps(...). We circumvent this issue by applying
|
||||
# wraps only if we're not tracing with dynamo.
|
||||
if not torch.compiler.is_compiling():
|
||||
if not torch._dynamo.is_compiling():
|
||||
wrapper_fn = wraps(func)(wrapper_fn)
|
||||
|
||||
return wrapper_fn
|
||||
|
@ -74,7 +74,7 @@ def associative_scan(
|
||||
assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
|
||||
assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
|
||||
|
||||
if not torch.compiler.is_compiling():
|
||||
if not torch._dynamo.is_compiling():
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(associative_scan, fullgraph=True)(
|
||||
combine_fn, input, dim
|
||||
|
@ -8,7 +8,7 @@ import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, DefaultDict, Generic, List, Optional
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
@ -868,10 +868,6 @@ def classproperty(func):
|
||||
return _ClassPropertyDescriptor(func)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
@ -6,6 +6,7 @@ from enum import auto, Enum
|
||||
from typing import Any, cast, List, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo.compiled_autograd as ca
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.contract import _get_registry
|
||||
@ -114,7 +115,6 @@ def _from_local_no_grad(
|
||||
This method is similar to ``DTensor.from_local()`` except that in eager mode
|
||||
it avoids some CPU overhead by avoiding default args and not being differentiable.
|
||||
"""
|
||||
import torch._dynamo.compiled_autograd as ca
|
||||
|
||||
if not ca.compiled_autograd_enabled:
|
||||
return DTensor(
|
||||
@ -124,13 +124,14 @@ def _from_local_no_grad(
|
||||
sharding_spec,
|
||||
requires_grad=local_tensor.requires_grad,
|
||||
)
|
||||
return DTensor.from_local(
|
||||
local_tensor,
|
||||
sharding_spec.mesh,
|
||||
sharding_spec.placements,
|
||||
shape=sharding_spec.shape,
|
||||
stride=sharding_spec.stride,
|
||||
)
|
||||
else:
|
||||
return DTensor.from_local(
|
||||
local_tensor,
|
||||
sharding_spec.mesh,
|
||||
sharding_spec.placements,
|
||||
shape=sharding_spec.shape,
|
||||
stride=sharding_spec.stride,
|
||||
)
|
||||
|
||||
|
||||
def _to_dtype_if_needed(
|
||||
|
@ -87,7 +87,7 @@ def fp16_compress_hook(
|
||||
decompressed_tensor.copy_(value)
|
||||
return decompressed_tensor
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
if torch._utils.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
)
|
||||
@ -136,7 +136,7 @@ def bf16_compress_hook(
|
||||
decompressed_tensor.copy_(value)
|
||||
return decompressed_tensor
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
if torch._utils.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
)
|
||||
|
@ -2,20 +2,22 @@
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
|
||||
try:
|
||||
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
|
||||
except Exception:
|
||||
|
||||
def is_torchdynamo_compiling(): # type: ignore[misc]
|
||||
return False
|
||||
|
||||
|
||||
LayoutsType = Union[Placement, Tuple[Placement, ...]]
|
||||
|
||||
|
||||
def is_torchdynamo_compiling() -> bool:
|
||||
# Use local function to avoid circular imports
|
||||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
|
||||
"""
|
||||
Inject common validation logics for `_prepare_input` funcs via this decorator.
|
||||
|
@ -1484,7 +1484,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
def _should_disable_cpp_reducer(self) -> bool:
|
||||
return self._use_python_reducer and (
|
||||
torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer
|
||||
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
|
||||
)
|
||||
|
||||
def _pre_forward(self, *inputs, **kwargs):
|
||||
@ -1497,7 +1497,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
h.remove()
|
||||
self._accum_grad_hooks.clear()
|
||||
|
||||
if not self._lazy_init_ran and not torch.compiler.is_compiling():
|
||||
if not self._lazy_init_ran and not torch._utils.is_compiling():
|
||||
self._lazy_init()
|
||||
|
||||
if self._delay_all_reduce_all_params:
|
||||
|
@ -423,7 +423,7 @@ def adafactor(
|
||||
|
||||
See :class:`~torch.optim.Adafactor` for details.
|
||||
"""
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -259,7 +259,7 @@ def _single_tensor_adadelta(
|
||||
has_complex: bool,
|
||||
):
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -315,7 +315,7 @@ def _multi_tensor_adadelta(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -347,7 +347,7 @@ def _multi_tensor_adadelta(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -418,7 +418,7 @@ def adadelta(
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -451,7 +451,7 @@ def _multi_tensor_adagrad(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -361,7 +361,7 @@ def _single_tensor_adam(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -474,7 +474,7 @@ def _multi_tensor_adam(
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -521,7 +521,7 @@ def _multi_tensor_adam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -754,7 +754,7 @@ def adam(
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -248,7 +248,7 @@ def _single_tensor_adamax(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -320,7 +320,7 @@ def _multi_tensor_adamax(
|
||||
return
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -352,7 +352,7 @@ def _multi_tensor_adamax(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -429,7 +429,7 @@ def adamax(
|
||||
See :class:`~torch.optim.Adamax` for details.
|
||||
"""
|
||||
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -362,7 +362,7 @@ def _single_tensor_adamw(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -475,7 +475,7 @@ def _multi_tensor_adamw(
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -521,7 +521,7 @@ def _multi_tensor_adamw(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -739,7 +739,7 @@ def adamw(
|
||||
|
||||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -219,7 +219,7 @@ def _single_tensor_asgd(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type
|
||||
@ -292,7 +292,7 @@ def _multi_tensor_asgd(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -326,7 +326,7 @@ def _multi_tensor_asgd(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -310,7 +310,7 @@ def _single_tensor_nadam(
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == mu_product.device.type == step_t.device.type
|
||||
@ -396,7 +396,7 @@ def _multi_tensor_nadam(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -430,7 +430,7 @@ def _multi_tensor_nadam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ from typing_extensions import ParamSpec, Self, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.utils.hooks as hooks
|
||||
from torch._utils import is_compiling
|
||||
from torch.utils._foreach_utils import (
|
||||
_get_foreach_kernels_supported_devices,
|
||||
_get_fused_kernels_supported_devices,
|
||||
@ -99,14 +100,14 @@ def _use_grad_for_differentiable(func):
|
||||
|
||||
def _get_value(x):
|
||||
# item is significantly faster than a cpu tensor in eager mode
|
||||
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
|
||||
if not torch.jit.is_scripting() and is_compiling():
|
||||
return x
|
||||
else:
|
||||
return x.item() if isinstance(x, torch.Tensor) else x
|
||||
|
||||
|
||||
def _stack_if_compiling(x):
|
||||
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
|
||||
if not torch.jit.is_scripting() and is_compiling():
|
||||
return torch.stack(x)
|
||||
else:
|
||||
return x
|
||||
@ -138,7 +139,7 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None):
|
||||
# the capturable flag. If capturable=True, this is not a problem.
|
||||
@functools.wraps(func)
|
||||
def maybe_fallback(*args, **kwargs):
|
||||
if torch.compiler.is_compiling() and (
|
||||
if is_compiling() and (
|
||||
not kwargs.get("capturable", False)
|
||||
and has_state_steps
|
||||
and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
|
||||
@ -413,7 +414,7 @@ class Optimizer:
|
||||
# Thus, when compiling, inductor will determine if cudagraphs
|
||||
# can be enabled based on whether there is input mutation or CPU tensors.
|
||||
if (
|
||||
not torch.compiler.is_compiling()
|
||||
not is_compiling()
|
||||
and torch.backends.cuda.is_built()
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
@ -501,7 +502,7 @@ class Optimizer:
|
||||
|
||||
Skips this step if we are compiling since this will occur during inductor lowering.
|
||||
"""
|
||||
if torch.compiler.is_compiling():
|
||||
if is_compiling():
|
||||
return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
|
||||
else:
|
||||
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
|
||||
|
@ -276,7 +276,7 @@ def _single_tensor_radam(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -374,7 +374,7 @@ def _multi_tensor_radam(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -398,7 +398,7 @@ def _multi_tensor_radam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -283,7 +283,7 @@ def _single_tensor_rmsprop(
|
||||
step = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
@ -356,7 +356,7 @@ def _multi_tensor_rmsprop(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
@ -392,7 +392,7 @@ def _multi_tensor_rmsprop(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -475,7 +475,7 @@ def rmsprop(
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -243,7 +243,7 @@ def _single_tensor_rprop(
|
||||
step = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
@ -309,7 +309,7 @@ def _multi_tensor_rprop(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
@ -331,7 +331,7 @@ def _multi_tensor_rprop(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -421,7 +421,7 @@ def rprop(
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -434,7 +434,7 @@ def _multi_tensor_sgd(
|
||||
|
||||
if not device_has_sparse_grad:
|
||||
# handle internal item() call if lr is a tensor
|
||||
if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
|
||||
if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
|
||||
grads_x_lr = torch._foreach_mul(device_grads, -lr)
|
||||
torch._foreach_add_(device_params, grads_x_lr)
|
||||
else:
|
||||
|
@ -567,7 +567,7 @@ class OpCheckMode(TorchFunctionMode):
|
||||
if (
|
||||
torch.jit.is_tracing()
|
||||
or torch.jit.is_scripting()
|
||||
or torch.compiler.is_compiling()
|
||||
or torch._dynamo.is_compiling()
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
# Pre-existing code may not use the .default overload. If we see an
|
||||
|
Reference in New Issue
Block a user