mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Deprecate torch._utils.is_compiling()
and torch._dynamo.external_utils.is_compiling()
(#127690)
This PR is split from PR #126898. - #126898 ------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690 Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffb7a08921
commit
e84d1121ad
@ -12,12 +12,12 @@ _variable_2 = 0
|
||||
|
||||
|
||||
def user_function():
|
||||
return torch._utils.is_compiling()
|
||||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def user_generator():
|
||||
for _ in range(1):
|
||||
yield torch._utils.is_compiling()
|
||||
yield torch.compiler.is_compiling()
|
||||
return
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class MyModule(torch.nn.Module):
|
||||
global _variable, _variable_2
|
||||
|
||||
if self.mode == 1:
|
||||
if torch._utils.is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
_variable += 1
|
||||
else:
|
||||
_variable_2 += 1
|
||||
@ -46,7 +46,7 @@ class MyModule(torch.nn.Module):
|
||||
if user_function():
|
||||
_variable += 1
|
||||
elif self.mode == 3:
|
||||
lambda_f = lambda: torch._utils.is_compiling() # noqa: E731
|
||||
lambda_f = lambda: torch.compiler.is_compiling() # noqa: E731
|
||||
if lambda_f():
|
||||
_variable += 1
|
||||
elif self.mode == 4:
|
||||
@ -163,7 +163,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_do_not_skip_side_effects(self):
|
||||
# https://github.com/pytorch/pytorch/issues/110765
|
||||
|
||||
# By invoking torch._utils.is_compiling(),
|
||||
# By invoking torch.compiler.is_compiling(),
|
||||
# there may be side-effects inconsistent with eager when
|
||||
# compiling. Thus we force dynamo to commit the graph,
|
||||
# even if it does not perform any tensor operation
|
||||
|
@ -1315,7 +1315,7 @@ class TestCompileTorchbind(TestCase):
|
||||
f(_empty_tensor_queue(), x),
|
||||
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
|
||||
)
|
||||
if not torch._dynamo.is_compiling() and backend == "eager":
|
||||
if not torch.compiler.is_compiling() and backend == "eager":
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
|
@ -278,7 +278,7 @@ class NoChangeTestCase(TestCase):
|
||||
# Test to repro issue with fx_graph_cse when
|
||||
# hash((primals_2, 1.0)) == hash((primals_2, 1))
|
||||
|
||||
if torch._dynamo.is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
self.skipTest("Unsupported if test run is compiled")
|
||||
|
||||
def f(inpt, osize):
|
||||
|
@ -91,13 +91,13 @@ def init_fake_distributed(device="cpu"):
|
||||
|
||||
def init_module_bw_hooks(allow_eager):
|
||||
def bw_pre_hook(mod, gO):
|
||||
assert allow_eager or torch._dynamo.is_compiling()
|
||||
assert allow_eager or torch.compiler.is_compiling()
|
||||
assert mod.weight.size() == (10, 10)
|
||||
mod.hook_count_pre.add_(1)
|
||||
return (torch.sin(gO[0] + 1.2),)
|
||||
|
||||
def bw_post_hook(mod, gI, gO):
|
||||
assert allow_eager or torch._dynamo.is_compiling()
|
||||
assert allow_eager or torch.compiler.is_compiling()
|
||||
assert mod.weight.size() == (10, 10)
|
||||
mod.hook_count_post.add_(1)
|
||||
return (torch.sin(gI[0] + 3.4),)
|
||||
|
@ -4354,7 +4354,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
|
||||
out = func(nt, dim=rd, keepdim=keepdim)
|
||||
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
|
||||
if not torch.compiler.is_compiling: # if not using torch dynamo
|
||||
if not torch.compiler.is_compiling(): # if not using torch dynamo
|
||||
self.assertEqual(len(out.shape), len(ref_shape))
|
||||
for o, r in zip(out.shape, ref_shape):
|
||||
if r is not None:
|
||||
@ -4597,7 +4597,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
# requires_grad = False does not currently work with dynamo tests and throws this error:
|
||||
# AssertionError: SymInts must use SymNodeVariable.
|
||||
# If the underlying value is static, we will create a ConstantVariable and specialize.
|
||||
if torch._dynamo.is_compiling() and not requires_grad:
|
||||
if torch.compiler.is_compiling() and not requires_grad:
|
||||
return
|
||||
|
||||
tensor_lists = self._get_example_tensor_lists(
|
||||
|
@ -288,7 +288,7 @@ class TestOptimRenewed(TestCase):
|
||||
inpt = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
|
||||
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
|
||||
lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01
|
||||
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
|
||||
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
|
||||
|
||||
|
@ -19,7 +19,6 @@ from .eval_frame import (
|
||||
RunOnlyContext,
|
||||
)
|
||||
from .exc import IncorrectUsage
|
||||
from .external_utils import is_compiling
|
||||
from .utils import is_function
|
||||
|
||||
|
||||
@ -546,7 +545,7 @@ def mark_static(t, index=None):
|
||||
instances of the nn.Module can have different values of the attributes. The
|
||||
key point here is that the attributes are static.
|
||||
"""
|
||||
if is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
if index is None:
|
||||
for s in t.size():
|
||||
comptime.force_static(s)
|
||||
|
@ -3,6 +3,7 @@
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -14,6 +15,10 @@ except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
@ -191,7 +191,7 @@ def vmap(
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
sequences out of the box.
|
||||
"""
|
||||
from torch._dynamo import is_compiling
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
_check_randomness_arg(randomness)
|
||||
if not (chunk_size is None or chunk_size > 0):
|
||||
@ -393,7 +393,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
||||
"""
|
||||
# To avoid cyclical dependency.
|
||||
import torch._functorch.eager_transforms as eager_transforms
|
||||
from torch._dynamo import is_compiling
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
|
||||
@ -435,8 +435,8 @@ def grad_and_value(
|
||||
|
||||
See :func:`grad` for examples
|
||||
"""
|
||||
from torch._dynamo import is_compiling
|
||||
from torch._functorch import eager_transforms
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_and_value_impl(
|
||||
|
@ -764,7 +764,7 @@ def jacrev(
|
||||
# Dynamo does not support HOP composition if their inner function is
|
||||
# annotated with @functools.wraps(...). We circumvent this issue by applying
|
||||
# wraps only if we're not tracing with dynamo.
|
||||
if not torch._dynamo.is_compiling():
|
||||
if not torch.compiler.is_compiling():
|
||||
wrapper_fn = wraps(func)(wrapper_fn)
|
||||
|
||||
return wrapper_fn
|
||||
@ -1344,7 +1344,7 @@ def jacfwd(
|
||||
# Dynamo does not support HOP composition if their inner function is
|
||||
# annotated with @functools.wraps(...). We circumvent this issue by applying
|
||||
# wraps only if we're not tracing with dynamo.
|
||||
if not torch._dynamo.is_compiling():
|
||||
if not torch.compiler.is_compiling():
|
||||
wrapper_fn = wraps(func)(wrapper_fn)
|
||||
|
||||
return wrapper_fn
|
||||
|
@ -132,7 +132,7 @@ def associative_scan(
|
||||
"Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}"
|
||||
)
|
||||
|
||||
if not torch._dynamo.is_compiling():
|
||||
if not torch.compiler.is_compiling():
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(associative_scan, fullgraph=True)(
|
||||
combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
|
||||
|
@ -191,7 +191,7 @@ doesn't match the length of the pytree of the init {len(leaves_init)}"
|
||||
combine_fn, leaves_init, leaves_xs, dim, reverse, additional_inputs=[]
|
||||
)
|
||||
|
||||
if not torch._dynamo.is_compiling():
|
||||
if not torch.compiler.is_compiling():
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_mode,
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, DefaultDict, Generic, List, Optional
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
@ -882,6 +882,10 @@ def classproperty(func):
|
||||
return _ClassPropertyDescriptor(func)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def is_compiling() -> bool:
|
||||
"""
|
||||
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
||||
|
@ -75,7 +75,7 @@ def _compress_hook(
|
||||
decompressed_tensor.copy_(value)
|
||||
return decompressed_tensor
|
||||
|
||||
if torch._utils.is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
)
|
||||
|
@ -7,17 +7,16 @@ from torch.distributed.tensor import DeviceMesh
|
||||
from torch.distributed.tensor.placement_types import Placement
|
||||
|
||||
|
||||
try:
|
||||
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
|
||||
except Exception:
|
||||
|
||||
def is_torchdynamo_compiling(): # type: ignore[misc]
|
||||
return False
|
||||
|
||||
|
||||
LayoutsType = Union[Placement, Tuple[Placement, ...]]
|
||||
|
||||
|
||||
def is_torchdynamo_compiling() -> bool:
|
||||
# Use local function to avoid circular imports
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
return is_compiling()
|
||||
|
||||
|
||||
def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
|
||||
"""
|
||||
Inject common validation logics for `_prepare_input` funcs via this decorator.
|
||||
|
@ -1828,8 +1828,6 @@ class Module:
|
||||
|
||||
return result
|
||||
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
# This is technically not behavior equivalent when compiling, but it's
|
||||
# incredibly unlikely we will ever support throwing an exception in NN
|
||||
# module, and then catching it here, and then reraising it, and then
|
||||
@ -1837,7 +1835,7 @@ class Module:
|
||||
# The reraise here just gunks up our exception handling for no good
|
||||
# reason. Don't try to run the always called hooks in event of
|
||||
# exception.
|
||||
if is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
return inner()
|
||||
|
||||
try:
|
||||
|
@ -1487,7 +1487,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
def _should_disable_cpp_reducer(self) -> bool:
|
||||
return self._use_python_reducer and (
|
||||
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
|
||||
torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer
|
||||
)
|
||||
|
||||
def _pre_forward(self, *inputs, **kwargs):
|
||||
@ -1500,7 +1500,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
h.remove()
|
||||
self._accum_grad_hooks.clear()
|
||||
|
||||
if not self._lazy_init_ran and not torch._utils.is_compiling():
|
||||
if not self._lazy_init_ran and not torch.compiler.is_compiling():
|
||||
self._lazy_init()
|
||||
|
||||
if self._delay_all_reduce_all_params:
|
||||
|
@ -505,7 +505,7 @@ def _multi_tensor_adafactor(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -624,7 +624,7 @@ def adafactor(
|
||||
|
||||
See :class:`~torch.optim.Adafactor` for details.
|
||||
"""
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -259,7 +259,7 @@ def _single_tensor_adadelta(
|
||||
has_complex: bool,
|
||||
):
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -315,7 +315,7 @@ def _multi_tensor_adadelta(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -352,7 +352,7 @@ def _multi_tensor_adadelta(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -423,7 +423,7 @@ def adadelta(
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -451,7 +451,7 @@ def _multi_tensor_adagrad(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -353,7 +353,7 @@ def _single_tensor_adam(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -466,7 +466,7 @@ def _multi_tensor_adam(
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -520,7 +520,7 @@ def _multi_tensor_adam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -762,7 +762,7 @@ def adam(
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -248,7 +248,7 @@ def _single_tensor_adamax(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -320,7 +320,7 @@ def _multi_tensor_adamax(
|
||||
return
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -358,7 +358,7 @@ def _multi_tensor_adamax(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -435,7 +435,7 @@ def adamax(
|
||||
See :class:`~torch.optim.Adamax` for details.
|
||||
"""
|
||||
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -350,7 +350,7 @@ def _single_tensor_adamw(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -463,7 +463,7 @@ def _multi_tensor_adamw(
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -516,7 +516,7 @@ def _multi_tensor_adamw(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -744,7 +744,7 @@ def adamw(
|
||||
|
||||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -219,7 +219,7 @@ def _single_tensor_asgd(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type
|
||||
@ -292,7 +292,7 @@ def _multi_tensor_asgd(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -333,7 +333,7 @@ def _multi_tensor_asgd(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -310,7 +310,7 @@ def _single_tensor_nadam(
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == mu_product.device.type == step_t.device.type
|
||||
@ -396,7 +396,7 @@ def _multi_tensor_nadam(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -437,7 +437,7 @@ def _multi_tensor_nadam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -26,7 +26,6 @@ from typing_extensions import ParamSpec, Self, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.utils.hooks as hooks
|
||||
from torch._utils import is_compiling
|
||||
from torch.utils._foreach_utils import (
|
||||
_get_foreach_kernels_supported_devices,
|
||||
_get_fused_kernels_supported_devices,
|
||||
@ -100,14 +99,14 @@ def _use_grad_for_differentiable(func):
|
||||
|
||||
def _get_value(x):
|
||||
# item is significantly faster than a cpu tensor in eager mode
|
||||
if not torch.jit.is_scripting() and is_compiling():
|
||||
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
|
||||
return x
|
||||
else:
|
||||
return x.item() if isinstance(x, torch.Tensor) else x
|
||||
|
||||
|
||||
def _stack_if_compiling(x):
|
||||
if not torch.jit.is_scripting() and is_compiling():
|
||||
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
|
||||
return torch.stack(x)
|
||||
else:
|
||||
return x
|
||||
@ -139,7 +138,7 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None):
|
||||
# the capturable flag. If capturable=True, this is not a problem.
|
||||
@functools.wraps(func)
|
||||
def maybe_fallback(*args, **kwargs):
|
||||
if is_compiling() and (
|
||||
if torch.compiler.is_compiling() and (
|
||||
not kwargs.get("capturable", False)
|
||||
and has_state_steps
|
||||
and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
|
||||
@ -429,7 +428,7 @@ class Optimizer:
|
||||
# Thus, when compiling, inductor will determine if cudagraphs
|
||||
# can be enabled based on whether there is input mutation or CPU tensors.
|
||||
if (
|
||||
not is_compiling()
|
||||
not torch.compiler.is_compiling()
|
||||
and torch.backends.cuda.is_built()
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
@ -516,7 +515,7 @@ class Optimizer:
|
||||
|
||||
Skips this step if we are compiling since this will occur during inductor lowering.
|
||||
"""
|
||||
if is_compiling():
|
||||
if torch.compiler.is_compiling():
|
||||
return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
|
||||
else:
|
||||
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
|
||||
|
@ -276,7 +276,7 @@ def _single_tensor_radam(
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
@ -374,7 +374,7 @@ def _multi_tensor_radam(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
@ -404,7 +404,7 @@ def _multi_tensor_radam(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -284,7 +284,7 @@ def _single_tensor_rmsprop(
|
||||
step = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
@ -357,7 +357,7 @@ def _multi_tensor_rmsprop(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
@ -402,7 +402,7 @@ def _multi_tensor_rmsprop(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -489,7 +489,7 @@ def rmsprop(
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -243,7 +243,7 @@ def _single_tensor_rprop(
|
||||
step = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
@ -309,7 +309,7 @@ def _multi_tensor_rprop(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
@ -337,7 +337,7 @@ def _multi_tensor_rprop(
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
@ -427,7 +427,7 @@ def rprop(
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
|
@ -435,7 +435,7 @@ def _multi_tensor_sgd(
|
||||
|
||||
if not device_has_sparse_grad:
|
||||
# handle internal item() call if lr is a tensor
|
||||
if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
|
||||
if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
|
||||
grads_x_lr = torch._foreach_mul(device_grads, -lr)
|
||||
torch._foreach_add_(device_params, grads_x_lr)
|
||||
else:
|
||||
|
@ -565,7 +565,7 @@ class OpCheckMode(TorchFunctionMode):
|
||||
if (
|
||||
torch.jit.is_tracing()
|
||||
or torch.jit.is_scripting()
|
||||
or torch._dynamo.is_compiling()
|
||||
or torch.compiler.is_compiling()
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
# Pre-existing code may not use the .default overload. If we see an
|
||||
|
Reference in New Issue
Block a user