diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index b7b1c51e83b1..2c3a12020eb6 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -37,22 +37,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() { // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect // to participate in the functorch transforms. key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; - // For better error handling, - // we also don't want our wrapper tensor to be able to dispatch directly - // to a backend kernel. - // Dispatching directly to e.g. a CPU kernel would always segfault, - // because wrapper tensors don't have any real data. - // (This should never happen because we should always hit a functionalization kernel, - // but can help make bugs less nasty). - // Here, we defensively remove any backend keys from the wrapper's keyset. - // We don't want to remove actual backend bits though (say we're redispatching to autograd; - // we need to know if we're dispatching to AutogradCPU or AutogradXLA). - // Instead, it's sufficient to remove the `Dense` dispatch key, - // which prevents us from accidentally trying to directly run a CPU/CUDA kernel. - key_set_ = key_set_.remove(c10::DispatchKey::Dense); // We override a bunch of _custom(), so make sure they get called // TODO: metadata copying may not actually be necessary then set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); + set_custom_device(true); } FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) @@ -335,6 +323,9 @@ c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( std::move(version_counter), allow_tensor_metadata_change); } +c10::Device FunctionalTensorWrapper::device_custom() const { + return value_.unsafeGetTensorImpl()->device(); +} at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { return value_.unsafeGetTensorImpl()->sizes(); } diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 27a88f13f872..0762fb1f7f9b 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -148,6 +148,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { c10::SymInt sym_size_custom(int64_t d) const override; c10::SymIntArrayRef sym_strides_custom() const override; c10::SymInt sym_storage_offset_custom() const override; + c10::Device device_custom() const override; private: const char* tensorimpl_type_name() const override; diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 16c0aa42232f..5c8214b7d882 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace at { @@ -15,7 +16,8 @@ ThreadLocalState::ThreadLocalState() functorch_tls_(functorch::getCopyOfFuncTorchTLS()), autograd_tls_(c10::AutogradState::get_tls_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), - python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) { + python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), + functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) { rf_tls_ = at::get_record_function_tls_(); saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state(); @@ -53,6 +55,8 @@ void ThreadLocalState::setThreadLocalState( c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); functorch::setFuncTorchTLS(state.functorch_tls_); + + at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 9e5f70a4224f..0184cc9b82c4 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -74,6 +74,8 @@ class TORCH_API ThreadLocalState { // TLS for saved tensors default hooks at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; + bool functionalization_reapply_views_state_; + friend class ThreadLocalStateGuard; }; diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 33916492a0ef..6a0be21f8fe0 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -876,7 +876,10 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // treatment; return (s - autograd_dispatch_keyset_with_ADInplaceOrView - autocast_dispatch_keyset - - DispatchKeySet({DispatchKey::PythonTLSSnapshot, DispatchKey::Python})) + DispatchKeySet( + {DispatchKey::Functionalize, + DispatchKey::PythonTLSSnapshot, + DispatchKey::Python})) .highestPriorityTypeId(); } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index a360a65d42a3..27d65e2d8673 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1307,7 +1307,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * It can be expanded as needed in the future, e.g sparse Tensor. */ inline bool support_as_strided() const { - return is_nested() ? false : device().supports_as_strided(); + if (is_nested()) { + return false; + } + if (key_set_.has(DispatchKey::Functionalize)) { + return false; + } + return device().supports_as_strided(); } // ~~~~~ Autograd API ~~~~~ diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index e0c9d09d821d..d682d8b4b71b 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -18,7 +18,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.nn.utils import stateless from functorch import make_fx -from functorch.experimental import functionalize from torch._dispatch.python import enable_python_dispatcher from . import config from .named_members_polyfill import _named_buffers, _named_parameters @@ -133,7 +132,99 @@ def setup_stacktrace_preservation_hooks(roots: List): node.register_hook(get_posthook(special_stack)) -def create_joint_forward_backward(fn): +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. It might be generally applicable though +# (if so, move it out of this file), so I've tried to give it a name that +# describes what it does. +# +# Given a function f, it produces a new function g that: +# +# - Detaches all inputs before running f; the inner function +# does not directly participate in any pre-existing autograd. +# preserve_requires_grad is provided as a convenience to set the +# requires_grad on the new detached leaves in sync with the originals. +# (NB: In principle, you could backward through the pure operations +# produced by functionalization; this is not used for AOTAutograd +# and we have not tested it.) +# +# - Functionalizes all operations on f, under the assumption that the passed +# in function f must be "observationally pure"; that is, it cannot perform any +# mutations (inplace data or view operations) on the passed in inputs, nor is +# it allowed to directly close over tensors that aren't passed via its +# arguments. See +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +# for discussion how how to implement the more complicated case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# TODO: Provide a faster version of this that assumes flat arguments +# (so no pytree necessary) +def detach_and_functionalize_pure(f, preserve_requires_grad=True): + @wraps(f) + def inner(*args, **kwargs): + def to_fun(t): + if isinstance(t, Tensor): + r = torch._to_functional_tensor(t) + # NB: r is a leaf; it has no grad_fn relating + # it to t. If t has autograd metadata, that + # metadata was preserved *inside* the r wrapper + if preserve_requires_grad: + r.requires_grad = t.requires_grad + return r + else: + return t + + f_args, f_kwargs = pytree.tree_map(to_fun, (args, kwargs)) + + torch._enable_functionalization(reapply_views=True) + try: + outs = f(*f_args, **f_kwargs) + finally: + torch._disable_functionalization() + + # Detect input mutation and error if found + flat_args, _ = pytree.tree_flatten((args, kwargs)) + flat_f_args, _ = pytree.tree_flatten((f_args, f_kwargs)) + + # This is just for sanity checking, can be skipped + for arg, f_arg in zip(flat_args, flat_f_args): + if not isinstance(arg, Tensor): + continue + torch._sync(f_arg) + new_arg = torch._from_functional_tensor(f_arg) + # I want to do this assert, but it is annoying because + # we have operator tests that have mutating inputs. So + # I do something unsound instead + # assert arg is new_arg, "input argument was mutated, this is not valid" + if arg is not new_arg: + assert arg.shape == new_arg.shape + arg.copy_(new_arg) + + def from_fun(t): + if not isinstance(t, Tensor) or not torch._is_functional_tensor(t): + return t + torch._sync(t) + return torch._from_functional_tensor(t) + + return pytree.tree_map(from_fun, outs) + return inner + + +# This creates a joint forwards-backwards function given both +# the primals (to run forwards) and tangents (to run backwards). +# +# It has a precondition which is that the passed in function +# must be observationally pure; it is not permitted to mutate +# the primals or tangents. +def create_joint_forward_backward_pure(fn): def joint_forward_backward( primals: List[Any], tangents: List[Any] ) -> Tuple[List[Any], List[Any]]: @@ -366,7 +457,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi deduped_flat_args = remove_dupe_args(flat_args) - joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args))) + joint_forward_backward = create_joint_forward_backward_pure(lambda *args: flat_fn(*add_dupe_args(args))) out = flat_fn(*flat_args) # Collect info on which output tensors require gradients, @@ -392,27 +483,13 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi if config.use_functionalize: with enable_python_dispatcher(): - # Trace once without decompositions, into a graph of ATen ops. - # NB: tracing_mode is real, as it's assumed the calling context setup - # fake tensor mode / symbolic shapes if that is needed - fx_g = make_fx(joint_forward_backward)(*joint_inputs) - - context = disable_autocast_manager if disable_amp else nullcontext - - def fake_fn(primals, tangents): - with torch.fx.traceback.override_stack_trace(): - return torch.fx.Interpreter(fx_g).run(primals, tangents) - - # Trace a second time, running functionalization, and THEN running decompositions. - # functionalization only acts on ATen today, and doesn't currently handle - # view and inplace ops that come from primtorch. - # Eventually, functionalization should support primtorch view/inplace ops, - # which will make it ok to run decompositions before functionalization. - with context(): - fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs) + fx_g = make_fx( + detach_and_functionalize_pure(joint_forward_backward), aot_config.decompositions + )(*joint_inputs) fx_g.graph.eliminate_dead_code() fx_g.recompile() else: + warnings.warn("graph partitioning without functionalization is not sound, we may introduce errors") fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs) if config.debug_joint: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index d8d330b4f3fc..2b457cd05c6c 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -416,6 +416,21 @@ class TestAOTAutograd(AOTTestCase): self.assertEqual(ref_out, test_out) + def test_custom_autograd(self): + class CustomFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + + def f(x): + return CustomFn.apply(x) + + self.verify_aot_autograd(f, [torch.randn(3)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_autocast_disable_guard(self): guard = torch._C._DisableAutocast() @@ -1099,12 +1114,10 @@ symbolic_aot_autograd_failures = { xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... - xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('median', ''), # could not find kernel xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1116,7 +1129,6 @@ symbolic_aot_autograd_failures = { # Deleting this in a followup xfail('nn.functional.feature_alpha_dropout', 'with_train'), - xfail('nn.functional.pad', 'circular'), xfail('nn.functional.poisson_nll_loss', ''), xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ... diff --git a/test/test_functionalization.py b/test/test_functionalization.py index bfb79675c7eb..c6c3d991771b 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -141,26 +141,23 @@ class TestFunctionalization(TestCase): def forward(self, a_1): view_copy = torch.ops.aten.view_copy.default(a_1, [1, 1024, 128, 128]); a_1 = None clone = torch.ops.aten.clone.default(view_copy); view_copy = None - view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None + view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]) relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None sum_1 = torch.ops.aten.sum.default(relu) ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None - new_zeros = torch.ops.aten.new_zeros.default(expand_copy, [16777216]) - as_strided_copy = torch.ops.aten.as_strided_copy.default(new_zeros, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) - as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(new_zeros, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0) - as_strided_scatter = torch.ops.aten.as_strided_scatter.default(new_zeros, expand_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0); new_zeros = expand_copy = None - as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0); as_strided_scatter = None - new_empty_strided = torch.ops.aten.new_empty_strided.default(as_strided_copy_2, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) - as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) - as_strided_copy_4 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) - clone_1 = torch.ops.aten.clone.default(as_strided_copy_4, memory_format = torch.contiguous_format); as_strided_copy_4 = None + _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None + new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) + view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) + view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) + clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1]) - detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy); _reshape_alias_copy = None - as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(as_strided_copy_2, threshold_backward, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0); as_strided_copy_2 = threshold_backward = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(as_strided_scatter_1, [16, 64, 128, 128], [1048576, 16384, 128, 1]); as_strided_scatter_1 = None - detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None + _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None + detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None + view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None + _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None + detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None return detach_copy_1 """) # noqa: B950