mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Functionalize and compute joint simultaneously. (#88063)
This also comes with some bug fixes that were uncovered from doing this: - Forward device calls to inner tensor in FunctionalTensorWrapper - Make legacyExtractDispatchKey exclude Functionalize, so that it can get at the real device type key. This is noncontroversial. - Stop stripping dense from key set. The reason for this is FunctionalWrapperTensor may be used in contexts where people query if it is dense or not. If it doesn't report this correctly (from the dispatch key), it will cause errors. This caused some torchbench models to fail when I did one-pass tracing. - Save and restore reapply views TLS correctly Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/88063 Approved by: https://github.com/bdhirsh
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							957a9b63c5
						
					
				
				
					commit
					0e3031f7e7
				
			| @ -37,22 +37,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() { | ||||
|   // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect | ||||
|   // to participate in the functorch transforms. | ||||
|   key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; | ||||
|   // For better error handling, | ||||
|   // we also don't want our wrapper tensor to be able to dispatch directly | ||||
|   // to a backend kernel. | ||||
|   // Dispatching directly to e.g. a CPU kernel would always segfault, | ||||
|   // because wrapper tensors don't have any real data. | ||||
|   // (This should never happen because we should always hit a functionalization kernel, | ||||
|   // but can help make bugs less nasty). | ||||
|   // Here, we defensively remove any backend keys from the wrapper's keyset. | ||||
|   // We don't want to remove actual backend bits though (say we're redispatching to autograd; | ||||
|   // we need to know if we're dispatching to AutogradCPU or AutogradXLA). | ||||
|   // Instead, it's sufficient to remove the `Dense` dispatch key, | ||||
|   // which prevents us from accidentally trying to directly run a CPU/CUDA kernel. | ||||
|   key_set_ = key_set_.remove(c10::DispatchKey::Dense); | ||||
|   // We override a bunch of _custom(), so make sure they get called | ||||
|   // TODO: metadata copying may not actually be necessary then | ||||
|   set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); | ||||
|   set_custom_device(true); | ||||
| } | ||||
|  | ||||
| FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) | ||||
| @ -335,6 +323,9 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach( | ||||
|       std::move(version_counter), allow_tensor_metadata_change); | ||||
| } | ||||
|  | ||||
| c10::Device FunctionalTensorWrapper::device_custom() const { | ||||
|   return value_.unsafeGetTensorImpl()->device(); | ||||
| } | ||||
| at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { | ||||
|   return value_.unsafeGetTensorImpl()->sizes(); | ||||
| } | ||||
|  | ||||
| @ -148,6 +148,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|   c10::SymInt sym_size_custom(int64_t d) const override; | ||||
|   c10::SymIntArrayRef sym_strides_custom() const override; | ||||
|   c10::SymInt sym_storage_offset_custom() const override; | ||||
|   c10::Device device_custom() const override; | ||||
|  | ||||
|  private: | ||||
|   const char* tensorimpl_type_name() const override; | ||||
|  | ||||
| @ -6,6 +6,7 @@ | ||||
|  | ||||
| #include <ATen/record_function.h> | ||||
| #include <ATen/SavedTensorHooks.h> | ||||
| #include <ATen/FunctionalTensorWrapper.h> | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| @ -15,7 +16,8 @@ ThreadLocalState::ThreadLocalState() | ||||
|       functorch_tls_(functorch::getCopyOfFuncTorchTLS()), | ||||
|       autograd_tls_(c10::AutogradState::get_tls_state()), | ||||
|       python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), | ||||
|       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) { | ||||
|       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), | ||||
|       functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) { | ||||
|   rf_tls_ = at::get_record_function_tls_(); | ||||
|  | ||||
|   saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state(); | ||||
| @ -53,6 +55,8 @@ void ThreadLocalState::setThreadLocalState( | ||||
|   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); | ||||
|  | ||||
|   functorch::setFuncTorchTLS(state.functorch_tls_); | ||||
|  | ||||
|   at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); | ||||
| } | ||||
|  | ||||
| } // namespace at | ||||
|  | ||||
| @ -74,6 +74,8 @@ class TORCH_API ThreadLocalState { | ||||
|   // TLS for saved tensors default hooks | ||||
|   at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; | ||||
|  | ||||
|   bool functionalization_reapply_views_state_; | ||||
|  | ||||
|   friend class ThreadLocalStateGuard; | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -876,7 +876,10 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { | ||||
|   // treatment; | ||||
|   return (s - autograd_dispatch_keyset_with_ADInplaceOrView - | ||||
|           autocast_dispatch_keyset - | ||||
|           DispatchKeySet({DispatchKey::PythonTLSSnapshot, DispatchKey::Python})) | ||||
|           DispatchKeySet( | ||||
|               {DispatchKey::Functionalize, | ||||
|                DispatchKey::PythonTLSSnapshot, | ||||
|                DispatchKey::Python})) | ||||
|       .highestPriorityTypeId(); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -1307,7 +1307,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | ||||
|    * It can be expanded as needed in the future, e.g sparse Tensor. | ||||
|    */ | ||||
|   inline bool support_as_strided() const { | ||||
|     return is_nested() ? false : device().supports_as_strided(); | ||||
|     if (is_nested()) { | ||||
|       return false; | ||||
|     } | ||||
|     if (key_set_.has(DispatchKey::Functionalize)) { | ||||
|       return false; | ||||
|     } | ||||
|     return device().supports_as_strided(); | ||||
|   } | ||||
|  | ||||
|   // ~~~~~ Autograd API ~~~~~ | ||||
|  | ||||
| @ -18,7 +18,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv | ||||
| from torch.nn.utils import stateless | ||||
|  | ||||
| from functorch import make_fx | ||||
| from functorch.experimental import functionalize | ||||
| from torch._dispatch.python import enable_python_dispatcher | ||||
| from . import config | ||||
| from .named_members_polyfill import _named_buffers, _named_parameters | ||||
| @ -133,7 +132,99 @@ def setup_stacktrace_preservation_hooks(roots: List): | ||||
|         node.register_hook(get_posthook(special_stack)) | ||||
|  | ||||
|  | ||||
| def create_joint_forward_backward(fn): | ||||
| # This is a version of functionalization that is specifically designed | ||||
| # for the AOTAutograd use case.  It might be generally applicable though | ||||
| # (if so, move it out of this file), so I've tried to give it a name that | ||||
| # describes what it does. | ||||
| # | ||||
| # Given a function f, it produces a new function g that: | ||||
| # | ||||
| #   - Detaches all inputs before running f; the inner function | ||||
| #     does not directly participate in any pre-existing autograd. | ||||
| #     preserve_requires_grad is provided as a convenience to set the | ||||
| #     requires_grad on the new detached leaves in sync with the originals. | ||||
| #     (NB: In principle, you could backward through the pure operations | ||||
| #     produced by functionalization; this is not used for AOTAutograd | ||||
| #     and we have not tested it.) | ||||
| # | ||||
| #   - Functionalizes all operations on f, under the assumption that the passed | ||||
| #     in function f must be "observationally pure"; that is, it cannot perform any | ||||
| #     mutations (inplace data or view operations) on the passed in inputs, nor is | ||||
| #     it allowed to directly close over tensors that aren't passed via its | ||||
| #     arguments.  See | ||||
| #     https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit | ||||
| #     for discussion how how to implement the more complicated case. | ||||
| # | ||||
| # Unlike functorch's variant, this doesn't use the functorch level system, | ||||
| # instead it directly uses PyTorch's conventional dispatcher to hit the | ||||
| # functionalization key.  In particular, this means that FunctionalTensorWrapper | ||||
| # can have autograd data stored directly on it. | ||||
| # | ||||
| # In typical AOTAutograd usage, the dispatch key order will look like: | ||||
| # | ||||
| #   Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor | ||||
| #       outer tensor                        inner tensor | ||||
| # | ||||
| # TODO: Provide a faster version of this that assumes flat arguments | ||||
| # (so no pytree necessary) | ||||
| def detach_and_functionalize_pure(f, preserve_requires_grad=True): | ||||
|     @wraps(f) | ||||
|     def inner(*args, **kwargs): | ||||
|         def to_fun(t): | ||||
|             if isinstance(t, Tensor): | ||||
|                 r = torch._to_functional_tensor(t) | ||||
|                 # NB: r is a leaf; it has no grad_fn relating | ||||
|                 # it to t.  If t has autograd metadata, that | ||||
|                 # metadata was preserved *inside* the r wrapper | ||||
|                 if preserve_requires_grad: | ||||
|                     r.requires_grad = t.requires_grad | ||||
|                 return r | ||||
|             else: | ||||
|                 return t | ||||
|  | ||||
|         f_args, f_kwargs = pytree.tree_map(to_fun, (args, kwargs)) | ||||
|  | ||||
|         torch._enable_functionalization(reapply_views=True) | ||||
|         try: | ||||
|             outs = f(*f_args, **f_kwargs) | ||||
|         finally: | ||||
|             torch._disable_functionalization() | ||||
|  | ||||
|         # Detect input mutation and error if found | ||||
|         flat_args, _ = pytree.tree_flatten((args, kwargs)) | ||||
|         flat_f_args, _ = pytree.tree_flatten((f_args, f_kwargs)) | ||||
|  | ||||
|         # This is just for sanity checking, can be skipped | ||||
|         for arg, f_arg in zip(flat_args, flat_f_args): | ||||
|             if not isinstance(arg, Tensor): | ||||
|                 continue | ||||
|             torch._sync(f_arg) | ||||
|             new_arg = torch._from_functional_tensor(f_arg) | ||||
|             # I want to do this assert, but it is annoying because | ||||
|             # we have operator tests that have mutating inputs.  So | ||||
|             # I do something unsound instead | ||||
|             # assert arg is new_arg, "input argument was mutated, this is not valid" | ||||
|             if arg is not new_arg: | ||||
|                 assert arg.shape == new_arg.shape | ||||
|                 arg.copy_(new_arg) | ||||
|  | ||||
|         def from_fun(t): | ||||
|             if not isinstance(t, Tensor) or not torch._is_functional_tensor(t): | ||||
|                 return t | ||||
|             torch._sync(t) | ||||
|             return torch._from_functional_tensor(t) | ||||
|  | ||||
|         return pytree.tree_map(from_fun, outs) | ||||
|     return inner | ||||
|  | ||||
|  | ||||
| # This creates a joint forwards-backwards function given both | ||||
| # the primals (to run forwards) and tangents (to run backwards). | ||||
| # | ||||
| # It has a precondition which is that the passed in function | ||||
| # must be observationally pure; it is not permitted to mutate | ||||
| # the primals or tangents. | ||||
| def create_joint_forward_backward_pure(fn): | ||||
|     def joint_forward_backward( | ||||
|         primals: List[Any], tangents: List[Any] | ||||
|     ) -> Tuple[List[Any], List[Any]]: | ||||
| @ -366,7 +457,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi | ||||
|  | ||||
|     deduped_flat_args = remove_dupe_args(flat_args) | ||||
|  | ||||
|     joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args))) | ||||
|     joint_forward_backward = create_joint_forward_backward_pure(lambda *args: flat_fn(*add_dupe_args(args))) | ||||
|  | ||||
|     out = flat_fn(*flat_args) | ||||
|     # Collect info on which output tensors require gradients, | ||||
| @ -392,27 +483,13 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi | ||||
|  | ||||
|     if config.use_functionalize: | ||||
|         with enable_python_dispatcher(): | ||||
|             # Trace once without decompositions, into a graph of ATen ops. | ||||
|             # NB: tracing_mode is real, as it's assumed the calling context setup | ||||
|             # fake tensor mode / symbolic shapes if that is needed | ||||
|             fx_g = make_fx(joint_forward_backward)(*joint_inputs) | ||||
|  | ||||
|             context = disable_autocast_manager if disable_amp else nullcontext | ||||
|  | ||||
|             def fake_fn(primals, tangents): | ||||
|                 with torch.fx.traceback.override_stack_trace(): | ||||
|                     return torch.fx.Interpreter(fx_g).run(primals, tangents) | ||||
|  | ||||
|             # Trace a second time, running functionalization, and THEN running decompositions. | ||||
|             # functionalization only acts on ATen today, and doesn't currently handle | ||||
|             # view and inplace ops that come from primtorch. | ||||
|             # Eventually, functionalization should support primtorch view/inplace ops, | ||||
|             # which will make it ok to run decompositions before functionalization. | ||||
|             with context(): | ||||
|                 fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs) | ||||
|             fx_g = make_fx( | ||||
|                 detach_and_functionalize_pure(joint_forward_backward), aot_config.decompositions | ||||
|             )(*joint_inputs) | ||||
|         fx_g.graph.eliminate_dead_code() | ||||
|         fx_g.recompile() | ||||
|     else: | ||||
|         warnings.warn("graph partitioning without functionalization is not sound, we may introduce errors") | ||||
|         fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs) | ||||
|  | ||||
|     if config.debug_joint: | ||||
|  | ||||
| @ -416,6 +416,21 @@ class TestAOTAutograd(AOTTestCase): | ||||
|  | ||||
|         self.assertEqual(ref_out, test_out) | ||||
|  | ||||
|     def test_custom_autograd(self): | ||||
|         class CustomFn(torch.autograd.Function): | ||||
|             @staticmethod | ||||
|             def forward(ctx, x): | ||||
|                 return x.clone() | ||||
|  | ||||
|             @staticmethod | ||||
|             def backward(ctx, grad_output): | ||||
|                 return grad_output + 1 | ||||
|  | ||||
|         def f(x): | ||||
|             return CustomFn.apply(x) | ||||
|  | ||||
|         self.verify_aot_autograd(f, [torch.randn(3)]) | ||||
|  | ||||
|     @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") | ||||
|     def test_autocast_disable_guard(self): | ||||
|         guard = torch._C._DisableAutocast() | ||||
| @ -1099,12 +1114,10 @@ symbolic_aot_autograd_failures = { | ||||
|     xfail('masked.var', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=to... | ||||
|     xfail('matmul', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides | ||||
|     xfail('matrix_exp', ''),  # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... | ||||
|     xfail('max', 'reduction_no_dim'),  # aten.logical_or_.default - couldn't find symbolic meta function/dec... | ||||
|     xfail('max', 'reduction_with_dim'),  # Cannot call sizes() on tensor with symbolic sizes/strides | ||||
|     xfail('median', ''),  # could not find kernel | ||||
|     xfail('meshgrid', 'list_of_tensors'),  # Cannot call numel() on tensor with symbolic sizes/strides | ||||
|     xfail('meshgrid', 'variadic_tensors'),  # Cannot call numel() on tensor with symbolic sizes/strides | ||||
|     xfail('min', 'reduction_no_dim'),  # aten.logical_or_.default - couldn't find symbolic meta function/dec... | ||||
|     xfail('min', 'reduction_with_dim'),  # Cannot call sizes() on tensor with symbolic sizes/strides | ||||
|     xfail('mode', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides | ||||
|     xfail('msort', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides | ||||
| @ -1116,7 +1129,6 @@ symbolic_aot_autograd_failures = { | ||||
|  | ||||
|     # Deleting this in a followup | ||||
|     xfail('nn.functional.feature_alpha_dropout', 'with_train'), | ||||
|     xfail('nn.functional.pad', 'circular'), | ||||
|     xfail('nn.functional.poisson_nll_loss', ''), | ||||
|  | ||||
|     xfail('nn.functional._scaled_dot_product_attention', ''),  # Cannot call sizes() on tensor with symbolic ... | ||||
|  | ||||
| @ -141,26 +141,23 @@ class TestFunctionalization(TestCase): | ||||
| def forward(self, a_1): | ||||
|     view_copy = torch.ops.aten.view_copy.default(a_1, [1, 1024, 128, 128]);  a_1 = None | ||||
|     clone = torch.ops.aten.clone.default(view_copy);  view_copy = None | ||||
|     view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]);  clone = None | ||||
|     view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]) | ||||
|     relu = torch.ops.aten.relu.default(view_copy_1);  view_copy_1 = None | ||||
|     view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]);  clone = None | ||||
|     sum_1 = torch.ops.aten.sum.default(relu) | ||||
|     ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format);  sum_1 = None | ||||
|     expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]);  ones_like = None | ||||
|     new_zeros = torch.ops.aten.new_zeros.default(expand_copy, [16777216]) | ||||
|     as_strided_copy = torch.ops.aten.as_strided_copy.default(new_zeros, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) | ||||
|     as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(new_zeros, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0) | ||||
|     as_strided_scatter = torch.ops.aten.as_strided_scatter.default(new_zeros, expand_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0);  new_zeros = expand_copy = None | ||||
|     as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0);  as_strided_scatter = None | ||||
|     new_empty_strided = torch.ops.aten.new_empty_strided.default(as_strided_copy_2, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) | ||||
|     as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) | ||||
|     as_strided_copy_4 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0) | ||||
|     clone_1 = torch.ops.aten.clone.default(as_strided_copy_4, memory_format = torch.contiguous_format);  as_strided_copy_4 = None | ||||
|     _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]);  expand_copy = None | ||||
|     new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) | ||||
|     view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) | ||||
|     view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) | ||||
|     clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format);  view_copy_4 = None | ||||
|     threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0);  clone_1 = relu = None | ||||
|     _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1]) | ||||
|     detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy);  _reshape_alias_copy = None | ||||
|     as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(as_strided_copy_2, threshold_backward, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0);  as_strided_copy_2 = threshold_backward = None | ||||
|     _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(as_strided_scatter_1, [16, 64, 128, 128], [1048576, 16384, 128, 1]);  as_strided_scatter_1 = None | ||||
|     detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1);  _reshape_alias_copy_1 = None | ||||
|     _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]);  _reshape_alias_copy = None | ||||
|     detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1);  _reshape_alias_copy_1 = None | ||||
|     view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]);  threshold_backward = None | ||||
|     _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]);  view_copy_5 = None | ||||
|     detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2);  _reshape_alias_copy_2 = None | ||||
|     return detach_copy_1 | ||||
|     """)  # noqa: B950 | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user