diff --git a/.gitignore b/.gitignore index ca87f1306e12..91d6c9f71550 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ torch/return_types.pyi torch/nn/functional.pyi torch/utils/data/datapipes/datapipe.pyi torch/csrc/autograd/generated/* +torch/csrc/functionalization/generated/* torch/csrc/lazy/generated/*.[!m]* torch_compile_debug/ # Listed manually because some files in this directory are not generated diff --git a/BUILD.bazel b/BUILD.bazel index 5d7625b40294..e814106e561f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -90,6 +90,8 @@ generated_cpu_cpp = [ "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/VmapGeneratedPlumbing.h", + "aten/src/ATen/ViewMetaClasses.h", + "aten/src/ATen/ViewMetaClasses.cpp", "aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", @@ -1074,6 +1076,7 @@ test_suite( "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", + "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/ts_native_functions.yaml", diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index 2cf8d9727f65..9631872875c6 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -9,11 +9,6 @@ namespace at::functionalization { -ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { - if (out_idx == this->out_index) return *this; - return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); -} - // Note [Functionalization: Alias Removal Part 2] // See Note [Functionalization: Alias Removal] for more details. // This function applies a single update from one of the views to the StorageImpl. @@ -42,12 +37,12 @@ ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) { at::Tensor t = update.new_val; TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); - if (update.view_metas.empty()) return t; + if (update.view_metas.empty()) { return t; } std::vector tmp_values({base}); tmp_values.reserve(update.view_metas.size()); for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { - at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); + at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back()); // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // All of these ops require additional information to recover the sizes of the original tensor. // If need to, we could probably apply this optimization and only bother computing tmp_values @@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co tmp_values.push_back(std::move(next_view)); } for(int64_t i = static_cast(update.view_metas.size()) - 1; i >= 0; --i) { - int64_t out_idx = update.view_metas[i].out_index; // Each view inverse is implemented in ViewInverses.cpp. - t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); + t = update.view_metas[i]->reverse(tmp_values[i], t); } TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); return t; @@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); } -void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& metas) { +void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector>& metas) { TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); if (metas.size() > 1) { for (size_t i = 1; i < metas.size(); ++i) { // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI - TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, + TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided, "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 8cd1cb7434aa..0c9c1fd775f3 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -8,44 +8,89 @@ namespace at::functionalization { // See Note [Functionalization Pass In Core] +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) + /// scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to + /// handle. + ViewOrScatterInverse, +}; + +#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \ + static const char* name() { \ + return #TYPE; \ + } + +#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \ + using SerializableTuple = std::tuple<__VA_ARGS__> + // ViewMeta is a class used by the functionalization pass to navigate between // a base tensor and a view tensor. // For example, if I call `b = a.view1(...)` -// the functionalization pass will generate and store a ViewMeta on b that looks -// like: +// the functionalization pass will generate and store a ViewMeta specialization +// for `view1` operation on b that looks like: // -// ViewMeta( -// [](const Tensor& base, int64_t mutated_view_idx) { -// return base.view1(...); -// }, -// [](const at::Tensor& base, const at::Tensor& mutated_view, -// int64_t mutated_view_idx) -> at::Tensor { -// return at::functionalization::impl::view1_inverse(base, mutated_view, -// ...); +// struct TORCH_API view1_ViewMeta : public ViewMeta { +// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta); +// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( +// bool /* reapply_views */, +// const std::vector&); +// +// view1_ViewMeta(const SerializableTuple& tpl) +// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} +// +// view1_ViewMeta(bool reapply_views, const std::vector& size) +// : ViewMeta(/*has_symbolic_inputs=*/false), +// reapply_views(reapply_views), +// size(size) {} +// +// Tensor forward(const Tensor& base) override { +// return base.view1(...); // } // -// The forward_fn lambda describes how to replay view1 on a tensor. +// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); +// } // -// The reverse_fn lambda describes how, given a tensor that is already a view, +// SerializableTuple to_serializable_tuple() { +// return std::make_tuple(reapply_views, size); +// } +// +// bool reapply_views; +// std::vector size; +// }; +// +// The forward function describes how to replay view1 on a tensor. +// +// The reverse function describes how, given a tensor that is already a view, // how to get the corresponding base tensor. See Note [Functionalization Pass: // View Inverses] for details. +// +// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type +// representing the `ViewMeta` instance state. Methods that take in/return such +// a type are used for supporting pickle serialization. struct ViewMeta { ViewMeta( - std::function forward, - std::function reverse, bool has_symbolic_inputs, bool is_multi_output = false, bool is_as_strided = false, int64_t out_idx = 0) - : forward_fn(std::move(forward)), - reverse_fn(std::move(reverse)), - out_index(out_idx), + : out_index(out_idx), is_multi_output(is_multi_output), is_as_strided(is_as_strided), has_symbolic_inputs(has_symbolic_inputs) {} - std::function forward_fn; - std::function reverse_fn; + virtual ~ViewMeta() = default; + + virtual Tensor forward(const Tensor& base) = 0; + virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; + // See Note [out_idx in ViewMeta] int64_t out_index; @@ -57,10 +102,17 @@ struct ViewMeta { // Tells us if this view operation has any symbolic inputs bool has_symbolic_inputs; - // Returns a copy of the current ViewMeta, if out_idx matches the current - // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse + // Returns a new ViewMeta with the same forward/reverse // functions, but a new out index. - ViewMeta to_out_idx(int64_t out_idx); + // + // This method should be implemented by those `ViewMeta` that have more than + // one output. + virtual std::shared_ptr to_out_index(int64_t out_index) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "ViewMeta::to_out_index not implemented. ", + "Likely because there's only one output."); + } }; // FunctionalStorageImpl is a subclass of StorageImpl used by the @@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const at::Tensor new_val; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::vector view_metas; + const std::vector> view_metas; }; explicit FunctionalStorageImpl(const Tensor& value); void add_update( const Tensor& updated_val, - const std::vector& view_metas); + const std::vector>& view_metas); bool apply_updates(); const Tensor& base() { return base_; diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 0a2fa153a6cf..d553cc1fb949 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const { // - view_value: The output tensor that we need to wrap. // - base: The "base" of the view that `view_value` was generated from. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. -FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) - : c10::TensorImpl( - c10::DispatchKeySet(DispatchKey::Functionalize), - view_value.dtype(), - base->storage().data_ptr().device() - ), - value_(view_value), - is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), - was_storage_changed_(base->was_storage_changed_), - is_symbolic_(base->is_symbolic_) -{ +FunctionalTensorWrapper::FunctionalTensorWrapper( + const Tensor& view_value, + const FunctionalTensorWrapper* base, + const std::shared_ptr& meta) + : c10::TensorImpl( + c10::DispatchKeySet(DispatchKey::Functionalize), + view_value.dtype(), + base->storage().data_ptr().device()), + value_(view_value), + is_multi_output_view_( + base->is_multi_output_view_ || meta->is_multi_output), + was_storage_changed_(base->was_storage_changed_), + is_symbolic_(base->is_symbolic_) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); @@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); - maybe_mark_symbolic(meta); + maybe_mark_symbolic(meta.get()); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } - functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { return static_cast(storage_.unsafeGetStorageImpl()); } @@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const { } // See Note [Functionalization Pass - Inplace View Ops] -void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { +void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr& meta) { view_metas_.push_back(meta); // Manually track the fact that this tensor received a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. - maybe_mark_symbolic(meta); + maybe_mark_symbolic(meta.get()); // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. at::AutoDispatchSkipFunctionalize guard; - value_ = meta.forward_fn(value_, meta.out_index); + value_ = meta->forward(value_); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } @@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() { regenerate_from_base(); } -Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { - auto t = base; - - // Reapply views to get the viewed tensor from the base in alias_ - for (auto& view_meta: view_metas_) { - t = view_meta.forward_fn(t, view_meta.out_index); - } - - return t; +const std::vector>& FunctionalTensorWrapper::view_metas() const { + return view_metas_; } void FunctionalTensorWrapper::regenerate_from_base() { @@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() { auto t = storage_impl->base(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); - t = apply_view_metas(t); + t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); replace_(t, /*from_lazy_regenerate=*/true); @@ -727,11 +721,11 @@ bool isFunctionalTensor(const std::optional& t) { } bool isFunctionalTensor(const c10::List<::std::optional>& t_list) { - if (t_list.empty()) return false; + if (t_list.empty()) { return false; } auto functional_count = 0; for (const auto i : c10::irange(t_list.size())) { auto const & e= t_list[i]; - if (!e.has_value() || !e->defined()) continue; + if (!e.has_value() || !e->defined()) { continue; } if (isFunctionalTensor(e)) { ++functional_count; } @@ -741,10 +735,10 @@ bool isFunctionalTensor(const c10::List<::std::optional>& t_list) { template static bool isFunctionalTensorIListRef(c10::IListRef list) { - if (list.size() == 0) return false; + if (list.size() == 0) { return false; } auto functional_count = 0; for (const auto& tensor : list) { - if (!tensor.defined()) continue; + if (!tensor.defined()) { continue; } if (isFunctionalTensor(tensor)) { ++functional_count; } @@ -762,20 +756,28 @@ void freeze_functional_tensor(const Tensor& tensor) { functional_base_impl->freeze_storage(); } -Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { +Tensor create_functional_tensor_with_view_meta( + const at::Tensor& view_to_wrap, + const at::Tensor& base, + const std::shared_ptr& meta, + int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); + auto meta_ = meta; if (out_idx != 0) { // Note [out_idx in ViewMeta] // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. - meta = meta.to_out_idx(out_idx); + meta_ = meta->to_out_index(out_idx); } - return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); + return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta_); } -std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { +std::vector create_functional_tensor_with_view_meta( + ITensorListRef view_to_wrap, + const at::Tensor& base, + const std::shared_ptr& meta) { std::vector outputs(view_to_wrap.size()); int64_t i = 0; for (const auto& tensor : view_to_wrap) { @@ -785,12 +787,22 @@ std::vector create_functional_tensor_with_view_meta(ITensorListRef view_ return outputs; } -void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { +void mutate_view_meta(const at::Tensor& self, const std::shared_ptr& meta) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); self_impl->mutate_view_meta(meta); } +Tensor apply_view_meta_sequence( + const Tensor& base, + const std::vector>& sequence) { + Tensor r = base; + for (auto& vm : sequence) { + r = vm->forward(r); + } + return r; +} + // Note [Propagating strides in the functionalization pass] // In order to properly compute stride information, the functionalization pass // calls each {view} reference implementations with meta tensors. @@ -884,7 +896,7 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s const auto& ivalue = returns[idx]; if (ivalue.isTensor()) { const auto& t = ivalue.toTensor(); - if (!t.defined()) continue; + if (!t.defined()) { continue; } at::functionalization::impl::sync(t); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index b260b7c9f958..6d9050728da7 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { explicit FunctionalTensorWrapper( const Tensor& view_value, const FunctionalTensorWrapper* base, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); // Get the underlying, actual tensor, that doesn't know anything about // functionalization. @@ -99,17 +99,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { ->are_all_mutations_under_no_grad_or_inference_mode(); } - void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { - is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; + void maybe_mark_symbolic(functionalization::ViewMeta* meta) { + is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; } bool is_symbolic() const { return is_symbolic_; } - // Runs the forward_fn of every ViewMeta collected in the current instance - // to some other base. - Tensor apply_view_metas(const Tensor& base); + // Retrieves the ViewMeta sequence of this tensor. + const std::vector>& view_metas() + const; // Sync's the underlying tensor with its alias, if it's out of date. This // involves two steps: 1) Apply any pending updates/mutations to the alias 2) @@ -146,7 +146,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // from the base tensor. This method is used by inplace-view ops like // transpose_. It appends a ViewMeta to the existing stack, and refreshes the // tensor by replaying the views off of the alias. - void mutate_view_meta(const at::functionalization::ViewMeta& meta); + void mutate_view_meta( + const std::shared_ptr& meta); // Custom implementation of self.set_(src) void set__impl(const FunctionalTensorWrapper* other); @@ -285,7 +286,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_symbolic_ = false; size_t generation_ = 0; - std::vector view_metas_; + std::vector> view_metas_; protected: static void copy_tensor_metadata( @@ -377,16 +378,20 @@ TORCH_API void propagate_xla_data_direct( Tensor create_functional_tensor_with_view_meta( const Tensor& view_to_wrap, const Tensor& base, - functionalization::ViewMeta meta, + const std::shared_ptr& meta, int64_t out_idx = 0); std::vector create_functional_tensor_with_view_meta( ITensorListRef view_to_wrap, const Tensor& base, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); void mutate_view_meta( const Tensor& self, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); + +TORCH_API Tensor apply_view_meta_sequence( + const Tensor& base, + const std::vector>& sequence); void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset( diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 97094c9f125a..10f988b4d281 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -7,7 +9,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -28,6 +29,31 @@ #include #endif +namespace at::functionalization { + +Tensor resize__ViewMeta::forward(const Tensor& base) { + if (reapply_views) { + return base.as_strided(size, c10::contiguous_strides(size)); + } else { + return at::as_strided_copy(base, size, c10::contiguous_strides(size)); + } +} + +Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { + return base.as_strided_scatter( + mutated_view, size, c10::contiguous_strides(size)); +} + +Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) { + return at::_unsafe_view_symint(base, size); +} + +Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { + return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); +} + +} // namespace at::functionalization + namespace { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -106,7 +132,9 @@ namespace { const auto& ivalue = returns[idx]; if (ivalue.isTensor() && should_wrap_outputs) { const auto& t = ivalue.toTensor(); - if (!t.defined()) continue; + if (!t.defined()) { + continue; + } auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; } else if (ivalue.isTensorList() && should_wrap_outputs) { @@ -169,19 +197,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch // The output of resizing is equivalent to taking a slice of a larger tensor. // We have to emulate this "slicing" with an as_strided call. auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - if (reapply_views) { - return base.as_strided(size, c10::contiguous_strides(size)); - } else { - return at::as_strided_copy(base, size, c10::contiguous_strides(size)); - } - }, - [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); - }, - /*has_symbolic_inputs=*/false - ); + auto view_meta = std::make_shared( + reapply_views, size.vec()); at::functionalization::impl::mutate_view_meta(self, view_meta); return self; } @@ -300,17 +317,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt tmp_output = at::_unsafe_view_symint(self_, size); } - bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); - - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return at::_unsafe_view_symint(base, size); - }, - [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); - }, - /*has_symbolic_inputs=*/has_symbolic_inputs - ); + bool has_symbolic_inputs = std::any_of( + size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); + auto view_meta = + std::make_shared( + has_symbolic_inputs, size.vec()); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); // See Note [Propagating strides in the functionalization pass] diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.h b/aten/src/ATen/FunctionalizeFallbackKernel.h new file mode 100644 index 000000000000..aabcfc827af3 --- /dev/null +++ b/aten/src/ATen/FunctionalizeFallbackKernel.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +namespace at::functionalization { + +// `ViewMeta` implementation for `resize_` operation. +struct TORCH_API resize__ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta) + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* reapply_views */, + const std::vector&); + + resize__ViewMeta(const SerializableTuple& tpl) + : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + resize__ViewMeta(bool reapply_views, const std::vector& size) + : ViewMeta(/*has_symbolic_inputs=*/false), + reapply_views(reapply_views), + size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(reapply_views, size); + } + + bool reapply_views; + std::vector size; +}; + +// `ViewMeta` implementation for `_unsafe_view` operation. +struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta) + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* has_symbolic_inputs */, + const std::vector&); + + _unsafe_view_ViewMeta(const SerializableTuple& tpl) + : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + _unsafe_view_ViewMeta( + bool has_symbolic_inputs, + const std::vector& size) + : ViewMeta(has_symbolic_inputs), size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(has_symbolic_inputs, size); + } + + std::vector size; +}; + +} // namespace at::functionalization diff --git a/aten/src/ATen/templates/FunctionalInverses.h b/aten/src/ATen/templates/FunctionalInverses.h index 3217e097d7ad..b15cd09a6c65 100644 --- a/aten/src/ATen/templates/FunctionalInverses.h +++ b/aten/src/ATen/templates/FunctionalInverses.h @@ -2,22 +2,12 @@ // ${generated_comment} +#include #include namespace at { namespace functionalization { -enum class InverseReturnMode { - /// Specifies that functional inverses should always return a view. - AlwaysView, - /// Specifies that functional inverses should always return a non-view / copy. - NeverView, - /// Specifies that functional inverses should return a view unless a (copying) scatter - /// inverse exists, in which case that will be used instead. - /// This avoids as_strided() calls that can be difficult for subclasses to handle. - ViewOrScatterInverse, -}; - struct FunctionalInverses { ${view_inverse_declarations} diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index dc8619c25fc5..408aff0cdab4 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/templates/ViewMetaClasses.cpp b/aten/src/ATen/templates/ViewMetaClasses.cpp new file mode 100644 index 000000000000..0fd53171935f --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClasses.cpp @@ -0,0 +1,19 @@ +// ${generated_comment} + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +${op_headers} +#endif + +namespace at { +namespace functionalization { + +${view_meta_implementations} + +} // namespace functionalization +} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClasses.h b/aten/src/ATen/templates/ViewMetaClasses.h new file mode 100644 index 000000000000..be2dee2a871b --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClasses.h @@ -0,0 +1,12 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include + +namespace at { +namespace functionalization { + +${view_meta_declarations} + +} // namespace functionalization +} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp b/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp new file mode 100644 index 000000000000..c784e5abe5c8 --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp @@ -0,0 +1,11 @@ +#include +#include + +namespace torch::functionalization { + +void initGenerated(PyObject* module) { + auto functionalization = py::handle(module).cast(); + $view_meta_bindings +} + +} // namespace torch::functionalization diff --git a/buckbuild.bzl b/buckbuild.bzl index 047ed71ad279..d1363a7e9d73 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -391,6 +391,8 @@ def get_aten_generated_files(enabled_backends): "CompositeExplicitAutogradFunctions_inl.h", "CompositeExplicitAutogradNonFunctionalFunctions.h", "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", + "ViewMetaClasses.h", + "ViewMetaClasses.cpp", "VmapGeneratedPlumbing.h", "core/ATenOpList.cpp", "core/TensorBody.h", @@ -1193,6 +1195,7 @@ def define_buck_targets( "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", "Operators.h": ":gen_aten[Operators.h]", "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", + "ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]", "core/TensorBody.h": ":gen_aten[core/TensorBody.h]", "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", "core/enum_tag.h": ":gen_aten[core/enum_tag.h]", diff --git a/build.bzl b/build.bzl index 7c2c3e24dc5a..91529e75c9f0 100644 --- a/build.bzl +++ b/build.bzl @@ -118,6 +118,9 @@ def define_targets(rules): ":LazyNonNativeIr.h", ":RegisterDispatchDefinitions.ini", ":RegisterDispatchKey.cpp", + ":ViewMetaClassesPythonBinding.cpp", + ":ViewMetaClasses.cpp", + ":ViewMetaClasses.h", ":native_functions.yaml", ":shape_inference.h", ":tags.yaml", @@ -170,6 +173,7 @@ GENERATED_H = [ "FunctionalInverses.h", "RedispatchFunctions.h", "RegistrationDeclarations.h", + "ViewMetaClasses.h", "VmapGeneratedPlumbing.h", ] @@ -246,6 +250,7 @@ GENERATED_CPP = [ "RegisterFunctionalization_1.cpp", "RegisterFunctionalization_2.cpp", "RegisterFunctionalization_3.cpp", + "ViewMetaClasses.cpp", ] GENERATED_CPP_CORE = [ @@ -307,6 +312,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [ "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", + "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ] GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP diff --git a/build_variables.bzl b/build_variables.bzl index 01b204458eee..ecd1e8b79f65 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1010,6 +1010,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", + "torch/csrc/functionalization/Module.cpp", "torch/csrc/instruction_counter/Module.cpp", "torch/nativert/python/Bindings.cpp", ] + lazy_tensor_core_python_sources @@ -1052,6 +1053,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", + "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp", ]] _libtorch_python_sources.extend(libtorch_python_core_sources) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 51e4023b0d18..287e39f8eb99 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -316,6 +316,7 @@ set(GENERATED_CXX_PYTHON "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" + "${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ) set(GENERATED_H_PYTHON @@ -379,6 +380,9 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" + "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h" + "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp" + "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp" ${autograd_python} ${autograd_yaml} ${autograd_templates} diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index d3a8dcabaa7e..84f5f8bd3e62 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -156,6 +156,7 @@ def get_generate_code_bin_outs(): "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], + "functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"], }) return outs diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 68ac9d427f8e..04af76c90c52 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -519,11 +519,7 @@ class AOTAutogradCacheTests(InductorTestCase): @functorch_config.patch( {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} ) - def test_view_replay_bypass(self): - """ - Should bypass when view replay is turned on - """ - + def test_view_replay(self): def fn(a): tmp = a.detach() a.mul_(2) @@ -531,10 +527,25 @@ class AOTAutogradCacheTests(InductorTestCase): with torch.autograd._force_original_view_tracking(True): compiled_fn = torch.compile(fn) - compiled_fn(torch.rand(2, 3)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + def run_and_check(miss, hit, bypass): + self._clear_dynamo_and_codecache() + + inp = torch.rand(2, 3) + compiled_inp = inp.clone().detach() + + with torch.autograd._force_original_view_tracking(True): + out = fn(inp) + compiled_out = compiled_fn(compiled_inp) + + self.assertEqual(out, compiled_out) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass) + + run_and_check(miss=1, hit=0, bypass=0) + run_and_check(miss=1, hit=1, bypass=0) + run_and_check(miss=1, hit=2, bypass=0) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 0f697acc886a..aa0cc23fcd76 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8500,7 +8500,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): { "enable_autograd_cache": True, "strict_autograd_cache": True, - "view_replay_for_aliased_outputs": False, } ) @torch._inductor.config.patch("fx_graph_cache", True) diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 64a12c0d228c..e53efd7288c1 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -189,6 +189,12 @@ def main() -> None: ) options = parser.parse_args() + # Path: aten/src/ATen + aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) + operator_selector = get_selector( + options.selected_op_list_path, options.operators_yaml_path + ) + generate_code( options.gen_dir, options.native_functions_path, @@ -198,18 +204,37 @@ def main() -> None: options.disable_autograd, options.force_schema_registration, # options.selected_op_list - operator_selector=get_selector( - options.selected_op_list_path, options.operators_yaml_path - ), + operator_selector=operator_selector, + ) + + # Generate the python bindings for functionalization's `ViewMeta` classes. + from torchgen.gen_functionalization_type import ( + gen_functionalization_view_meta_classes, + ) + + functionalization_templates_dir = os.path.join(aten_path, "templates") + install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc") + functionalization_install_dir = os.path.join( + install_dir, "functionalization", "generated" + ) + + os.makedirs(functionalization_install_dir, exist_ok=True) + assert os.path.isdir(functionalization_install_dir) + assert os.path.isdir(functionalization_templates_dir) + + gen_functionalization_view_meta_classes( + options.native_functions_path or NATIVE_FUNCTIONS_PATH, + options.tags_path or TAGS_PATH, + selector=operator_selector, + install_dir=functionalization_install_dir, + template_dir=functionalization_templates_dir, ) if options.gen_lazy_ts_backend: - aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" - install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc") - lazy_install_dir = os.path.join(install_dir, "lazy/generated") + lazy_install_dir = os.path.join(install_dir, "lazy", "generated") os.makedirs(lazy_install_dir, exist_ok=True) assert os.path.isfile(ts_backend_yaml), ( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 08b2616b3952..147dc9a86524 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -30,6 +30,7 @@ from torch._C import ( _cpu, _dynamo, _export, + _functionalization, _functorch, _lazy, _lazy_ts_backend, diff --git a/torch/_C/_functionalization.pyi b/torch/_C/_functionalization.pyi new file mode 100644 index 000000000000..4e00df97e271 --- /dev/null +++ b/torch/_C/_functionalization.pyi @@ -0,0 +1,16 @@ +from torch import Tensor +from torch.types import _bool + +# Defined in torch/csrc/functionalization/Module.cpp + +class ViewMeta: + has_symbolic_inputs: _bool + +# Returns the list of ViewMeta instances of the given functional tensor. +# +# Although we do have python bindings for their types, we won't +# expose them here, since they should not be used by users. +def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ... + +# Applies the ViewMeta sequence on top of the given base. +def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ... diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 4d370766aaf9..4d6a881b2a45 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -284,19 +284,6 @@ def check_cacheable(gm: torch.fx.GraphModule): check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] -def check_metadata_cacheable(metadata: ViewAndMutationMeta): - """ - When view replay is turned on, we bypass autograd cache if - the output is aliased. - """ - if config.view_replay_for_aliased_outputs: - for info in metadata.output_info: - if info.functional_tensor is not None: - raise BypassAOTAutogradCache( - "Cannot cache a graph with functional tensor" - ) - - class AOTAutogradCacheDetails(FxGraphHashDetails): """ Object to capture all the details for a dynamo graph module relevant to computing @@ -803,7 +790,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]): """ Perform any preparations to make the cache entry ready for serialization. """ - check_metadata_cacheable(self.runtime_metadata) self.compiled_fw.pre_save() if self.compiled_bw is not None: self.compiled_bw.pre_save() diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 19d08a64f967..acfd40fe78c7 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -43,10 +43,10 @@ from .functional_utils import ( has_metadata_mutation, MetadataKey, to_fun, + ViewMetaSequence, was_inductor_storage_resized, ) from .schemas import ( - FunctionalTensorMetadataEq, InputAliasInfo, MemoryFormatMeta, MutationType, @@ -640,7 +640,7 @@ from a multi-output view call" # # The FunctionalTensor will be saved if one of the 2 conditions below # is true: - functional_tensor = None + view_meta_sequence = None if ( # 1. If the output_type is either of: # (i) alias_of_intermediate; @@ -672,7 +672,7 @@ from a multi-output view call" and not input_info[base_idx].mutates_metadata ): if isinstance(o, FunctionalTensor): - functional_tensor = FunctionalTensorMetadataEq(o.elem) + view_meta_sequence = ViewMetaSequence(o) out_info = OutputAliasInfo( output_type=output_type, @@ -680,7 +680,7 @@ from a multi-output view call" base_idx=base_idx, dynamic_dims=dynamic_dims, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, - functional_tensor=functional_tensor, + view_meta_sequence=view_meta_sequence, ) output_info.append(out_info) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 4e74ed6341b9..958804e5c763 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -14,6 +14,7 @@ from typing import Optional import torch from torch import Tensor +from torch._C import _functionalization from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor @@ -224,9 +225,9 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, + target_view_meta_sequence: Optional[ViewMetaSequence] = None, *, - replay_views, + replay_views: bool, ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -245,13 +246,11 @@ def gen_alias_from_base( # to replay them (view functions) on the aliased_base_tensor. if ( replay_views - and target_functional_tensor is not None - and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) + and target_view_meta_sequence is not None + and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) ): - functional_tensor = target_functional_tensor.tensor - - out = torch._functionalize_apply_view_metas( - functional_tensor, aliased_base_tensor + out = _functionalization.apply_view_meta_sequence( + aliased_base_tensor, target_view_meta_sequence.sequence ) # If re-applying the ViewMeta sequence succeeded, there should be no more # problems going forward. We just check we got to the target shape and @@ -357,25 +356,45 @@ class MetadataKey: ) -# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata -# after applying all the ViewMeta operations. -class FunctionalTensorMetadataEq: - def __init__(self, tensor: torch.Tensor) -> None: - assert torch._is_functional_tensor(tensor) - self.tensor = tensor +# ViewMeta sequence wrapper for equality comparisons. +# +# Even though we can compare each ViewMeta instance, we compare the resulting +# tensor metadata, instead. That's because the creation of synthetic bases + the +# re-generation of input views might end-up creating a different sequence of +# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same +# metadata. +# +# Therefore, we store what the end result should look like as serializable +# metadata. +# +# When logging, this class should look like: +# +# ViewMetaSequence(view, select_int, slice_Tensor) +# +# i.e. a parenthesized list of view operations within that ViewMeta sequence. +class ViewMetaSequence: + def __init__(self, tensor: FunctionalTensor) -> None: + assert torch._is_functional_tensor(tensor.elem) + self.sequence = _functionalization.get_view_meta_sequence(tensor.elem) + self.metadata = MetadataKey.make(tensor) + + def __repr__(self) -> str: + suffix = len("_ViewMeta") + types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence) + return f"ViewMetaSequence({types})" def __eq__(self, other: object) -> bool: # If other is None, then it probably means that we weren't able to recreate - # the FunctionalTensorMetadataEq. One of this cases is when we update the - # view metadata by calling: create_synthetic_base_metadata. + # the ViewMeta sequence. One example is when we update the view metadata by + # calling: create_synthetic_base_metadata. if other is None: return True # Comparison against any other type is not implemented. - if not isinstance(other, FunctionalTensorMetadataEq): + if not isinstance(other, ViewMetaSequence): return NotImplemented - return has_same_metadata(self.tensor, other.tensor) + return self.metadata == other.metadata # new_arg and arg here are either: diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index dcee706f5cc2..06581e1524fd 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -89,7 +89,7 @@ def remove_dupe_metadata( dynamic_dims=o.dynamic_dims, base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], requires_grad=o.requires_grad, - functional_tensor=o.functional_tensor, + view_meta_sequence=o.view_meta_sequence, ) for o in m.output_info ], @@ -242,7 +242,7 @@ def create_synthetic_base_metadata( # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases base_idx=new_base_idx, # type: ignore[arg-type] requires_grad=o.requires_grad, - functional_tensor=o.functional_tensor, + view_meta_sequence=o.view_meta_sequence, ) ) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 5a5536913813..80564a90e61e 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -150,7 +150,7 @@ class AliasOfInputHandler: self.base_idx = info.base_idx self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.functional_tensor = info.functional_tensor + self.view_meta_sequence = info.view_meta_sequence self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -159,7 +159,7 @@ class AliasOfInputHandler: aliased_base_tensor, self.unwrap_out(out), self.requires_grad, - self.functional_tensor, + self.view_meta_sequence, replay_views=self.replay_views, ) @@ -190,7 +190,7 @@ class AliasOfIntermediateHandler: self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.functional_tensor = info.functional_tensor + self.view_meta_sequence = info.view_meta_sequence self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -199,7 +199,7 @@ class AliasOfIntermediateHandler: self._unwrap_aliased_base_tensor(aliased_base_tensor), self.unwrap_out(out), self.requires_grad, - self.functional_tensor, + self.view_meta_sequence, replay_views=self.replay_views, ) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 9c8cfc0a318d..a65351c31934 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -7,7 +7,6 @@ input/output types, metadata, config, function signatures etc. from __future__ import annotations import collections -import dataclasses import functools import itertools from dataclasses import dataclass, field @@ -32,10 +31,7 @@ from torch.fx.experimental._backward_state import BackwardState from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config -from .functional_utils import ( - _check_if_mutation_can_be_in_graph, - FunctionalTensorMetadataEq, -) +from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence from .utils import strict_zip @@ -117,15 +113,14 @@ class OutputAliasInfo: dynamic_dims: Optional[set[int]] # requires_grad requires_grad: bool - # FunctionalTensorWrapper that represents this output. + # Sequence of ViewMeta objects. # - # Provides us the means to replay views from it. + # Provides us the means to re-run view functions on other tensors. # - # We need to wrap the actual FunctionalTensorWrapper with this class so that - # we only compare the tensor's metadata. That's because with the transformations - # of the model throughout AOTAutograd, the sequence of ViewMeta and the base - # tensor might change. - functional_tensor: Optional[FunctionalTensorMetadataEq] = None + # We need to wrap the actual list of ViewMeta with this class so that + # we compare the ViewMeta elements appropriately, i.e. their type and + # the elements returned by the `as_tuple()` call. + view_meta_sequence: Optional[ViewMetaSequence] = None class MutationType(Enum): @@ -665,17 +660,6 @@ class ViewAndMutationMeta: self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] # Clear traced tangents at runtime self.traced_tangents = [] - new_output_info = [] - for out in self.output_info: - if config.view_replay_for_aliased_outputs: - new_out = out - else: - # If we're not using view_replay, remove the functional tensor. - # Functional tensors are unfortunately not serializable, - # so doing this is required for AOTAutograd caching. - new_out = dataclasses.replace(out, functional_tensor=None) - new_output_info.append(new_out) - self.output_info = new_output_info for inp_meta in self.subclass_inp_meta: if isinstance(inp_meta, SubclassCreationMeta): inp_meta.make_runtime_safe() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index bd8491de6efb..34f7d1491802 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -72,6 +72,7 @@ #include #include #include +#include #include #include #include @@ -2080,6 +2081,7 @@ PyObject* initModule() { torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); + torch::functionalization::initModule(module); #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 1236fad45f36..79739b6e459d 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -644,15 +644,6 @@ void initTorchFunctions(PyObject* module) { at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::mark_mutation_hidden_from_autograd(t); }); - py_module.def( - "_functionalize_apply_view_metas", - [](const at::Tensor& tensor, const at::Tensor& base) { - TORCH_INTERNAL_ASSERT( - at::functionalization::impl::isFunctionalTensor(tensor)); - auto impl = - at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); - return impl->apply_view_metas(base); - }); py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); diff --git a/torch/csrc/functionalization/Module.cpp b/torch/csrc/functionalization/Module.cpp new file mode 100644 index 000000000000..d38cb1078054 --- /dev/null +++ b/torch/csrc/functionalization/Module.cpp @@ -0,0 +1,71 @@ +#include +#include + +#include +#include +#include +#include + +namespace torch::functionalization { + +void initModule(PyObject* module) { + auto m = py::handle(module).cast(); + + // Create a `torch._C._functionalization` Python module. + auto functionalization = m.def_submodule( + "_functionalization", "functionalization related pybind."); + + // Retrieve the ViewMeta sequence of a given functional tensor. + functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + return impl->view_metas(); + }); + + // Applies the given ViewMeta sequence to the given base. + functionalization.def( + "apply_view_meta_sequence", + [](const at::Tensor& base, + const std::vector>& + sequence) { + return at::functionalization::impl::apply_view_meta_sequence( + base, sequence); + }); + + // Binding for InverseReturnMode. + py::enum_( + functionalization, "InverseReturnMode") + .value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView) + .value("NeverView", at::functionalization::InverseReturnMode::NeverView) + .value( + "ViewOrScatterInverse", + at::functionalization::InverseReturnMode::ViewOrScatterInverse); + + // Create bindings for the ViewMeta base class. + // + // Needed so that we can take a list of ViewMeta objects as parameter. + // Specifically, in the Python-side, we will have a list of derived ViewMeta + // classes. We need to tell pybind11 that all of those are, in fact, instances + // of different ViewMeta sub-types. + py::class_< + at::functionalization::ViewMeta, + std::shared_ptr>( + functionalization, "ViewMeta") + .def_property_readonly( + "has_symbolic_inputs", + [](const std::shared_ptr& meta) { + return meta->has_symbolic_inputs; + }); + + // Bindings for `ViewMeta` specializations manually implemented. + create_binding_with_pickle( + functionalization); + create_binding_with_pickle( + functionalization); + + // Bindings for `ViewMeta` specializations automatically generated. + initGenerated(functionalization.ptr()); +} + +} // namespace torch::functionalization diff --git a/torch/csrc/functionalization/Module.h b/torch/csrc/functionalization/Module.h new file mode 100644 index 000000000000..2f77fd3098c3 --- /dev/null +++ b/torch/csrc/functionalization/Module.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include +#include + +namespace torch::functionalization { + +// Creates the default bindings for `ViewMeta` specializations. +// +// Defines a constructor using the types in `SerializableTuple`, as well +// as pickle methods. +template +void create_binding_with_pickle(py::module m) { + py::class_, at::functionalization::ViewMeta>( + m, T::name()) + .def(py::init()) + .def( + "as_tuple", + [](const std::shared_ptr& meta) { + return meta->to_serializable_tuple(); + }) + .def(py::pickle( + [](const std::shared_ptr& meta) { + return meta->to_serializable_tuple(); + }, + [](const typename T::SerializableTuple& tpl) { + return std::make_shared(tpl); + })); +} + +void initModule(PyObject* module); +void initGenerated(PyObject* module); + +} // namespace torch::functionalization diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 850f6b8b5d2b..7e4ed0fc2a95 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -213,7 +213,7 @@ _SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic) class SymIntEqByExpr: """ This is a wrapper around SymInt which has alternative semantics for - equality. Specifically, instead of erroring or guarding, we + equality and pickling. Specifically, instead of erroring or guarding, we instead will hash/compare equality based on the underlying sympy expression; e.g., s0 and s1 will always compare as False. @@ -222,31 +222,25 @@ class SymIntEqByExpr: canonicalize to the same expression via regular simplification. """ - val: Union[torch.SymInt, int] + @staticmethod + def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr: + if isinstance(val, torch.SymInt): + return val.node.expr + else: + return sympy.Integer(val) def __init__(self, val: Union[torch.SymInt, int]) -> None: - self.val = val + self.val: sympy.Expr = SymIntEqByExpr._extract(val) def __repr__(self) -> str: return repr(self.val) - def _extract(self) -> sympy.Expr: - if isinstance(self.val, torch.SymInt): - return self.val.node.expr - else: - return sympy.Integer(self.val) - def __eq__(self, other: object) -> bool: assert isinstance(other, SymIntEqByExpr) - - # int equality fastpath - if type(self.val) is int and type(other.val) is int: - return self.val == other.val - - return self._extract() == other._extract() + return self.val == other.val def __hash__(self) -> int: - return hash(self._extract()) + return hash(self.val) def _nested_int_aware_sort( diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 93667e39b17f..f4b46b5f1476 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -23,20 +23,13 @@ from torchgen.model import ( # This file describes the translation of JIT schema to API's used -# when creating view lambdas that are used by the functionalization pass. -# There are two types of lambdas: forward lambdas and reverse lambdas. -# These API's mostly follow the dispatcher API, with a few quirks: -# - The lambda capture has to convert reference types to value types -# - While the forward lambda just directly calls into the at::_ops API -# (following the dispatcher convention), the logic here for the reverse lambda +# when creating `ViewMeta` specializations that are used by the functionalization pass. +# These API's mostly follow the dispatcher API, with one difference: +# - While the forward function just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse function # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). -# The lambdas generated for each view op in the functionalization pass are of the form -# [capture_arguments](outer_arguments) -> returns_type { -# return name(inner_arguments); -# } - # Define some specific lambda input arguments. base_binding = Binding( name="base", @@ -46,6 +39,18 @@ base_binding = Binding( ), default=None, ) + +has_symbolic_inputs_binding = Binding( + name="has_symbolic_inputs", + nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)), + argument=Argument( + name="has_symbolic_inputs", + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), @@ -54,11 +59,11 @@ mutated_view_binding = Binding( ), default=None, ) -mutated_view_idx_binding = Binding( - name="mutated_view_idx", - nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), +out_index_binding = Binding( + name="out_index", + nctype=NamedCType(name="out_index", type=BaseCType(longT)), argument=Argument( - name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None ), default=None, ) @@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding( ) -# The lambda capture itself doesn't have a name. -# The name returned here corresponds to the name of the inner function called by the lambda. +# Name of the `ViewMeta` specialization class created. +def classname(func: FunctionSchema, with_namespace: bool = False) -> str: + namespace = "at::functionalization::" if with_namespace else "" + return f"{namespace}{func.name.unambiguous_name()}_ViewMeta" + + +# Name of the operation called inside the `forward`/`reverse` implementations. def name( g: NativeFunctionsViewGroup, *, @@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str: return f"{api_name}_inverse" -def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: - # capture arguments include all arguments except `self`. - # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), - # So any reference types (IntArrayRef) need to be converted to value types (vector) - args = func.arguments.flat_all - assert args[0].type == BaseType(BaseTy.Tensor) - non_self_args = args[1:] - non_self_value_bindings = [ - dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args - ] - - all_bindings = [ - inverse_return_mode_binding if is_reverse else reapply_views_binding - ] - all_bindings.extend(non_self_value_bindings) - return all_bindings - - def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 @@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType: return BaseCType(tensorT) -def outer_arguments(*, is_reverse: bool) -> list[Binding]: - if is_reverse: - return [base_binding, mutated_view_binding, mutated_view_idx_binding] - else: - return [base_binding, mutated_view_idx_binding] +# Checks whether `func` might return more than one value. +def is_multi_output(func: FunctionSchema) -> bool: + return len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None + ) -def inner_call_index(func: FunctionSchema) -> Binding | None: - # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. - # When we replay a view op that returns multiple tensors, we need to index into the output appropriately - if len(func.returns) > 1 or ( - len(func.returns) == 1 and func.returns[0].type.is_list_like() - ): - return mutated_view_idx_binding - return None +# `ViewMeta` specialization constructor parameters. +def base_ctor_arguments(func: FunctionSchema) -> list[Binding]: + # All specializations are parematerized by `has_symbolic_inputs` flag. + arguments = [has_symbolic_inputs_binding] + + # If `func` might return more than 1 value, we also parameterize this specialization + # with the output index. + if is_multi_output(func): + arguments.append(out_index_binding) + + return arguments -def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +# `ViewMeta` specialized class' constructor arguments. +# +# Values needed specifically by this specialization, that the base class does not need. +# Same as the class' attributes, but non-owning. +def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]: + return attributes(func, owning=False) + + +# `ViewMeta` specialized class' non-static member data. +# +# Essential data for calling the instance's `forward` and `reverse functions. You can +# think of them as values that should be captured from the functionalization kernel. +def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + return [ + reapply_views_binding, + inverse_return_mode_binding, + *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]], + ] + + +def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] @@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. - index_binding = inner_call_index(func) - if index_binding is not None: + if is_multi_output(func): return [ base_binding, mutated_view_binding, inverse_return_mode_binding, - index_binding, + out_index_binding, ] + non_self_bindings else: return [ diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index b3856e65e700..d4a47536dd1f 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -300,83 +300,11 @@ class ViewInverseSignature: return_type = functionalization.returns_type(self.g.view.func) decls = [ a.decl() - for a in functionalization.inner_arguments( - self.g.view.func, is_reverse=True - ) + for a in functionalization.op_arguments(self.g.view.func, is_reverse=True) ] return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" -@dataclass(frozen=True) -class FunctionalizationLambda: - g: NativeFunctionsViewGroup - - # are we generating the forward lambda or the reverse lambda? - is_reverse: bool - - def captures(self) -> list[Expr]: - # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments - # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, - # and plumb it into the lambda. - outer_ctx = dispatcher.arguments(self.g.view.func) + [ - functionalization.reapply_views_binding, - functionalization.inverse_return_mode_binding, - ] - capture_bindings = functionalization.capture_arguments( - self.g.view.func, is_reverse=self.is_reverse - ) - # allow_expensive_conversions is set because we want to convert - # some reference types (IntArrayRef) to value types (vector). - capture_exprs = translate.translate( - outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True - ) - return capture_exprs - - def decl(self) -> str: - return_type = functionalization.returns_type(self.g.view.func) - capture_str = ", ".join( - f"{val.type.name} = {val.expr}" for val in self.captures() - ) - decls = [ - a.decl() - for a in functionalization.outer_arguments(is_reverse=self.is_reverse) - ] - return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" - - def inner_call(self, *, reapply_views: bool | None = None) -> str: - inner_call_name = functionalization.name( - self.g, - is_reverse=self.is_reverse, - include_namespace=True, - reapply_views=reapply_views, - ) - - arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) - capture_ctx = functionalization.capture_arguments( - self.g.view.func, is_reverse=self.is_reverse - ) - full_ctx = arg_ctx + capture_ctx - - assert self.g.view_copy is not None - call_bindings = functionalization.inner_arguments( - self.g.view_copy.func, is_reverse=self.is_reverse - ) - maybe_index = functionalization.inner_call_index(self.g.view_copy.func) - call_exprs = [ - e.expr for e in translate.translate(full_ctx, call_bindings, method=False) - ] - if not self.is_reverse and maybe_index is not None: - return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];" - else: - return f"{inner_call_name}({', '.join(call_exprs)});" - - @staticmethod - def from_func( - g: NativeFunctionsViewGroup, *, is_reverse: bool - ) -> FunctionalizationLambda: - return FunctionalizationLambda(g, is_reverse) - - @dataclass(frozen=True) class StructuredImplSignature: g: NativeFunctionsGroup diff --git a/torchgen/gen.py b/torchgen/gen.py index b8290d6b8684..7bbdd4a7a741 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -43,6 +43,8 @@ from torchgen.gen_functionalization_type import ( gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, + gen_functionalization_view_meta_classes_decl, + gen_functionalization_view_meta_classes_impl, GenCompositeViewCopyKernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing @@ -2493,48 +2495,48 @@ def gen_source_files( }, ) + def gen_op_headers( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> list[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + def functionalization_env_callable( g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, ) -> dict[str, list[str]]: - def gen_op_headers( - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, - ) -> list[str]: - if isinstance(g, NativeFunctionsViewGroup): - # view ops always get a functionalization kernel - headers = [ - f"#include ", - f"#include ", - ] - if g.view_copy is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - elif isinstance(g, NativeFunctionsGroup): - headers = [ - f"#include ", - f"#include ", - f"#include ", - f"#include ", - ] - if g.inplace is not None: - headers += [ - f"#include ", - f"#include ", - ] - if g.mutable is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - else: - return [ - f"#include ", - f"#include ", - ] - return { "ops_headers": gen_op_headers(g), "func_definitions": gen_functionalization_definition( @@ -2600,6 +2602,31 @@ def gen_source_files( }, ) + cpu_fm.write( + "ViewMetaClasses.h", + lambda: { + "view_meta_declarations": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_decl(selector, g), + view_groups, + ) + ) + }, + ) + + cpu_fm.write( + "ViewMetaClasses.cpp", + lambda: { + "view_meta_implementations": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_impl(selector, g), + view_groups, + ) + ), + "op_headers": list(concatMap(gen_op_headers, view_groups)), + }, + ) + # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # needs to have a corresponding non-aliasing {view}_copy variant. diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 42407974087a..f47985837eac 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,16 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, TYPE_CHECKING +from typing import Callable, Optional, TYPE_CHECKING -from torchgen.api import cpp, dispatcher +from torchgen.api import cpp, dispatcher, functionalization from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, CType, DispatcherSignature, - FunctionalizationLambda, iTensorListRefT, NativeSignature, OptionalCType, @@ -48,7 +47,7 @@ from torchgen.native_function_generation import ( MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) -from torchgen.utils import dataclass_repr +from torchgen.utils import concatMap, dataclass_repr, FileManager if TYPE_CHECKING: @@ -365,6 +364,8 @@ def emit_view_functionalization_body( with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) + spec = ViewMetaSpecialization(g, f=f) + # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) @@ -385,9 +386,6 @@ def emit_view_functionalization_body( for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] - forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) - reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) - # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ @@ -415,19 +413,7 @@ def emit_view_functionalization_body( : at::functionalization::InverseReturnMode::NeverView ); {symbolic_inputs_check} - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - {forward_lambda.decl()} {{ - if (reapply_views) {{ - return {forward_lambda.inner_call(reapply_views=True)} - }} else {{ - return {forward_lambda.inner_call(reapply_views=False)} - }} - }}, - {reverse_lambda.decl()} {{ - return {reverse_lambda.inner_call()} - }}, - /*has_symbolic_inputs=*/{symbolic_inputs_varname} - ); + auto view_meta = {spec.new()}; auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); @@ -455,7 +441,6 @@ def emit_view_functionalization_body( """ else: - is_multi_output_view = isinstance(f.func.returns[0].type, ListType) return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} @@ -489,21 +474,7 @@ def emit_view_functionalization_body( }} }} {symbolic_inputs_check} - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - {forward_lambda.decl()} {{ - if (reapply_views) {{ - return {forward_lambda.inner_call(reapply_views=True)} - }} else {{ - return {forward_lambda.inner_call(reapply_views=False)} - }} - }}, - {reverse_lambda.decl()} {{ - return {reverse_lambda.inner_call()} - }}, - /*has_symbolic_inputs=*/{symbolic_inputs_varname}, - /*is_multi_output=*/{str(is_multi_output_view).lower()}, - /*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()} - ); + auto view_meta = {spec.new()}; auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] if (compute_reference_meta && !disable_meta_reference()) {{ @@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration( return emit_decl_helper(g) +# Helper class for generating `ViewMeta` specializations. +@dataclass +class ViewMetaSpecialization: + g: NativeFunctionsViewGroup + f: NativeFunction + + @property + def is_multi_output(self) -> bool: + return functionalization.is_multi_output(self.f.func) + + @property + def is_as_strided(self) -> bool: + return str(self.f.func.name) == "as_strided" + + @property + def out_index(self) -> str: + if self.is_multi_output: + return functionalization.out_index_binding.name + return "0" + + @property + def classname(self) -> str: + return functionalization.classname(self.f.func) + + def decl(self) -> list[str]: + base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func) + extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func) + attributes = functionalization.attributes(self.f.func) + + # List of types for declaring the `SerializableTuple` type. + serializable_tuple_args = ",\n".join( + f" {binding.type} /* {binding.name} */" + for binding in (base_ctor_arguments + attributes) + ) + + # Arguments used for forwarding the tuple elements to the constructor. + destructure_tuple_args = ", ".join( + f"std::get<{i}>(tpl)" + for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments)) + ) + + # List of constructor parameters + ctor_parameters = ", ".join( + binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments) + ) + + # Call the base class `ViewMeta` constructor. + # + # Both of `is_multi_output` and `is_as_strided` are known values, given the + # operation schema. + is_multi_output_str = str(self.is_multi_output).lower() + is_as_strided_str = str(self.is_as_strided).lower() + + base_ctor_bindings = ", ".join( + [ + # `has_symbolic_inputs` is always taken as parameter. + functionalization.has_symbolic_inputs_binding.name, + f"/*is_multi_output=*/{is_multi_output_str}", + f"/*is_as_strided=*/{is_as_strided_str}", + # `out_index` is know if the operation returns only one value. Otherwise, + # we also take it as parameter. + f"/*out_index=*/{self.out_index}", + ] + ) + + # Assignments of `extra_ctor_arguments` to their corresponding fields. + # These are extra fields to-be-declared in this specialization. + # + # We need to set `allow_expensive_conversions`, since we are storing owned versions + # of the non-owning arguments. + ctor_assignments = ",\n".join( + f" {e.type.name}({e.expr})" + for e in translate( + extra_ctor_arguments, + attributes, + method=False, + allow_expensive_conversions=True, + ) + ) + + # List of arguments for constructing the `SerializableTuple` from an instance. + tuple_arguments = ", ".join( + binding.name for binding in (base_ctor_arguments + attributes) + ) + + # List of field declarations. + attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes) + + # Override `to_out_index` if this operation returns more than 1 value. + to_out_index_decl = "" + if self.is_multi_output: + to_out_index_decl = ( + " std::shared_ptr to_out_index(int64_t out_idx) override;" + ) + + return [ + f""" +struct TORCH_API {self.classname} : public ViewMeta {{ + FUNCTIONALIZATION_VIEWMETA_NAME({self.classname}) + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args}); + + {self.classname}(const SerializableTuple& tpl) + : {self.classname}({destructure_tuple_args}) {{}} + + {self.classname}({ctor_parameters}) + : at::functionalization::ViewMeta({base_ctor_bindings}), +{ctor_assignments} {{}} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; +{to_out_index_decl} + + SerializableTuple to_serializable_tuple() {{ + return std::make_tuple({tuple_arguments}); + }} + +{attr_declarations} +}}; +""" + ] + + # Generate a call to the actual operation. + def opcall(self, is_reverse: bool, reapply_views: bool) -> str: + opname = functionalization.name( + self.g, + is_reverse=is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + # Expected arguments for the operation. + assert self.g.view_copy is not None + op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse) + + # The context is composed by the constructor arguments (which are also + # the field variables stored in the instance), and the `base` tensor. + context = [functionalization.base_binding] + context += functionalization.base_ctor_arguments(self.f.func) + context += functionalization.attributes(self.f.func) + + # If we are generating the call for the reverse function, we also have + # access to `mutated_view` argument. + if is_reverse: + context.append(functionalization.mutated_view_binding) + + arguments = ", ".join( + [e.expr for e in translate(context, op_arguments, method=False)] + ) + + # Index the result if this operation returns multiple values. + maybe_index = "" + if not is_reverse and self.is_multi_output: + maybe_index = f"[{self.out_index}]" + + return f"{opname}({arguments}){maybe_index}" + + def impl(self) -> list[str]: + functions = [ + f""" +at::Tensor {self.classname}::forward(const at::Tensor& base) {{ + if (reapply_views) {{ + return {self.opcall(is_reverse=False, reapply_views=True)}; + }} else {{ + return {self.opcall(is_reverse=False, reapply_views=False)}; + }} +}}""", + f""" +at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{ + return {self.opcall(is_reverse=True, reapply_views=True)}; +}}""", + ] + + # If this operation returns multiple values, also generate a `to_out_index` + # implementation. + if self.is_multi_output: + functions.append(f""" +std::shared_ptr {self.classname}::to_out_index(int64_t out_index) {{ + return {self.new("out_index")}; +}} +""") + + return functions + + # Create the Python binding for this specialized class. + def binding(self) -> list[str]: + name = functionalization.classname(self.f.func, with_namespace=True) + return [f" create_binding_with_pickle<{name}>(functionalization);"] + + # Generate an instantiation of this specialized class. + def new(self, out_index: str = "0") -> str: + name = functionalization.classname(self.f.func, with_namespace=True) + ctor_arguments = functionalization.base_ctor_arguments( + self.f.func + ) + functionalization.extra_ctor_arguments(self.f.func) + # Replace the `out_index` parameter with the given `out_index`. + arguments = ", ".join( + binding.name if binding.name != "out_index" else out_index + for binding in ctor_arguments + ) + return f"std::make_shared<{name}>({arguments})" + + # Run the function `run` for both: `view` and `view_inplace` functions. + @staticmethod + def map( + g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]] + ) -> list[str]: + def maybe_run(f: Optional[NativeFunction]) -> list[str]: + if f is None: + return [] + with native_function_manager(f): + return run(ViewMetaSpecialization(g, f)) + + return list(concatMap(maybe_run, (g.view, g.view_inplace))) + + +def gen_functionalization_view_meta_classes_base( + selector: SelectiveBuilder, + g: NativeFunctionsViewGroup, + run: Callable[[ViewMetaSpecialization], list[str]], +) -> list[str]: + if not selector.include_all_operators: + return [] + + if g.composite: + return [] + + return ViewMetaSpecialization.map(g, run) + + +def gen_functionalization_view_meta_classes_decl( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.decl + ) + + +def gen_functionalization_view_meta_classes_impl( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.impl + ) + + +def gen_functionalization_view_meta_classes_binding( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.binding + ) + + +# Generates the Python bindings for the `ViewMeta` specialized classes. +def gen_functionalization_view_meta_classes( + native_functions_path: str, + tags_path: str, + selector: SelectiveBuilder, + install_dir: str, + template_dir: str, +) -> None: + from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml + + # Parse the native_functions.yaml. + # Then, group them into `NativeFunctionsViewGroup`. + # + # This is the same steps we do in gen.py (ATen codegen). + native_functions = parse_native_yaml( + native_functions_path, tags_path + ).native_functions + native_functions_with_view_groups = get_grouped_by_view_native_functions( + native_functions + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False) + fm.write( + "ViewMetaClassesPythonBinding.cpp", + lambda: { + "view_meta_bindings": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_binding( + selector, g + ), + view_groups, + ) + ), + }, + ) + + def gen_functionalization_registration( selector: SelectiveBuilder, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,