From 6c713ccb5e0df227dd5b630057cbccd373cbe7d6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 Jan 2025 00:52:50 +0000 Subject: [PATCH] Revert "Make functionalization `ViewMeta` serializable with pickle. (#143712)" This reverts commit b8abdaa286fd161af48af57a675827f4f849914d. Reverted https://github.com/pytorch/pytorch/pull/143712 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/143712#issuecomment-2597205261)) --- .gitignore | 1 - BUILD.bazel | 3 - aten/src/ATen/FunctionalStorageImpl.cpp | 14 +- aten/src/ATen/FunctionalStorageImpl.h | 100 ++---- aten/src/ATen/FunctionalTensorWrapper.cpp | 74 ++-- aten/src/ATen/FunctionalTensorWrapper.h | 27 +- aten/src/ATen/FunctionalizeFallbackKernel.cpp | 58 ++- aten/src/ATen/FunctionalizeFallbackKernel.h | 58 --- aten/src/ATen/templates/FunctionalInverses.h | 12 +- .../templates/RegisterFunctionalization.cpp | 2 +- aten/src/ATen/templates/ViewMetaClasses.cpp | 19 - aten/src/ATen/templates/ViewMetaClasses.h | 12 - .../ViewMetaClassesPythonBinding.cpp | 11 - build.bzl | 2 - build_variables.bzl | 1 - caffe2/CMakeLists.txt | 2 - test/dynamo/test_aot_autograd_cache.py | 27 +- test/functorch/test_aotdispatch.py | 1 + tools/setup_helpers/generate_code.py | 33 +- torch/_C/__init__.pyi.in | 1 - torch/_C/_functionalization.pyi | 16 - .../_aot_autograd/autograd_cache.py | 14 + .../collect_metadata_analysis.py | 8 +- .../_aot_autograd/functional_utils.py | 83 +++-- .../_aot_autograd/input_output_analysis.py | 4 +- .../_aot_autograd/runtime_wrappers.py | 8 +- torch/_functorch/_aot_autograd/schemas.py | 30 +- torch/csrc/Module.cpp | 2 - .../python_torch_functions_manual.cpp | 9 + torch/csrc/functionalization/Module.cpp | 71 ---- torch/csrc/functionalization/Module.h | 36 -- torchgen/api/functionalization.py | 120 +++---- torchgen/api/types/signatures.py | 74 +++- torchgen/gen.py | 105 ++---- torchgen/gen_functionalization_type.py | 338 ++---------------- 35 files changed, 425 insertions(+), 951 deletions(-) delete mode 100644 aten/src/ATen/FunctionalizeFallbackKernel.h delete mode 100644 aten/src/ATen/templates/ViewMetaClasses.cpp delete mode 100644 aten/src/ATen/templates/ViewMetaClasses.h delete mode 100644 aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp delete mode 100644 torch/_C/_functionalization.pyi delete mode 100644 torch/csrc/functionalization/Module.cpp delete mode 100644 torch/csrc/functionalization/Module.h diff --git a/.gitignore b/.gitignore index c81f0734665a..8d4ceaa811c0 100644 --- a/.gitignore +++ b/.gitignore @@ -79,7 +79,6 @@ torch/return_types.pyi torch/nn/functional.pyi torch/utils/data/datapipes/datapipe.pyi torch/csrc/autograd/generated/* -torch/csrc/functionalization/generated/* torch/csrc/lazy/generated/*.[!m]* torch_compile_debug/ # Listed manually because some files in this directory are not generated diff --git a/BUILD.bazel b/BUILD.bazel index 893dbfc6cec0..df46835f363e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -90,8 +90,6 @@ generated_cpu_cpp = [ "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/VmapGeneratedPlumbing.h", - "aten/src/ATen/ViewMetaClasses.h", - "aten/src/ATen/ViewMetaClasses.cpp", "aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", @@ -1089,7 +1087,6 @@ test_suite( "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", - "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/ts_native_functions.yaml", diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index cae0ab0ba601..a5512818343f 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -9,6 +9,11 @@ namespace at::functionalization { +ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { + if (out_idx == this->out_index) return *this; + return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); +} + // Note [Functionalization: Alias Removal Part 2] // See Note [Functionalization: Alias Removal] for more details. // This function applies a single update from one of the views to the StorageImpl. @@ -42,7 +47,7 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co std::vector tmp_values({base}); tmp_values.reserve(update.view_metas.size()); for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { - at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back()); + at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // All of these ops require additional information to recover the sizes of the original tensor. // If need to, we could probably apply this optimization and only bother computing tmp_values @@ -50,8 +55,9 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co tmp_values.push_back(std::move(next_view)); } for(int64_t i = static_cast(update.view_metas.size()) - 1; i >= 0; --i) { + int64_t out_idx = update.view_metas[i].out_index; // Each view inverse is implemented in ViewInverses.cpp. - t = update.view_metas[i]->reverse(tmp_values[i], t); + t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); } TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); return t; @@ -105,13 +111,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); } -void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector>& metas) { +void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& metas) { TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); if (metas.size() > 1) { for (size_t i = 1; i < metas.size(); ++i) { // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI - TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided, + TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 71c259937e9d..3f80171196fb 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -8,89 +8,44 @@ namespace at::functionalization { // See Note [Functionalization Pass In Core] -enum class InverseReturnMode { - /// Specifies that functional inverses should always return a view. - AlwaysView, - /// Specifies that functional inverses should always return a non-view / copy. - NeverView, - /// Specifies that functional inverses should return a view unless a (copying) - /// scatter - /// inverse exists, in which case that will be used instead. - /// This avoids as_strided() calls that can be difficult for subclasses to - /// handle. - ViewOrScatterInverse, -}; - -#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \ - static const char* name() { \ - return #TYPE; \ - } - -#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \ - using SerializableTuple = std::tuple<__VA_ARGS__>; - // ViewMeta is a class used by the functionalization pass to navigate between // a base tensor and a view tensor. // For example, if I call `b = a.view1(...)` -// the functionalization pass will generate and store a ViewMeta specialization -// for `view1` operation on b that looks like: +// the functionalization pass will generate and store a ViewMeta on b that looks +// like: // -// struct TORCH_API view1_ViewMeta : public ViewMeta { -// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta); -// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( -// bool /* reapply_views */, -// const std::vector&); -// -// view1_ViewMeta(const SerializableTuple& tpl) -// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} -// -// view1_ViewMeta(bool reapply_views, const std::vector& size) -// : ViewMeta(/*has_symbolic_inputs=*/false), -// reapply_views(reapply_views), -// size(size) {} -// -// Tensor forward(const Tensor& base) override { -// return base.view1(...); +// ViewMeta( +// [](const Tensor& base, int64_t mutated_view_idx) { +// return base.view1(...); +// }, +// [](const at::Tensor& base, const at::Tensor& mutated_view, +// int64_t mutated_view_idx) -> at::Tensor { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); // } // -// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override { -// return at::functionalization::impl::view1_inverse(base, mutated_view, -// ...); -// } +// The forward_fn lambda describes how to replay view1 on a tensor. // -// SerializableTuple to_serializable_tuple() { -// return std::make_tuple(reapply_views, size); -// } -// -// bool reapply_views; -// std::vector size; -// }; -// -// The forward function describes how to replay view1 on a tensor. -// -// The reverse function describes how, given a tensor that is already a view, +// The reverse_fn lambda describes how, given a tensor that is already a view, // how to get the corresponding base tensor. See Note [Functionalization Pass: // View Inverses] for details. -// -// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type -// representing the `ViewMeta` instance state. Methods that take in/return such -// a type are used for supporting pickle serialization. struct ViewMeta { ViewMeta( + std::function forward, + std::function reverse, bool has_symbolic_inputs, bool is_multi_output = false, bool is_as_strided = false, int64_t out_idx = 0) - : out_index(out_idx), + : forward_fn(std::move(forward)), + reverse_fn(std::move(reverse)), + out_index(out_idx), is_multi_output(is_multi_output), is_as_strided(is_as_strided), has_symbolic_inputs(has_symbolic_inputs) {} - virtual ~ViewMeta() {} - - virtual Tensor forward(const Tensor& base) = 0; - virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; - + std::function forward_fn; + std::function reverse_fn; // See Note [out_idx in ViewMeta] int64_t out_index; @@ -102,17 +57,10 @@ struct ViewMeta { // Tells us if this view operation has any symbolic inputs bool has_symbolic_inputs; - // Returns a new ViewMeta with the same forward/reverse + // Returns a copy of the current ViewMeta, if out_idx matches the current + // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse // functions, but a new out index. - // - // This method should be implemented by those `ViewMeta` that have more than - // one output. - virtual std::shared_ptr to_out_index(int64_t out_index) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "ViewMeta::to_out_index not implemented. ", - "Likely because there's only one output."); - } + ViewMeta to_out_idx(int64_t out_idx); }; // FunctionalStorageImpl is a subclass of StorageImpl used by the @@ -145,14 +93,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const at::Tensor new_val; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::vector> view_metas; + const std::vector view_metas; }; explicit FunctionalStorageImpl(const Tensor& value); void add_update( const Tensor& updated_val, - const std::vector>& view_metas); + const std::vector& view_metas); bool apply_updates(); const Tensor& base() { return base_; diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 4aed2aac4a0c..409f944a88e3 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -129,19 +129,17 @@ void FunctionalTensorWrapper::freeze_storage() const { // - view_value: The output tensor that we need to wrap. // - base: The "base" of the view that `view_value` was generated from. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. -FunctionalTensorWrapper::FunctionalTensorWrapper( - const Tensor& view_value, - const FunctionalTensorWrapper* base, - const std::shared_ptr& meta) - : c10::TensorImpl( - c10::DispatchKeySet(DispatchKey::Functionalize), - view_value.dtype(), - view_value.device()), - value_(view_value), - is_multi_output_view_( - base->is_multi_output_view_ || meta->is_multi_output), - was_storage_changed_(base->was_storage_changed_), - is_symbolic_(base->is_symbolic_) { +FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) + : c10::TensorImpl( + c10::DispatchKeySet(DispatchKey::Functionalize), + view_value.dtype(), + view_value.device() + ), + value_(view_value), + is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), + was_storage_changed_(base->was_storage_changed_), + is_symbolic_(base->is_symbolic_) +{ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); @@ -150,10 +148,11 @@ FunctionalTensorWrapper::FunctionalTensorWrapper( view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); - maybe_mark_symbolic(meta.get()); + maybe_mark_symbolic(meta); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } + functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { return static_cast(storage_.unsafeGetStorageImpl()); } @@ -177,18 +176,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const { } // See Note [Functionalization Pass - Inplace View Ops] -void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr& meta) { +void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { view_metas_.push_back(meta); // Manually track the fact that this tensor recieved a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. - maybe_mark_symbolic(meta.get()); + maybe_mark_symbolic(meta); // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. at::AutoDispatchSkipFunctionalize guard; - value_ = meta->forward(value_); + value_ = meta.forward_fn(value_, meta.out_index); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } @@ -369,8 +368,15 @@ void FunctionalTensorWrapper::sync_() { regenerate_from_base(); } -const std::vector>& FunctionalTensorWrapper::view_metas() const { - return view_metas_; +Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { + auto t = base; + + // Reapply views to get the viewed tensor from the base in alias_ + for (auto& view_meta: view_metas_) { + t = view_meta.forward_fn(t, view_meta.out_index); + } + + return t; } void FunctionalTensorWrapper::regenerate_from_base() { @@ -379,7 +385,7 @@ void FunctionalTensorWrapper::regenerate_from_base() { auto t = storage_impl->base(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); - t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_); + t = apply_view_metas(t); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); replace_(t, /*from_lazy_regenerate=*/true); @@ -753,28 +759,20 @@ void freeze_functional_tensor(const Tensor& tensor) { functional_base_impl->freeze_storage(); } -Tensor create_functional_tensor_with_view_meta( - const at::Tensor& view_to_wrap, - const at::Tensor& base, - const std::shared_ptr& meta, - int64_t out_idx) { +Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); - auto meta_ = meta; if (out_idx != 0) { // Note [out_idx in ViewMeta] // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. - meta_ = meta->to_out_index(out_idx); + meta = meta.to_out_idx(out_idx); } - return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta_); + return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); } -std::vector create_functional_tensor_with_view_meta( - ITensorListRef view_to_wrap, - const at::Tensor& base, - const std::shared_ptr& meta) { +std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { std::vector outputs(view_to_wrap.size()); int64_t i = 0; for (const auto& tensor : view_to_wrap) { @@ -784,22 +782,12 @@ std::vector create_functional_tensor_with_view_meta( return outputs; } -void mutate_view_meta(const at::Tensor& self, const std::shared_ptr& meta) { +void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); self_impl->mutate_view_meta(meta); } -Tensor apply_view_meta_sequence( - const Tensor& base, - const std::vector>& sequence) { - Tensor r = base; - for (auto& vm : sequence) { - r = vm->forward(r); - } - return r; -} - // Note [Propagating strides in the functionalization pass] // In order to properly compute stride information, the functionalization pass // calls each {view} reference implementations with meta tensors. diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index f25a5637de3c..c418ef39427c 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { explicit FunctionalTensorWrapper( const Tensor& view_value, const FunctionalTensorWrapper* base, - const std::shared_ptr& meta); + const functionalization::ViewMeta& meta); // Get the underlying, actual tensor, that doesn't know anything about // functionalization. @@ -97,17 +97,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { ->are_all_mutations_under_no_grad_or_inference_mode(); } - void maybe_mark_symbolic(functionalization::ViewMeta* meta) { - is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; + void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { + is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; } bool is_symbolic() const { return is_symbolic_; } - // Retrieves the ViewMeta sequence of this tensor. - const std::vector>& view_metas() - const; + // Runs the forward_fn of every ViewMeta collected in the current instance + // to some other base. + Tensor apply_view_metas(const Tensor& base); // Sync's the underlying tensor with its alias, if it's out of date. This // involves two steps: 1) Apply any pending updates/mutations to the alias 2) @@ -144,8 +144,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // from the base tensor. This method is used by inplace-view ops like // transpose_. It appends a ViewMeta to the existing stack, and refreshes the // tensor by replaying the views off of the alias. - void mutate_view_meta( - const std::shared_ptr& meta); + void mutate_view_meta(const at::functionalization::ViewMeta& meta); // Custom implementation of self.set_(src) void set__impl(const FunctionalTensorWrapper* other); @@ -274,7 +273,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_symbolic_ = false; size_t generation_ = 0; - std::vector> view_metas_; + std::vector view_metas_; protected: static void copy_tensor_metadata( @@ -366,20 +365,16 @@ TORCH_API void propagate_xla_data_direct( Tensor create_functional_tensor_with_view_meta( const Tensor& view_to_wrap, const Tensor& base, - const std::shared_ptr& meta, + functionalization::ViewMeta meta, int64_t out_idx = 0); std::vector create_functional_tensor_with_view_meta( ITensorListRef view_to_wrap, const Tensor& base, - const std::shared_ptr& meta); + const functionalization::ViewMeta& meta); void mutate_view_meta( const Tensor& self, - const std::shared_ptr& meta); - -TORCH_API Tensor apply_view_meta_sequence( - const Tensor& base, - const std::vector>& sequence); + const functionalization::ViewMeta& meta); void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset( diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 1bf805d134f7..36b6f91c1d99 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -1,5 +1,3 @@ -#include - #include #include #include @@ -29,31 +27,6 @@ #include #endif -namespace at::functionalization { - -Tensor resize__ViewMeta::forward(const Tensor& base) { - if (reapply_views) { - return base.as_strided(size, c10::contiguous_strides(size)); - } else { - return at::as_strided_copy(base, size, c10::contiguous_strides(size)); - } -} - -Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { - return base.as_strided_scatter( - mutated_view, size, c10::contiguous_strides(size)); -} - -Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) { - return at::_unsafe_view_symint(base, size); -} - -Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { - return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); -} - -} // namespace at::functionalization - namespace { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -195,8 +168,19 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch // The output of resizing is equivalent to taking a slice of a larger tensor. // We have to emulate this "slicing" with an as_strided call. auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); - auto view_meta = std::make_shared( - reapply_views, size.vec()); + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + if (reapply_views) { + return base.as_strided(size, c10::contiguous_strides(size)); + } else { + return at::as_strided_copy(base, size, c10::contiguous_strides(size)); + } + }, + [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); + }, + /*has_symbolic_inputs=*/false + ); at::functionalization::impl::mutate_view_meta(self, view_meta); return self; } @@ -315,11 +299,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt tmp_output = at::_unsafe_view_symint(self_, size); } - bool has_symbolic_inputs = std::any_of( - size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); - auto view_meta = - std::make_shared( - has_symbolic_inputs, size.vec()); + bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); + + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return at::_unsafe_view_symint(base, size); + }, + [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); + }, + /*has_symbolic_inputs=*/has_symbolic_inputs + ); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); // See Note [Propagating strides in the functionalization pass] diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.h b/aten/src/ATen/FunctionalizeFallbackKernel.h deleted file mode 100644 index cd4f64a70fab..000000000000 --- a/aten/src/ATen/FunctionalizeFallbackKernel.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include - -namespace at::functionalization { - -// `ViewMeta` implementation for `resize_` operation. -struct TORCH_API resize__ViewMeta : public ViewMeta { - FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta); - FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( - bool /* reapply_views */, - const std::vector&); - - resize__ViewMeta(const SerializableTuple& tpl) - : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} - - resize__ViewMeta(bool reapply_views, const std::vector& size) - : ViewMeta(/*has_symbolic_inputs=*/false), - reapply_views(reapply_views), - size(size) {} - - Tensor forward(const Tensor& base) override; - Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; - - SerializableTuple to_serializable_tuple() { - return std::make_tuple(reapply_views, size); - } - - bool reapply_views; - std::vector size; -}; - -// `ViewMeta` implementation for `_unsafe_view` operation. -struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta { - FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta); - FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( - bool /* has_symbolic_inputs */, - const std::vector&); - - _unsafe_view_ViewMeta(const SerializableTuple& tpl) - : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} - - _unsafe_view_ViewMeta( - bool has_symbolic_inputs, - const std::vector& size) - : ViewMeta(has_symbolic_inputs), size(size) {} - - Tensor forward(const Tensor& base) override; - Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; - - SerializableTuple to_serializable_tuple() { - return std::make_tuple(has_symbolic_inputs, size); - } - - std::vector size; -}; - -} // namespace at::functionalization diff --git a/aten/src/ATen/templates/FunctionalInverses.h b/aten/src/ATen/templates/FunctionalInverses.h index b15cd09a6c65..3217e097d7ad 100644 --- a/aten/src/ATen/templates/FunctionalInverses.h +++ b/aten/src/ATen/templates/FunctionalInverses.h @@ -2,12 +2,22 @@ // ${generated_comment} -#include #include namespace at { namespace functionalization { +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to handle. + ViewOrScatterInverse, +}; + struct FunctionalInverses { ${view_inverse_declarations} diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index 93848d673f8b..999c06e2cb89 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/templates/ViewMetaClasses.cpp b/aten/src/ATen/templates/ViewMetaClasses.cpp deleted file mode 100644 index 0fd53171935f..000000000000 --- a/aten/src/ATen/templates/ViewMetaClasses.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// ${generated_comment} - -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -${op_headers} -#endif - -namespace at { -namespace functionalization { - -${view_meta_implementations} - -} // namespace functionalization -} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClasses.h b/aten/src/ATen/templates/ViewMetaClasses.h deleted file mode 100644 index be2dee2a871b..000000000000 --- a/aten/src/ATen/templates/ViewMetaClasses.h +++ /dev/null @@ -1,12 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -// ${generated_comment} - -#include - -namespace at { -namespace functionalization { - -${view_meta_declarations} - -} // namespace functionalization -} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp b/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp deleted file mode 100644 index c784e5abe5c8..000000000000 --- a/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include - -namespace torch::functionalization { - -void initGenerated(PyObject* module) { - auto functionalization = py::handle(module).cast(); - $view_meta_bindings -} - -} // namespace torch::functionalization diff --git a/build.bzl b/build.bzl index 6acbe49d790b..ad8ea1c8cef2 100644 --- a/build.bzl +++ b/build.bzl @@ -117,7 +117,6 @@ def define_targets(rules): ":LazyNonNativeIr.h", ":RegisterDispatchDefinitions.ini", ":RegisterDispatchKey.cpp", - ":ViewMetaClassesPythonBinding.cpp", ":native_functions.yaml", ":shape_inference.h", ":tags.yaml", @@ -298,7 +297,6 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [ "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", - "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ] GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP diff --git a/build_variables.bzl b/build_variables.bzl index a206c6a4f9a4..8bd8ad3a8df0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -929,7 +929,6 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", - "torch/csrc/functionalization/Module.cpp", "torch/csrc/instruction_counter/Module.cpp", ] + lazy_tensor_core_python_sources diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 7e4174a212d0..11b590a48175 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -310,7 +310,6 @@ set(GENERATED_CXX_PYTHON "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" - "${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ) set(GENERATED_H_PYTHON @@ -374,7 +373,6 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" - "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp" ${autograd_python} ${autograd_yaml} ${autograd_templates} diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 69f8310c4f6f..d543c7028b0b 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -250,7 +250,11 @@ class AOTAutogradCacheTests(InductorTestCase): @functorch_config.patch( {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} ) - def test_view_replay(self): + def test_view_replay_bypass(self): + """ + Shoud bypass when view replay is turned on + """ + def fn(a): tmp = a.detach() a.mul_(2) @@ -258,25 +262,10 @@ class AOTAutogradCacheTests(InductorTestCase): with torch.autograd._force_original_view_tracking(True): compiled_fn = torch.compile(fn) + compiled_fn(torch.rand(2, 3)) - def run_and_check(miss, hit, bypass): - self._clear_dynamo_and_codecache() - - inp = torch.rand(2, 3) - compiled_inp = inp.clone().detach() - - with torch.autograd._force_original_view_tracking(True): - out = fn(inp) - compiled_out = compiled_fn(compiled_inp) - - self.assertEqual(out, compiled_out) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit) - self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass) - - run_and_check(miss=1, hit=0, bypass=0) - run_and_check(miss=1, hit=1, bypass=0) - run_and_check(miss=1, hit=2, bypass=0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", False) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index fc1efde48835..50ef291417b1 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6897,6 +6897,7 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): { "enable_autograd_cache": True, "strict_autograd_cache": True, + "view_replay_for_aliased_outputs": False, } ) @torch._inductor.config.patch("fx_graph_cache", True) diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index a57732e5eba6..6e0a64888f0a 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -189,12 +189,6 @@ def main() -> None: ) options = parser.parse_args() - # Path: aten/src/ATen - aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) - operator_selector = get_selector( - options.selected_op_list_path, options.operators_yaml_path - ) - generate_code( options.gen_dir, options.native_functions_path, @@ -204,32 +198,13 @@ def main() -> None: options.disable_autograd, options.force_schema_registration, # options.selected_op_list - operator_selector=operator_selector, - ) - - # Generate the python bindings for functionalization's `ViewMeta` classes. - from torchgen.gen_functionalization_type import ( - gen_functionalization_view_meta_classes, - ) - - functionalization_templates_dir = os.path.join(aten_path, "templates") - functionalization_install_dir = os.path.join( - options.gen_dir, "torch/csrc/functionalization/generated" - ) - - os.makedirs(functionalization_install_dir, exist_ok=True) - assert os.path.isdir(functionalization_install_dir) - assert os.path.isdir(functionalization_templates_dir) - - gen_functionalization_view_meta_classes( - options.native_functions_path or NATIVE_FUNCTIONS_PATH, - options.tags_path or TAGS_PATH, - selector=operator_selector, - install_dir=functionalization_install_dir, - template_dir=functionalization_templates_dir, + operator_selector=get_selector( + options.selected_op_list_path, options.operators_yaml_path + ), ) if options.gen_lazy_ts_backend: + aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f12e11c2c445..e1ae17217f21 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -67,7 +67,6 @@ from . import ( _export, _cpu, _dynamo, - _functionalization, _functorch, _lazy, _lazy_ts_backend, diff --git a/torch/_C/_functionalization.pyi b/torch/_C/_functionalization.pyi deleted file mode 100644 index 4e00df97e271..000000000000 --- a/torch/_C/_functionalization.pyi +++ /dev/null @@ -1,16 +0,0 @@ -from torch import Tensor -from torch.types import _bool - -# Defined in torch/csrc/functionalization/Module.cpp - -class ViewMeta: - has_symbolic_inputs: _bool - -# Returns the list of ViewMeta instances of the given functional tensor. -# -# Although we do have python bindings for their types, we won't -# expose them here, since they should not be used by users. -def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ... - -# Applies the ViewMeta sequence on top of the given base. -def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ... diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index c112c8929229..38092a992258 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -227,6 +227,19 @@ def check_cacheable(gm: torch.fx.GraphModule): check_node_safe(node) +def check_metadata_cacheable(metadata: ViewAndMutationMeta): + """ + When view replay is turned on, we bypass autograd cache if + the output is aliased. + """ + if config.view_replay_for_aliased_outputs: + for info in metadata.output_info: + if info.functional_tensor is not None: + raise BypassAOTAutogradCache( + "Cannot cache a graph with functional tensor" + ) + + class AOTAutogradCacheDetails(FxGraphHashDetails): """ Object to capture all the details for a dynamo graph module relevant to computing @@ -862,6 +875,7 @@ class AOTAutogradCache: def save(key: str, entry: AOTAutogradCacheEntry, remote: bool): """Save a single entry into the cache.""" try: + check_metadata_cacheable(entry.runtime_metadata) content = pickle.dumps(entry) CacheArtifactManager.record_artifact( CacheArtifactType.AOT_AUTOGRAD, key, content diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 1476b1818126..0b7235358240 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -36,10 +36,10 @@ from .functional_utils import ( has_metadata_mutation, MetadataKey, to_fun, - ViewMetaSequence, was_inductor_storage_resized, ) from .schemas import ( + FunctionalTensorMetadataEq, InputAliasInfo, MutationType, OutputAliasInfo, @@ -604,7 +604,7 @@ from a multi-output view call" # # The FunctionalTensor will be saved if one of the 2 conditions below # is true: - view_meta_sequence = None + functional_tensor = None if ( # 1. If the output_type is either of: # (i) alias_of_intermediate; @@ -636,7 +636,7 @@ from a multi-output view call" and not input_info[base_idx].mutates_metadata ): if isinstance(o, FunctionalTensor): - view_meta_sequence = ViewMetaSequence(o) + functional_tensor = FunctionalTensorMetadataEq(o.elem) out_info = OutputAliasInfo( output_type=output_type, @@ -644,7 +644,7 @@ from a multi-output view call" base_idx=base_idx, dynamic_dims=dynamic_dims, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, - view_meta_sequence=view_meta_sequence, + functional_tensor=functional_tensor, ) output_info.append(out_info) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 4add4ae845d6..fb509296705e 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -13,12 +13,15 @@ from typing import Optional, Tuple import torch from torch import Tensor -from torch._C import _functionalization from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.meta_utils import is_sparse_any -from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr +from torch.fx.experimental.symbolic_shapes import ( + definitely_true, + sym_eq, + SymIntEqByExpr, +) from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, @@ -224,9 +227,9 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - target_view_meta_sequence: Optional[ViewMetaSequence] = None, + target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, *, - replay_views: bool, + replay_views, ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -245,11 +248,13 @@ def gen_alias_from_base( # to replay them (view functions) on the aliased_base_tensor. if ( replay_views - and target_view_meta_sequence is not None - and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) + and target_functional_tensor is not None + and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) ): - out = _functionalization.apply_view_meta_sequence( - aliased_base_tensor, target_view_meta_sequence.sequence + functional_tensor = target_functional_tensor.tensor + + out = torch._functionalize_apply_view_metas( + functional_tensor, aliased_base_tensor ) # If re-applying the ViewMeta sequence succeeded, there should be no more # problems going forward. We just check we got to the target shape and @@ -310,8 +315,28 @@ def gen_alias_from_base( return aliased_out +def has_same_metadata(t1, t2): + return ( + definitely_true(sym_eq(t1.size(), t2.size())) + and definitely_true(t1.layout == t2.layout) + and ( + is_sparse_any(t1) + or ( + definitely_true(sym_eq(t1.stride(), t2.stride())) + and definitely_true(t1.storage_offset() == t2.storage_offset()) + ) + ) + and t1.is_conj() == t2.is_conj() + and t1.is_neg() == t2.is_neg() + ) + + @dataclass(frozen=True) class MetadataKey: + """ + This should be equal whenever has_same_metadata would return True + """ + size: Tuple[SymIntEqByExpr, ...] layout: torch.layout is_sparse: bool @@ -335,45 +360,25 @@ class MetadataKey: ) -# ViewMeta sequence wrapper for equality comparisons. -# -# Even though we can compare each ViewMeta instance, we compare the resulting -# tensor metadata, instead. That's because the creation of synthetic bases + the -# re-generation of input views might end-up creating a different sequence of -# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same -# metadata. -# -# Therefore, we store what the end result should look like as serializable -# metadata. -# -# When logging, this class should look like: -# -# ViewMetaSequence(view, select_int, slice_Tensor) -# -# i.e. a parenthesized list of view operations within that ViewMeta sequence. -class ViewMetaSequence: - def __init__(self, tensor: FunctionalTensor) -> None: - assert torch._is_functional_tensor(tensor.elem) - self.sequence = _functionalization.get_view_meta_sequence(tensor.elem) - self.metadata = MetadataKey.make(tensor) - - def __repr__(self) -> str: - suffix = len("_ViewMeta") - types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence) - return f"ViewMetaSequence({types})" +# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata +# after applying all the ViewMeta operations. +class FunctionalTensorMetadataEq: + def __init__(self, tensor: torch.Tensor) -> None: + assert torch._is_functional_tensor(tensor) + self.tensor = tensor def __eq__(self, other: object) -> bool: # If other is None, then it probably means that we weren't able to recreate - # the ViewMeta sequence. One example is when we update the view metadata by - # calling: create_synthetic_base_metadata. + # the FunctionalTensorMetadataEq. One of this cases is when we update the + # view metadata by calling: create_synthetic_base_metadata. if other is None: return True - # Comparison against any other type is not implemented. - if not isinstance(other, ViewMetaSequence): + # Comparison agains any other type is not implemented. + if not isinstance(other, FunctionalTensorMetadataEq): return NotImplemented - return self.metadata == other.metadata + return has_same_metadata(self.tensor, other.tensor) # new_arg and arg here are either: diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index faa10e33547d..727b3af1e321 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -75,7 +75,7 @@ def remove_dupe_metadata( dynamic_dims=o.dynamic_dims, base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], requires_grad=o.requires_grad, - view_meta_sequence=o.view_meta_sequence, + functional_tensor=o.functional_tensor, ) for o in m.output_info ], @@ -226,7 +226,7 @@ def create_synthetic_base_metadata( # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases base_idx=new_base_idx, # type: ignore[arg-type] requires_grad=o.requires_grad, - view_meta_sequence=o.view_meta_sequence, + functional_tensor=o.functional_tensor, ) ) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index b81d3e929016..604d65408493 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -172,7 +172,7 @@ class AliasOfInputHandler: self.base_idx = info.base_idx self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.view_meta_sequence = info.view_meta_sequence + self.functional_tensor = info.functional_tensor self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -181,7 +181,7 @@ class AliasOfInputHandler: aliased_base_tensor, self.unwrap_out(out), self.requires_grad, - self.view_meta_sequence, + self.functional_tensor, replay_views=self.replay_views, ) @@ -209,7 +209,7 @@ class AliasOfIntermediateHandler: self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.view_meta_sequence = info.view_meta_sequence + self.functional_tensor = info.functional_tensor self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -218,7 +218,7 @@ class AliasOfIntermediateHandler: aliased_base_tensor, self.unwrap_out(out), self.requires_grad, - self.view_meta_sequence, + self.functional_tensor, replay_views=self.replay_views, ) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 22b9941ee404..bab5b7c35add 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -5,6 +5,7 @@ input/output types, metadata, config, function signatures etc. """ import collections +import dataclasses import functools from dataclasses import dataclass, field from enum import Enum @@ -19,7 +20,10 @@ from torch._subclasses.fake_tensor import is_fake from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config -from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + FunctionalTensorMetadataEq, +) from .utils import strict_zip @@ -88,14 +92,15 @@ class OutputAliasInfo: dynamic_dims: Optional[Set[int]] # requires_grad requires_grad: bool - # Sequence of ViewMeta objects. + # FunctionalTensorWrapper that represents this output. # - # Provides us the means to re-run view functions on other tensors. + # Provides us the means to replay views from it. # - # We need to wrap the actual list of ViewMeta with this class so that - # we compare the ViewMeta elements appropriately, i.e. their type and - # the elements returned by the `as_tuple()` call. - view_meta_sequence: Optional[ViewMetaSequence] = None + # We need to wrap the actual FunctionalTensorWrapper with this class so that + # we only compare the tensor's metadata. That's because with the transformations + # of the model throughout AOTAutograd, the sequence of ViewMeta and the base + # tensor might change. + functional_tensor: Optional[FunctionalTensorMetadataEq] = None class MutationType(Enum): @@ -577,6 +582,17 @@ class ViewAndMutationMeta: self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] # Clear traced tangents at runtime self.traced_tangents = [] + new_output_info = [] + for out in self.output_info: + if config.view_replay_for_aliased_outputs: + new_out = out + else: + # If we're not using view_replay, remove the functional tensor. + # Functional tensors are unfortunately not serializable, + # so doing this is required for AOTAutograd caching. + new_out = dataclasses.replace(out, functional_tensor=None) + new_output_info.append(new_out) + self.output_info = new_output_info for inp_meta in self.subclass_inp_meta: if isinstance(inp_meta, SubclassCreationMeta): inp_meta.make_runtime_safe() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ab59688adfe7..2230b15aeb3a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -71,7 +71,6 @@ #include #include #include -#include #include #include #include @@ -1839,7 +1838,6 @@ PyObject* initModule() { torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); - torch::functionalization::initModule(module); #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index b125413f2e71..a4d9eed924b2 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -633,6 +633,15 @@ void initTorchFunctions(PyObject* module) { at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::mark_mutation_hidden_from_autograd(t); }); + py_module.def( + "_functionalize_apply_view_metas", + [](const at::Tensor& tensor, const at::Tensor& base) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = + at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + return impl->apply_view_metas(base); + }); py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); diff --git a/torch/csrc/functionalization/Module.cpp b/torch/csrc/functionalization/Module.cpp deleted file mode 100644 index d38cb1078054..000000000000 --- a/torch/csrc/functionalization/Module.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include -#include - -#include -#include -#include -#include - -namespace torch::functionalization { - -void initModule(PyObject* module) { - auto m = py::handle(module).cast(); - - // Create a `torch._C._functionalization` Python module. - auto functionalization = m.def_submodule( - "_functionalization", "functionalization related pybind."); - - // Retrieve the ViewMeta sequence of a given functional tensor. - functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) { - TORCH_INTERNAL_ASSERT( - at::functionalization::impl::isFunctionalTensor(tensor)); - auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); - return impl->view_metas(); - }); - - // Applies the given ViewMeta sequence to the given base. - functionalization.def( - "apply_view_meta_sequence", - [](const at::Tensor& base, - const std::vector>& - sequence) { - return at::functionalization::impl::apply_view_meta_sequence( - base, sequence); - }); - - // Binding for InverseReturnMode. - py::enum_( - functionalization, "InverseReturnMode") - .value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView) - .value("NeverView", at::functionalization::InverseReturnMode::NeverView) - .value( - "ViewOrScatterInverse", - at::functionalization::InverseReturnMode::ViewOrScatterInverse); - - // Create bindings for the ViewMeta base class. - // - // Needed so that we can take a list of ViewMeta objects as parameter. - // Specifically, in the Python-side, we will have a list of derived ViewMeta - // classes. We need to tell pybind11 that all of those are, in fact, instances - // of different ViewMeta sub-types. - py::class_< - at::functionalization::ViewMeta, - std::shared_ptr>( - functionalization, "ViewMeta") - .def_property_readonly( - "has_symbolic_inputs", - [](const std::shared_ptr& meta) { - return meta->has_symbolic_inputs; - }); - - // Bindings for `ViewMeta` specializations manually implemented. - create_binding_with_pickle( - functionalization); - create_binding_with_pickle( - functionalization); - - // Bindings for `ViewMeta` specializations automatically generated. - initGenerated(functionalization.ptr()); -} - -} // namespace torch::functionalization diff --git a/torch/csrc/functionalization/Module.h b/torch/csrc/functionalization/Module.h deleted file mode 100644 index 2f77fd3098c3..000000000000 --- a/torch/csrc/functionalization/Module.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include - -#include -#include - -namespace torch::functionalization { - -// Creates the default bindings for `ViewMeta` specializations. -// -// Defines a constructor using the types in `SerializableTuple`, as well -// as pickle methods. -template -void create_binding_with_pickle(py::module m) { - py::class_, at::functionalization::ViewMeta>( - m, T::name()) - .def(py::init()) - .def( - "as_tuple", - [](const std::shared_ptr& meta) { - return meta->to_serializable_tuple(); - }) - .def(py::pickle( - [](const std::shared_ptr& meta) { - return meta->to_serializable_tuple(); - }, - [](const typename T::SerializableTuple& tpl) { - return std::make_shared(tpl); - })); -} - -void initModule(PyObject* module); -void initGenerated(PyObject* module); - -} // namespace torch::functionalization diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 367156d9ba7c..93667e39b17f 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -23,13 +23,20 @@ from torchgen.model import ( # This file describes the translation of JIT schema to API's used -# when creating `ViewMeta` specializations that are used by the functionalization pass. -# These API's mostly follow the dispatcher API, with one difference: -# - While the forward function just directly calls into the at::_ops API -# (following the dispatcher convention), the logic here for the reverse function +# when creating view lambdas that are used by the functionalization pass. +# There are two types of lambdas: forward lambdas and reverse lambdas. +# These API's mostly follow the dispatcher API, with a few quirks: +# - The lambda capture has to convert reference types to value types +# - While the forward lambda just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). +# The lambdas generated for each view op in the functionalization pass are of the form +# [capture_arguments](outer_arguments) -> returns_type { +# return name(inner_arguments); +# } + # Define some specific lambda input arguments. base_binding = Binding( name="base", @@ -39,18 +46,6 @@ base_binding = Binding( ), default=None, ) - -has_symbolic_inputs_binding = Binding( - name="has_symbolic_inputs", - nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)), - argument=Argument( - name="has_symbolic_inputs", - type=BaseType(BaseTy.bool), - default=None, - annotation=None, - ), - default=None, -) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), @@ -59,11 +54,11 @@ mutated_view_binding = Binding( ), default=None, ) -out_index_binding = Binding( - name="out_index", - nctype=NamedCType(name="out_index", type=BaseCType(longT)), +mutated_view_idx_binding = Binding( + name="mutated_view_idx", + nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), argument=Argument( - name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None ), default=None, ) @@ -91,13 +86,8 @@ inverse_return_mode_binding = Binding( ) -# Name of the `ViewMeta` specialization class created. -def classname(func: FunctionSchema, with_namespace: bool = False) -> str: - namespace = "at::functionalization::" if with_namespace else "" - return f"{namespace}{func.name.unambiguous_name()}_ViewMeta" - - -# Name of the operation called inside the `forward`/`reverse` implementations. +# The lambda capture itself doesn't have a name. +# The name returned here corresponds to the name of the inner function called by the lambda. def name( g: NativeFunctionsViewGroup, *, @@ -134,6 +124,24 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str: return f"{api_name}_inverse" +def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: + # capture arguments include all arguments except `self`. + # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), + # So any reference types (IntArrayRef) need to be converted to value types (vector) + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + all_bindings = [ + inverse_return_mode_binding if is_reverse else reapply_views_binding + ] + all_bindings.extend(non_self_value_bindings) + return all_bindings + + def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 @@ -144,49 +152,24 @@ def returns_type(func: FunctionSchema) -> CType: return BaseCType(tensorT) -# Checks whether `func` might return more than one value. -def is_multi_output(func: FunctionSchema) -> bool: - return len(func.returns) > 1 or ( - len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None - ) +def outer_arguments(*, is_reverse: bool) -> list[Binding]: + if is_reverse: + return [base_binding, mutated_view_binding, mutated_view_idx_binding] + else: + return [base_binding, mutated_view_idx_binding] -# `ViewMeta` specialization constructor parameters. -def base_ctor_arguments(func: FunctionSchema) -> list[Binding]: - # All specializations are paremeterized by `has_symbolic_inputs` flag. - arguments = [has_symbolic_inputs_binding] - - # If `func` might return more than 1 value, we also parameterize this specialization - # with the output index. - if is_multi_output(func): - arguments.append(out_index_binding) - - return arguments +def inner_call_index(func: FunctionSchema) -> Binding | None: + # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. + # When we replay a view op that returns multiple tensors, we need to index into the output appropriately + if len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() + ): + return mutated_view_idx_binding + return None -# `ViewMeta` specialized class' constructor arguments. -# -# Values needed specifically by this specialization, that the base class does not need. -# Same as the class' attributes, but non-owning. -def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]: - return attributes(func, owning=False) - - -# `ViewMeta` specialized class' non-static member data. -# -# Essential data for calling the instance's `forward` and `reverse functions. You can -# think of them as values that should be captured from the functionalization kernel. -def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: - args = func.arguments.flat_all - assert args[0].type == BaseType(BaseTy.Tensor) - return [ - reapply_views_binding, - inverse_return_mode_binding, - *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]], - ] - - -def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] @@ -200,12 +183,13 @@ def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. - if is_multi_output(func): + index_binding = inner_call_index(func) + if index_binding is not None: return [ base_binding, mutated_view_binding, inverse_return_mode_binding, - out_index_binding, + index_binding, ] + non_self_bindings else: return [ diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index f34028a5aa70..d7c60e52d93a 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -300,11 +300,83 @@ class ViewInverseSignature: return_type = functionalization.returns_type(self.g.view.func) decls = [ a.decl() - for a in functionalization.op_arguments(self.g.view.func, is_reverse=True) + for a in functionalization.inner_arguments( + self.g.view.func, is_reverse=True + ) ] return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" +@dataclass(frozen=True) +class FunctionalizationLambda: + g: NativeFunctionsViewGroup + + # are we generating the forward lambda or the reverse lambda? + is_reverse: bool + + def captures(self) -> list[Expr]: + # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments + # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, + # and plumb it into the lambda. + outer_ctx = dispatcher.arguments(self.g.view.func) + [ + functionalization.reapply_views_binding, + functionalization.inverse_return_mode_binding, + ] + capture_bindings = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + # allow_expensive_conversions is set because we want to convert + # some reference types (IntArrayRef) to value types (vector). + capture_exprs = translate.translate( + outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True + ) + return capture_exprs + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + capture_str = ", ".join( + f"{val.type.name} = {val.expr}" for val in self.captures() + ) + decls = [ + a.decl() + for a in functionalization.outer_arguments(is_reverse=self.is_reverse) + ] + return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" + + def inner_call(self, *, reapply_views: bool | None = None) -> str: + inner_call_name = functionalization.name( + self.g, + is_reverse=self.is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) + capture_ctx = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + full_ctx = arg_ctx + capture_ctx + + assert self.g.view_copy is not None + call_bindings = functionalization.inner_arguments( + self.g.view_copy.func, is_reverse=self.is_reverse + ) + maybe_index = functionalization.inner_call_index(self.g.view_copy.func) + call_exprs = [ + e.expr for e in translate.translate(full_ctx, call_bindings, method=False) + ] + if not self.is_reverse and maybe_index is not None: + return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' + else: + return f'{inner_call_name}({", ".join(call_exprs)});' + + @staticmethod + def from_func( + g: NativeFunctionsViewGroup, *, is_reverse: bool + ) -> FunctionalizationLambda: + return FunctionalizationLambda(g, is_reverse) + + @dataclass(frozen=True) class StructuredImplSignature: g: NativeFunctionsGroup diff --git a/torchgen/gen.py b/torchgen/gen.py index 5009495885b4..e9a10b9c52e8 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -45,8 +45,6 @@ from torchgen.gen_functionalization_type import ( gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, - gen_functionalization_view_meta_classes_decl, - gen_functionalization_view_meta_classes_impl, GenCompositeViewCopyKernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing @@ -2579,48 +2577,48 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f }, ) - def gen_op_headers( - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, - ) -> list[str]: - if isinstance(g, NativeFunctionsViewGroup): - # view ops always get a functionalization kernel - headers = [ - f"#include ", - f"#include ", - ] - if g.view_copy is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - elif isinstance(g, NativeFunctionsGroup): - headers = [ - f"#include ", - f"#include ", - f"#include ", - f"#include ", - ] - if g.inplace is not None: - headers += [ - f"#include ", - f"#include ", - ] - if g.mutable is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - else: - return [ - f"#include ", - f"#include ", - ] - def functionalization_env_callable( g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, ) -> dict[str, list[str]]: + def gen_op_headers( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> list[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + return { "ops_headers": gen_op_headers(g), "func_definitions": gen_functionalization_definition( @@ -2686,31 +2684,6 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f }, ) - cpu_fm.write( - "ViewMetaClasses.h", - lambda: { - "view_meta_declarations": list( - concatMap( - lambda g: gen_functionalization_view_meta_classes_decl(selector, g), - view_groups, - ) - ) - }, - ) - - cpu_fm.write( - "ViewMetaClasses.cpp", - lambda: { - "view_meta_implementations": list( - concatMap( - lambda g: gen_functionalization_view_meta_classes_impl(selector, g), - view_groups, - ) - ), - "op_headers": list(concatMap(gen_op_headers, view_groups)), - }, - ) - # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # needs to have a corresponding non-aliasing {view}_copy variant. diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 2d6ad3768728..4f9865d6d3eb 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,15 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Optional, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING -from torchgen.api import cpp, dispatcher, functionalization +from torchgen.api import cpp, dispatcher from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, CType, DispatcherSignature, + FunctionalizationLambda, iTensorListRefT, NativeSignature, OptionalCType, @@ -47,7 +48,7 @@ from torchgen.native_function_generation import ( MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) -from torchgen.utils import concatMap, dataclass_repr, FileManager +from torchgen.utils import dataclass_repr if TYPE_CHECKING: @@ -364,8 +365,6 @@ def emit_view_functionalization_body( with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) - spec = ViewMetaSpecialization(g, f=f) - # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) @@ -386,6 +385,9 @@ def emit_view_functionalization_body( for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] + forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) + reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) + # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ @@ -413,7 +415,19 @@ def emit_view_functionalization_body( : at::functionalization::InverseReturnMode::NeverView ); {symbolic_inputs_check} - auto view_meta = {spec.new()}; + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname} + ); auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); @@ -441,6 +455,7 @@ def emit_view_functionalization_body( """ else: + is_multi_output_view = isinstance(f.func.returns[0].type, ListType) return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} @@ -474,7 +489,21 @@ def emit_view_functionalization_body( }} }} {symbolic_inputs_check} - auto view_meta = {spec.new()}; + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname}, + /*is_multi_output=*/{str(is_multi_output_view).lower()}, + /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} + ); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] if (compute_reference_meta) {{ @@ -742,301 +771,6 @@ def gen_functionalization_view_inverse_declaration( return emit_decl_helper(g) -# Helper class for generating `ViewMeta` specializations. -@dataclass -class ViewMetaSpecialization: - g: NativeFunctionsViewGroup - f: NativeFunction - - @property - def is_multi_output(self) -> bool: - return functionalization.is_multi_output(self.f.func) - - @property - def is_as_strided(self) -> bool: - return str(self.f.func.name) == "as_strided" - - @property - def out_index(self) -> str: - if self.is_multi_output: - return functionalization.out_index_binding.name - return "0" - - @property - def classname(self) -> str: - return functionalization.classname(self.f.func) - - def decl(self) -> list[str]: - base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func) - extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func) - attributes = functionalization.attributes(self.f.func) - - # List of types for declaring the `SerializableTuple` type. - serializable_tuple_args = ",\n".join( - f" {binding.type} /* {binding.name} */" - for binding in (base_ctor_arguments + attributes) - ) - - # Arguments used for forwarding the tuple elements to the constructor. - destructure_tuple_args = ", ".join( - f"std::get<{i}>(tpl)" - for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments)) - ) - - # List of constructor parameters - ctor_parameters = ", ".join( - binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments) - ) - - # Call the base class `ViewMeta` constructor. - # - # Both of `is_multi_output` and `is_as_strided` are known values, given the - # operation schema. - is_multi_output_str = str(self.is_multi_output).lower() - is_as_strided_str = str(self.is_as_strided).lower() - - base_ctor_bindings = ", ".join( - [ - # `has_symbolic_inputs` is always taken as parameter. - functionalization.has_symbolic_inputs_binding.name, - f"/*is_multi_output=*/{is_multi_output_str}", - f"/*is_as_strided=*/{is_as_strided_str}", - # `out_index` is know if the operation returns only one value. Otherwise, - # we also take it as parameter. - f"/*out_index=*/{self.out_index}", - ] - ) - - # Assignments of `extra_ctor_arguments` to their corresponding fields. - # These are extra fields to-be-declared in this specialization. - # - # We need to set `allow_expensive_conversions`, since we are storing owned versions - # of the non-owning arguments. - ctor_assignments = ",\n".join( - f" {e.type.name}({e.expr})" - for e in translate( - extra_ctor_arguments, - attributes, - method=False, - allow_expensive_conversions=True, - ) - ) - - # List of arguments for constructing the `SerializableTuple` from an instance. - tuple_arguments = ", ".join( - binding.name for binding in (base_ctor_arguments + attributes) - ) - - # List of field declarations. - attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes) - - # Override `to_out_index` if this operation returns more than 1 value. - to_out_index_decl = "" - if self.is_multi_output: - to_out_index_decl = ( - " std::shared_ptr to_out_index(int64_t out_idx) override;" - ) - - return [ - f""" -struct TORCH_API {self.classname} : public ViewMeta {{ - FUNCTIONALIZATION_VIEWMETA_NAME({self.classname}); - FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args}); - - {self.classname}(const SerializableTuple& tpl) - : {self.classname}({destructure_tuple_args}) {{}} - - {self.classname}({ctor_parameters}) - : at::functionalization::ViewMeta({base_ctor_bindings}), -{ctor_assignments} {{}} - - Tensor forward(const Tensor& base) override; - Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; -{to_out_index_decl} - - SerializableTuple to_serializable_tuple() {{ - return std::make_tuple({tuple_arguments}); - }} - -{attr_declarations} -}}; -""" - ] - - # Generate a call to the actual operation. - def opcall(self, is_reverse: bool, reapply_views: bool) -> str: - opname = functionalization.name( - self.g, - is_reverse=is_reverse, - include_namespace=True, - reapply_views=reapply_views, - ) - - # Expected arguments for the operation. - assert self.g.view_copy is not None - op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse) - - # The context is composed by the constructor arguments (which are also - # the field variables stored in the instance), and the `base` tensor. - context = [functionalization.base_binding] - context += functionalization.base_ctor_arguments(self.f.func) - context += functionalization.attributes(self.f.func) - - # If we are generating the call for the reverse function, we also have - # access to `mutated_view` argument. - if is_reverse: - context.append(functionalization.mutated_view_binding) - - arguments = ", ".join( - [e.expr for e in translate(context, op_arguments, method=False)] - ) - - # Index the result if this operation returns multiple values. - maybe_index = "" - if not is_reverse and self.is_multi_output: - maybe_index = f"[{self.out_index}]" - - return f"{opname}({arguments}){maybe_index}" - - def impl(self) -> list[str]: - functions = [ - f""" -at::Tensor {self.classname}::forward(const at::Tensor& base) {{ - if (reapply_views) {{ - return {self.opcall(is_reverse=False, reapply_views=True)}; - }} else {{ - return {self.opcall(is_reverse=False, reapply_views=False)}; - }} -}}""", - f""" -at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{ - return {self.opcall(is_reverse=True, reapply_views=True)}; -}}""", - ] - - # If this operation returns multiple values, also generate a `to_out_index` - # implementation. - if self.is_multi_output: - functions.append(f""" -std::shared_ptr {self.classname}::to_out_index(int64_t out_index) {{ - return {self.new("out_index")}; -}} -""") - - return functions - - # Create the Python binding for this specialized class. - def binding(self) -> list[str]: - name = functionalization.classname(self.f.func, with_namespace=True) - return [f" create_binding_with_pickle<{name}>(functionalization);"] - - # Generate an instanciation of this specialized class. - def new(self, out_index: str = "0") -> str: - name = functionalization.classname(self.f.func, with_namespace=True) - ctor_arguments = functionalization.base_ctor_arguments( - self.f.func - ) + functionalization.extra_ctor_arguments(self.f.func) - # Replace the `out_index` parameter with the given `out_index`. - arguments = ", ".join( - binding.name if binding.name != "out_index" else out_index - for binding in ctor_arguments - ) - return f"std::make_shared<{name}>({arguments})" - - # Run the function `run` for both: `view` and `view_inplace` functions. - @staticmethod - def map( - g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]] - ) -> list[str]: - def maybe_run(f: Optional[NativeFunction]) -> list[str]: - if f is None: - return [] - with native_function_manager(f): - return run(ViewMetaSpecialization(g, f)) - - return list(concatMap(maybe_run, (g.view, g.view_inplace))) - - -def gen_functionalization_view_meta_classes_base( - selector: SelectiveBuilder, - g: NativeFunctionsViewGroup, - run: Callable[[ViewMetaSpecialization], list[str]], -) -> list[str]: - if not selector.include_all_operators: - return [] - - if g.composite: - return [] - - return ViewMetaSpecialization.map(g, run) - - -def gen_functionalization_view_meta_classes_decl( - selector: SelectiveBuilder, g: NativeFunctionsViewGroup -) -> list[str]: - return gen_functionalization_view_meta_classes_base( - selector, g, ViewMetaSpecialization.decl - ) - - -def gen_functionalization_view_meta_classes_impl( - selector: SelectiveBuilder, g: NativeFunctionsViewGroup -) -> list[str]: - return gen_functionalization_view_meta_classes_base( - selector, g, ViewMetaSpecialization.impl - ) - - -def gen_functionalization_view_meta_classes_binding( - selector: SelectiveBuilder, g: NativeFunctionsViewGroup -) -> list[str]: - return gen_functionalization_view_meta_classes_base( - selector, g, ViewMetaSpecialization.binding - ) - - -# Generates the Python bindings for the `ViewMeta` specialized classes. -def gen_functionalization_view_meta_classes( - native_functions_path: str, - tags_path: str, - selector: SelectiveBuilder, - install_dir: str, - template_dir: str, -) -> None: - from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml - - # Parse the native_functions.yaml. - # Then, group them into `NativeFunctionsViewGroup`. - # - # This is the same steps we do in gen.py (ATen codegen). - native_functions = parse_native_yaml( - native_functions_path, tags_path - ).native_functions - native_functions_with_view_groups = get_grouped_by_view_native_functions( - native_functions - ) - view_groups = [ - g - for g in native_functions_with_view_groups - if isinstance(g, NativeFunctionsViewGroup) - ] - - fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False) - fm.write( - "ViewMetaClassesPythonBinding.cpp", - lambda: { - "view_meta_bindings": list( - concatMap( - lambda g: gen_functionalization_view_meta_classes_binding( - selector, g - ), - view_groups, - ) - ), - }, - ) - - def gen_functionalization_registration( selector: SelectiveBuilder, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,