diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7748d14c3af1..a9f62d75fa77 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -10,7 +10,6 @@ import torch._functorch.config import torch.utils._pytree as pytree import torch.utils.checkpoint from torch._dynamo.testing import normalize_gm -from torch._functorch.aot_autograd import to_fun from torch._higher_order_ops.wrap import wrap from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv @@ -240,6 +239,12 @@ class GraphModule(torch.nn.Module): self.assertEqual(actual, expected) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) + # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. + def to_fun(x): + x_functional = torch._to_functional_tensor(x) + torch._mirror_autograd_meta_to(x, x_functional) + return x_functional + def aot_f_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -322,7 +327,8 @@ class GraphModule(torch.nn.Module): check_count_and_graph(2, 2, 2, expected_graph) try: - x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True) + x = torch._to_functional_tensor(t_clone2) + torch._mirror_autograd_meta_to(t_clone2, x) torch._enable_functionalization(reapply_views=False) aot_f_out = f(x) finally: diff --git a/torch/__init__.py b/torch/__init__.py index 10611c70a955..257bfb7064e4 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -24,6 +24,7 @@ def _running_with_deploy(): return sys.modules.get("torch._meta_registrations", None) is object from ._utils import _import_dotted_name, classproperty +from ._utils import _functionalize_sync as _sync from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index e5358f1de3d4..e165754603ee 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -676,7 +676,9 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires def to_fun(t): if isinstance(t, Tensor): - return torch._to_functional_tensor(t, mirror_autograd_meta=True) + out = torch._to_functional_tensor(t) + torch._mirror_autograd_meta_to(t, out) + return out else: return t @@ -727,7 +729,8 @@ def run_functionalized_fw_and_collect_metadata( if isinstance(t, Tensor): if t in memo: return memo[t] - r = torch._to_functional_tensor(t, mirror_autograd_meta=True) + r = torch._to_functional_tensor(t) + torch._mirror_autograd_meta_to(t, r) memo[t] = r return r else: diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 268120948f56..aea42c26bc1a 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -119,7 +119,9 @@ class FunctionalTensor(torch.Tensor): # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) # this is handled by FunctionalTensor.to_functional x_functional = torch._to_functional_tensor(x) + torch._mirror_autograd_meta_to(x, x_functional) out = FunctionalTensor(x_functional) + torch._mirror_autograd_meta_to(x_functional, out) return out def from_functional(self): diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 47dc36bbcc5a..915a1a0f3679 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -600,9 +600,9 @@ class MetaConverter: dynamic_dims=dynamic_dims, constraint_dims=constraint_dims, ) - return torch._to_functional_tensor( - fake_t, mirror_autograd_meta=True - ) + out = torch._to_functional_tensor(fake_t) + torch._mirror_autograd_meta_to(fake_t, out) + return out else: # torch.func.functionalize reapply_views = torch._C._functionalization_reapply_views_tls() diff --git a/torch/_utils.py b/torch/_utils.py index 048b987cb8bb..04afde58b597 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -4,6 +4,7 @@ import sys import traceback import warnings from collections import defaultdict +from contextlib import nullcontext from typing import Any, DefaultDict, List, Optional import torch @@ -842,6 +843,44 @@ def is_compiling(): return False +def _functionalize_sync(t): + # This code lives in python instead of C++ since conditioning on a certain python subclass + # is much more of a pain in C++. + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + maybe_disable_functional_mode, + ) + + ctx = ( + maybe_disable_functional_mode + if isinstance(t, FunctionalTensor) + else nullcontext + ) + if isinstance(t, FunctionalTensor): + # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called + # when we sync our inner tensor. + # Why? + # (1) If there are input mutations in the graph, then they will be re-applied during + # AOTAutograd when we call _sync() from inside of our functionalization kernels. + # (2) _sync() causes us to regenerate our updated the tensor from the updated base, + # which dispatches to a bunch of view ops + # (3) The input to these view ops is our inner FunctionalTensorWrapper + # (since the sync was called from C++), not the python FunctionalTensor + # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts + # the view op, since it will see an input that is a C++ FunctionalTensorWrapper + # (aka a normal torch.Tensor) instead of a python `FunctionalTensor). + maybe_functional_mode = torch._C._unset_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + try: + torch._functionalize_sync(t.elem) + finally: + if maybe_functional_mode is not None: + torch._C._set_dispatch_mode(maybe_functional_mode) + else: + torch._functionalize_sync(t) + + @functools.lru_cache(2) def _get_device_module(device_type: str): device_module = getattr(torch, device_type, None) diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 065830f131c5..469166bf3a7f 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -362,31 +362,52 @@ static PyObject* THPVariable__to_functional_tensor( PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser( - {"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"}, + {"_to_functional_tensor(Tensor t)"}, /*traceable=*/true); ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); auto self_ = r.tensor(0); - auto mirror_autograd_meta = r.toBool(1); auto wrapped = at::functionalization::impl::to_functional_tensor(self_); - if (mirror_autograd_meta) { - // Here, we unsafely set the grad function on the wrapper to be the same as - // the inner. We expect this grad_fn to NEVER be used. It's needed so that - // .is_leaf metadata is accurate on the wrapper - auto inner_autograd_meta = impl::get_autograd_meta(self_); - if (inner_autograd_meta) { - wrapped.set_requires_grad(self_.requires_grad()); - if (wrapped.requires_grad()) { - auto new_grad_fn = std::shared_ptr( - new torch::autograd::Error( - "Cannot backprop through mirrored meta, file a bug in PyTorch"), - torch::autograd::deleteNode); - torch::autograd::set_history(wrapped, new_grad_fn); - } + return wrap(std::move(wrapped)); + END_HANDLE_TH_ERRORS +} + +// Given source and dest tensors, +// Sets **some** (but not all) autograd metadata on dest, according to source: +// - requires_grad +// - grad_fn +// (If src has a grad_fn, we install an error grad_fn on dest to avoid +// difficult bugs. +// The main purpose is to ensure that dst.is_leaf == src.is_leaf) +static PyObject* THPVariable__mirror_autograd_meta_to( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_mirror_autograd_meta_to(Tensor source, Tensor dest)"}, + /*traceable=*/true); + + ParsedArgs<2> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto src_ = r.tensor(0); + auto dst_ = r.tensor(1); + // Here, we unsafely set the grad function on the wrapper to be the same as + // the inner. We expect this grad_fn to NEVER be used. It's needed so that + // .is_leaf metadata is accurate on the wrapper + auto inner_autograd_meta = impl::get_autograd_meta(src_); + if (inner_autograd_meta) { + dst_.set_requires_grad(src_.requires_grad()); + if (dst_.requires_grad()) { + auto new_grad_fn = std::shared_ptr( + new torch::autograd::Error( + "Cannot backprop through mirrored meta, file a bug in PyTorch"), + torch::autograd::deleteNode); + torch::autograd::set_history(dst_, new_grad_fn); } } - return wrap(std::move(wrapped)); + Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -526,12 +547,13 @@ static PyObject* THPVariable__disable_functionalization( END_HANDLE_TH_ERRORS } -static PyObject* THPVariable__sync( +static PyObject* THPVariable__functionalize_sync( PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true); + static PythonArgParser parser( + {"_functionalize_sync(Tensor t)"}, /*traceable=*/true); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -568,6 +590,10 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_mirror_autograd_meta_to", + castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, {"_from_functional_tensor", castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, @@ -576,8 +602,8 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sync", - castPyCFunctionWithKeywords(THPVariable__sync), + {"_functionalize_sync", + castPyCFunctionWithKeywords(THPVariable__functionalize_sync), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"_enable_functionalization",