Replace deprecated is_compiling method (#154476)

Replace depreacted `is_compiling` in `torch._dynamo` with `torch.compiler`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154476
Approved by: https://github.com/eellison
This commit is contained in:
zeshengzong
2025-06-24 05:16:36 +00:00
committed by PyTorch MergeBot
parent 1044934878
commit 495c317005

View File

@ -190,7 +190,7 @@ def vmap(
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
from torch._dynamo import is_compiling
from torch.compiler import is_compiling
_check_randomness_arg(randomness)
if not (chunk_size is None or chunk_size > 0):
@ -392,7 +392,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
"""
# To avoid cyclical dependency.
import torch._functorch.eager_transforms as eager_transforms
from torch._dynamo import is_compiling
from torch.compiler import is_compiling
def wrapper(*args, **kwargs):
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
@ -434,8 +434,8 @@ def grad_and_value(
See :func:`grad` for examples
"""
from torch._dynamo import is_compiling
from torch._functorch import eager_transforms
from torch.compiler import is_compiling
def wrapper(*args, **kwargs):
return eager_transforms.grad_and_value_impl(