Reapply "Make functionalization ViewMeta serializable with pickle. (#143712)" (#163769)

### Summary:
NOTE: This is a re-export of https://github.com/pytorch/pytorch/pull/161994 ; the changes between these two PRs is exclusively to the buck/build files

(Summary from #161994 )
Attempted rebase of https://github.com/pytorch/pytorch/pull/143712.

This reverts commit 6c713ccb5e0df227dd5b630057cbccd373cbe7d6.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

imported-using-ghimport

Test Plan: Imported from OSS

Differential Revision: D81524507

Pulled By: Lucaskabela

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163769
Approved by: https://github.com/dolpm

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
This commit is contained in:
Brian Hirsh
2025-09-25 10:27:37 +00:00
committed by PyTorch MergeBot
parent 29cbcbac42
commit 7d710403b0
38 changed files with 981 additions and 425 deletions

1
.gitignore vendored
View File

@ -82,6 +82,7 @@ torch/return_types.pyi
torch/nn/functional.pyi torch/nn/functional.pyi
torch/utils/data/datapipes/datapipe.pyi torch/utils/data/datapipes/datapipe.pyi
torch/csrc/autograd/generated/* torch/csrc/autograd/generated/*
torch/csrc/functionalization/generated/*
torch/csrc/lazy/generated/*.[!m]* torch/csrc/lazy/generated/*.[!m]*
torch_compile_debug/ torch_compile_debug/
# Listed manually because some files in this directory are not generated # Listed manually because some files in this directory are not generated

View File

@ -90,6 +90,8 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/VmapGeneratedPlumbing.h", "aten/src/ATen/VmapGeneratedPlumbing.h",
"aten/src/ATen/ViewMetaClasses.h",
"aten/src/ATen/ViewMetaClasses.cpp",
"aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorBody.h",
@ -1074,6 +1076,7 @@ test_suite(
"aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
"aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml", "aten/src/ATen/native/ts_native_functions.yaml",

View File

@ -9,11 +9,6 @@
namespace at::functionalization { namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}
// Note [Functionalization: Alias Removal Part 2] // Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details. // See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl. // This function applies a single update from one of the views to the StorageImpl.
@ -42,12 +37,12 @@ ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) { static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
at::Tensor t = update.new_val; at::Tensor t = update.new_val;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
if (update.view_metas.empty()) return t; if (update.view_metas.empty()) { return t; }
std::vector<at::Tensor> tmp_values({base}); std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size()); tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor. // All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values // If need to, we could probably apply this optimization and only bother computing tmp_values
@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
tmp_values.push_back(std::move(next_view)); tmp_values.push_back(std::move(next_view));
} }
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) { for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp. // Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); t = update.view_metas[i]->reverse(tmp_values[i], t);
} }
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t; return t;
@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
} }
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) { void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
if (metas.size() > 1) { if (metas.size() > 1) {
for (size_t i = 1; i < metas.size(); ++i) { for (size_t i = 1; i < metas.size(); ++i) {
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "

View File

@ -8,44 +8,89 @@ namespace at::functionalization {
// See Note [Functionalization Pass In Core] // See Note [Functionalization Pass In Core]
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying)
/// scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to
/// handle.
ViewOrScatterInverse,
};
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
static const char* name() { \
return #TYPE; \
}
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
using SerializableTuple = std::tuple<__VA_ARGS__>
// ViewMeta is a class used by the functionalization pass to navigate between // ViewMeta is a class used by the functionalization pass to navigate between
// a base tensor and a view tensor. // a base tensor and a view tensor.
// For example, if I call `b = a.view1(...)` // For example, if I call `b = a.view1(...)`
// the functionalization pass will generate and store a ViewMeta on b that looks // the functionalization pass will generate and store a ViewMeta specialization
// like: // for `view1` operation on b that looks like:
// //
// ViewMeta( // struct TORCH_API view1_ViewMeta : public ViewMeta {
// [<captures>](const Tensor& base, int64_t mutated_view_idx) { // FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
// bool /* reapply_views */,
// const std::vector<int64_t>&);
//
// view1_ViewMeta(const SerializableTuple& tpl)
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
//
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
// : ViewMeta(/*has_symbolic_inputs=*/false),
// reapply_views(reapply_views),
// size(size) {}
//
// Tensor forward(const Tensor& base) override {
// return base.view1(...); // return base.view1(...);
// }, // }
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view, //
// int64_t mutated_view_idx) -> at::Tensor { // Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
// return at::functionalization::impl::view1_inverse(base, mutated_view, // return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...); // ...);
// } // }
// //
// The forward_fn lambda describes how to replay view1 on a tensor. // SerializableTuple to_serializable_tuple() {
// return std::make_tuple(reapply_views, size);
// }
// //
// The reverse_fn lambda describes how, given a tensor that is already a view, // bool reapply_views;
// std::vector<int64_t> size;
// };
//
// The forward function describes how to replay view1 on a tensor.
//
// The reverse function describes how, given a tensor that is already a view,
// how to get the corresponding base tensor. See Note [Functionalization Pass: // how to get the corresponding base tensor. See Note [Functionalization Pass:
// View Inverses] for details. // View Inverses] for details.
//
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
// representing the `ViewMeta` instance state. Methods that take in/return such
// a type are used for supporting pickle serialization.
struct ViewMeta { struct ViewMeta {
ViewMeta( ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs, bool has_symbolic_inputs,
bool is_multi_output = false, bool is_multi_output = false,
bool is_as_strided = false, bool is_as_strided = false,
int64_t out_idx = 0) int64_t out_idx = 0)
: forward_fn(std::move(forward)), : out_index(out_idx),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output), is_multi_output(is_multi_output),
is_as_strided(is_as_strided), is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {} has_symbolic_inputs(has_symbolic_inputs) {}
std::function<Tensor(const Tensor&, int64_t)> forward_fn; virtual ~ViewMeta() = default;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
virtual Tensor forward(const Tensor& base) = 0;
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
// See Note [out_idx in ViewMeta] // See Note [out_idx in ViewMeta]
int64_t out_index; int64_t out_index;
@ -57,10 +102,17 @@ struct ViewMeta {
// Tells us if this view operation has any symbolic inputs // Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs; bool has_symbolic_inputs;
// Returns a copy of the current ViewMeta, if out_idx matches the current // Returns a new ViewMeta with the same forward/reverse
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index. // functions, but a new out index.
ViewMeta to_out_idx(int64_t out_idx); //
// This method should be implemented by those `ViewMeta` that have more than
// one output.
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"ViewMeta::to_out_index not implemented. ",
"Likely because there's only one output.");
}
}; };
// FunctionalStorageImpl is a subclass of StorageImpl used by the // FunctionalStorageImpl is a subclass of StorageImpl used by the
@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor new_val; const at::Tensor new_val;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<ViewMeta> view_metas; const std::vector<std::shared_ptr<ViewMeta>> view_metas;
}; };
explicit FunctionalStorageImpl(const Tensor& value); explicit FunctionalStorageImpl(const Tensor& value);
void add_update( void add_update(
const Tensor& updated_val, const Tensor& updated_val,
const std::vector<ViewMeta>& view_metas); const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
bool apply_updates(); bool apply_updates();
const Tensor& base() { const Tensor& base() {
return base_; return base_;

View File

@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const {
// - view_value: The output tensor that we need to wrap. // - view_value: The output tensor that we need to wrap.
// - base: The "base" of the view that `view_value` was generated from. // - base: The "base" of the view that `view_value` was generated from.
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) FunctionalTensorWrapper::FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const std::shared_ptr<functionalization::ViewMeta>& meta)
: c10::TensorImpl( : c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize), c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(), view_value.dtype(),
base->storage().data_ptr().device() base->storage().data_ptr().device()),
),
value_(view_value), value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), is_multi_output_view_(
base->is_multi_output_view_ || meta->is_multi_output),
was_storage_changed_(base->was_storage_changed_), was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_) is_symbolic_(base->is_symbolic_) {
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata(); set_constructor_metadata();
@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_metas_ = base->view_metas_; // copy view_metas_ = base->view_metas_; // copy
} }
view_metas_.push_back(meta); view_metas_.push_back(meta);
maybe_mark_symbolic(meta); maybe_mark_symbolic(meta.get());
storage_ = base->storage_; // alias this tensor's storage with the base tensor's storage_ = base->storage_; // alias this tensor's storage with the base tensor's
} }
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl()); return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
} }
@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
} }
// See Note [Functionalization Pass - Inplace View Ops] // See Note [Functionalization Pass - Inplace View Ops]
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
view_metas_.push_back(meta); view_metas_.push_back(meta);
// Manually track the fact that this tensor received a metadata mutation! // Manually track the fact that this tensor received a metadata mutation!
has_metadata_mutation_ = true; has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta); maybe_mark_symbolic(meta.get());
// Note [Functionalization Pass - Inplace View Ops] // Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen. // So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()` // An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
value_ = meta.forward_fn(value_, meta.out_index); value_ = meta->forward(value_);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
} }
@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() {
regenerate_from_base(); regenerate_from_base();
} }
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
auto t = base; return view_metas_;
// Reapply views to get the viewed tensor from the base in alias_
for (auto& view_meta: view_metas_) {
t = view_meta.forward_fn(t, view_meta.out_index);
}
return t;
} }
void FunctionalTensorWrapper::regenerate_from_base() { void FunctionalTensorWrapper::regenerate_from_base() {
@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
auto t = storage_impl->base(); auto t = storage_impl->base();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
t = apply_view_metas(t); t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t, /*from_lazy_regenerate=*/true); replace_(t, /*from_lazy_regenerate=*/true);
@ -727,11 +721,11 @@ bool isFunctionalTensor(const std::optional<Tensor>& t) {
} }
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) { bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
if (t_list.empty()) return false; if (t_list.empty()) { return false; }
auto functional_count = 0; auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) { for (const auto i : c10::irange(t_list.size())) {
auto const & e= t_list[i]; auto const & e= t_list[i];
if (!e.has_value() || !e->defined()) continue; if (!e.has_value() || !e->defined()) { continue; }
if (isFunctionalTensor(e)) { if (isFunctionalTensor(e)) {
++functional_count; ++functional_count;
} }
@ -741,10 +735,10 @@ bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
template <typename T> template <typename T>
static bool isFunctionalTensorIListRef(c10::IListRef<T> list) { static bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
if (list.size() == 0) return false; if (list.size() == 0) { return false; }
auto functional_count = 0; auto functional_count = 0;
for (const auto& tensor : list) { for (const auto& tensor : list) {
if (!tensor.defined()) continue; if (!tensor.defined()) { continue; }
if (isFunctionalTensor(tensor)) { if (isFunctionalTensor(tensor)) {
++functional_count; ++functional_count;
} }
@ -762,20 +756,28 @@ void freeze_functional_tensor(const Tensor& tensor) {
functional_base_impl->freeze_storage(); functional_base_impl->freeze_storage();
} }
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { Tensor create_functional_tensor_with_view_meta(
const at::Tensor& view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
auto meta_ = meta;
if (out_idx != 0) { if (out_idx != 0) {
// Note [out_idx in ViewMeta] // Note [out_idx in ViewMeta]
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
meta = meta.to_out_idx(out_idx); meta_ = meta->to_out_index(out_idx);
} }
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta); return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
} }
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta) {
std::vector<Tensor> outputs(view_to_wrap.size()); std::vector<Tensor> outputs(view_to_wrap.size());
int64_t i = 0; int64_t i = 0;
for (const auto& tensor : view_to_wrap) { for (const auto& tensor : view_to_wrap) {
@ -785,12 +787,22 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_
return outputs; return outputs;
} }
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
self_impl->mutate_view_meta(meta); self_impl->mutate_view_meta(meta);
} }
Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
Tensor r = base;
for (auto& vm : sequence) {
r = vm->forward(r);
}
return r;
}
// Note [Propagating strides in the functionalization pass] // Note [Propagating strides in the functionalization pass]
// In order to properly compute stride information, the functionalization pass // In order to properly compute stride information, the functionalization pass
// calls each {view} reference implementations with meta tensors. // calls each {view} reference implementations with meta tensors.
@ -884,7 +896,7 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s
const auto& ivalue = returns[idx]; const auto& ivalue = returns[idx];
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
const auto& t = ivalue.toTensor(); const auto& t = ivalue.toTensor();
if (!t.defined()) continue; if (!t.defined()) { continue; }
at::functionalization::impl::sync(t); at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new; (*stack)[returns_begin + idx] = t_new;

View File

@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper( explicit FunctionalTensorWrapper(
const Tensor& view_value, const Tensor& view_value,
const FunctionalTensorWrapper* base, const FunctionalTensorWrapper* base,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
// Get the underlying, actual tensor, that doesn't know anything about // Get the underlying, actual tensor, that doesn't know anything about
// functionalization. // functionalization.
@ -99,17 +99,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode(); ->are_all_mutations_under_no_grad_or_inference_mode();
} }
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
} }
bool is_symbolic() const { bool is_symbolic() const {
return is_symbolic_; return is_symbolic_;
} }
// Runs the forward_fn of every ViewMeta collected in the current instance // Retrieves the ViewMeta sequence of this tensor.
// to some other base. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
Tensor apply_view_metas(const Tensor& base); const;
// Sync's the underlying tensor with its alias, if it's out of date. This // Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2) // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
@ -146,7 +146,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// from the base tensor. This method is used by inplace-view ops like // from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias. // tensor by replaying the views off of the alias.
void mutate_view_meta(const at::functionalization::ViewMeta& meta); void mutate_view_meta(
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
// Custom implementation of self.set_(src) // Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other); void set__impl(const FunctionalTensorWrapper* other);
@ -285,7 +286,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_symbolic_ = false; bool is_symbolic_ = false;
size_t generation_ = 0; size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_; std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
protected: protected:
static void copy_tensor_metadata( static void copy_tensor_metadata(
@ -377,16 +378,20 @@ TORCH_API void propagate_xla_data_direct(
Tensor create_functional_tensor_with_view_meta( Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap, const Tensor& view_to_wrap,
const Tensor& base, const Tensor& base,
functionalization::ViewMeta meta, const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx = 0); int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta( std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap, ITensorListRef view_to_wrap,
const Tensor& base, const Tensor& base,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
void mutate_view_meta( void mutate_view_meta(
const Tensor& self, const Tensor& self,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
TORCH_API Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset( void set_sizes_strides_offset(

View File

@ -1,3 +1,5 @@
#include <ATen/FunctionalizeFallbackKernel.h>
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h> #include <ATen/EmptyTensor.h>
@ -7,7 +9,6 @@
#include <torch/library.h> #include <torch/library.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/strides.h> #include <c10/util/strides.h>
#include <ATen/EmptyTensor.h>
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/ATen.h> #include <ATen/ATen.h>
@ -28,6 +29,31 @@
#include <utility> #include <utility>
#endif #endif
namespace at::functionalization {
Tensor resize__ViewMeta::forward(const Tensor& base) {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
}
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return base.as_strided_scatter(
mutated_view, size, c10::contiguous_strides(size));
}
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
return at::_unsafe_view_symint(base, size);
}
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
} // namespace at::functionalization
namespace { namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
const auto& schema = op.schema(); const auto& schema = op.schema();
@ -106,7 +132,9 @@ namespace {
const auto& ivalue = returns[idx]; const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) { if (ivalue.isTensor() && should_wrap_outputs) {
const auto& t = ivalue.toTensor(); const auto& t = ivalue.toTensor();
if (!t.defined()) continue; if (!t.defined()) {
continue;
}
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new; (*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) { } else if (ivalue.isTensorList() && should_wrap_outputs) {
@ -169,19 +197,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
// The output of resizing is equivalent to taking a slice of a larger tensor. // The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call. // We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { reapply_views, size.vec());
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta); at::functionalization::impl::mutate_view_meta(self, view_meta);
return self; return self;
} }
@ -300,17 +317,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size); tmp_output = at::_unsafe_view_symint(self_, size);
} }
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); bool has_symbolic_inputs = std::any_of(
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta =
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
return at::_unsafe_view_symint(base, size); has_symbolic_inputs, size.vec());
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
// See Note [Propagating strides in the functionalization pass] // See Note [Propagating strides in the functionalization pass]

View File

@ -0,0 +1,58 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
namespace at::functionalization {
// `ViewMeta` implementation for `resize_` operation.
struct TORCH_API resize__ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta)
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* reapply_views */,
const std::vector<int64_t>&);
resize__ViewMeta(const SerializableTuple& tpl)
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
: ViewMeta(/*has_symbolic_inputs=*/false),
reapply_views(reapply_views),
size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(reapply_views, size);
}
bool reapply_views;
std::vector<int64_t> size;
};
// `ViewMeta` implementation for `_unsafe_view` operation.
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta)
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* has_symbolic_inputs */,
const std::vector<c10::SymInt>&);
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
_unsafe_view_ViewMeta(
bool has_symbolic_inputs,
const std::vector<c10::SymInt>& size)
: ViewMeta(has_symbolic_inputs), size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(has_symbolic_inputs, size);
}
std::vector<c10::SymInt> size;
};
} // namespace at::functionalization

View File

@ -2,22 +2,12 @@
// ${generated_comment} // ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
namespace at { namespace at {
namespace functionalization { namespace functionalization {
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying) scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
ViewOrScatterInverse,
};
struct FunctionalInverses { struct FunctionalInverses {
${view_inverse_declarations} ${view_inverse_declarations}

View File

@ -4,7 +4,7 @@
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h> #include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h> #include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalInverses.h> #include <ATen/ViewMetaClasses.h>
#include <ATen/MemoryOverlap.h> #include <ATen/MemoryOverlap.h>
#include <torch/library.h> #include <torch/library.h>

View File

@ -0,0 +1,19 @@
// ${generated_comment}
#include <ATen/FunctionalInverses.h>
#include <ATen/ViewMetaClasses.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#include <ATen/NativeFunctions.h>
#else
${op_headers}
#endif
namespace at {
namespace functionalization {
${view_meta_implementations}
} // namespace functionalization
} // namespace at

View File

@ -0,0 +1,12 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
namespace at {
namespace functionalization {
${view_meta_declarations}
} // namespace functionalization
} // namespace at

View File

@ -0,0 +1,11 @@
#include <ATen/ViewMetaClasses.h>
#include <torch/csrc/functionalization/Module.h>
namespace torch::functionalization {
void initGenerated(PyObject* module) {
auto functionalization = py::handle(module).cast<py::module>();
$view_meta_bindings
}
} // namespace torch::functionalization

View File

@ -391,6 +391,8 @@ def get_aten_generated_files(enabled_backends):
"CompositeExplicitAutogradFunctions_inl.h", "CompositeExplicitAutogradFunctions_inl.h",
"CompositeExplicitAutogradNonFunctionalFunctions.h", "CompositeExplicitAutogradNonFunctionalFunctions.h",
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"ViewMetaClasses.h",
"ViewMetaClasses.cpp",
"VmapGeneratedPlumbing.h", "VmapGeneratedPlumbing.h",
"core/ATenOpList.cpp", "core/ATenOpList.cpp",
"core/TensorBody.h", "core/TensorBody.h",
@ -1193,6 +1195,7 @@ def define_buck_targets(
"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
"Operators.h": ":gen_aten[Operators.h]", "Operators.h": ":gen_aten[Operators.h]",
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
"ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]",
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]", "core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]", "core/enum_tag.h": ":gen_aten[core/enum_tag.h]",

View File

@ -118,6 +118,9 @@ def define_targets(rules):
":LazyNonNativeIr.h", ":LazyNonNativeIr.h",
":RegisterDispatchDefinitions.ini", ":RegisterDispatchDefinitions.ini",
":RegisterDispatchKey.cpp", ":RegisterDispatchKey.cpp",
":ViewMetaClassesPythonBinding.cpp",
":ViewMetaClasses.cpp",
":ViewMetaClasses.h",
":native_functions.yaml", ":native_functions.yaml",
":shape_inference.h", ":shape_inference.h",
":tags.yaml", ":tags.yaml",
@ -170,6 +173,7 @@ GENERATED_H = [
"FunctionalInverses.h", "FunctionalInverses.h",
"RedispatchFunctions.h", "RedispatchFunctions.h",
"RegistrationDeclarations.h", "RegistrationDeclarations.h",
"ViewMetaClasses.h",
"VmapGeneratedPlumbing.h", "VmapGeneratedPlumbing.h",
] ]
@ -246,6 +250,7 @@ GENERATED_CPP = [
"RegisterFunctionalization_1.cpp", "RegisterFunctionalization_1.cpp",
"RegisterFunctionalization_2.cpp", "RegisterFunctionalization_2.cpp",
"RegisterFunctionalization_3.cpp", "RegisterFunctionalization_3.cpp",
"ViewMetaClasses.cpp",
] ]
GENERATED_CPP_CORE = [ GENERATED_CPP_CORE = [
@ -307,6 +312,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
] ]
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP

View File

@ -1010,6 +1010,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/disable_torch_function.cpp",
"torch/csrc/utils/verbose.cpp", "torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp", "torch/csrc/cpu/Module.cpp",
"torch/csrc/functionalization/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp", "torch/csrc/instruction_counter/Module.cpp",
"torch/nativert/python/Bindings.cpp", "torch/nativert/python/Bindings.cpp",
] + lazy_tensor_core_python_sources ] + lazy_tensor_core_python_sources
@ -1052,6 +1053,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
"torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp",
]] ]]
_libtorch_python_sources.extend(libtorch_python_core_sources) _libtorch_python_sources.extend(libtorch_python_core_sources)

View File

@ -316,6 +316,7 @@ set(GENERATED_CXX_PYTHON
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
) )
set(GENERATED_H_PYTHON set(GENERATED_H_PYTHON
@ -379,6 +380,9 @@ add_custom_command(
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
${autograd_python} ${autograd_python}
${autograd_yaml} ${autograd_yaml}
${autograd_templates} ${autograd_templates}

View File

@ -156,6 +156,7 @@ def get_generate_code_bin_outs():
"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
"autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
"functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"],
}) })
return outs return outs

View File

@ -519,11 +519,7 @@ class AOTAutogradCacheTests(InductorTestCase):
@functorch_config.patch( @functorch_config.patch(
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
) )
def test_view_replay_bypass(self): def test_view_replay(self):
"""
Should bypass when view replay is turned on
"""
def fn(a): def fn(a):
tmp = a.detach() tmp = a.detach()
a.mul_(2) a.mul_(2)
@ -531,10 +527,25 @@ class AOTAutogradCacheTests(InductorTestCase):
with torch.autograd._force_original_view_tracking(True): with torch.autograd._force_original_view_tracking(True):
compiled_fn = torch.compile(fn) compiled_fn = torch.compile(fn)
compiled_fn(torch.rand(2, 3))
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) def run_and_check(miss, hit, bypass):
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) self._clear_dynamo_and_codecache()
inp = torch.rand(2, 3)
compiled_inp = inp.clone().detach()
with torch.autograd._force_original_view_tracking(True):
out = fn(inp)
compiled_out = compiled_fn(compiled_inp)
self.assertEqual(out, compiled_out)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
run_and_check(miss=1, hit=0, bypass=0)
run_and_check(miss=1, hit=1, bypass=0)
run_and_check(miss=1, hit=2, bypass=0)
@inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True) @inductor_config.patch("fx_graph_cache", True)

View File

@ -8500,7 +8500,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
{ {
"enable_autograd_cache": True, "enable_autograd_cache": True,
"strict_autograd_cache": True, "strict_autograd_cache": True,
"view_replay_for_aliased_outputs": False,
} }
) )
@torch._inductor.config.patch("fx_graph_cache", True) @torch._inductor.config.patch("fx_graph_cache", True)

View File

@ -189,6 +189,12 @@ def main() -> None:
) )
options = parser.parse_args() options = parser.parse_args()
# Path: aten/src/ATen
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
operator_selector = get_selector(
options.selected_op_list_path, options.operators_yaml_path
)
generate_code( generate_code(
options.gen_dir, options.gen_dir,
options.native_functions_path, options.native_functions_path,
@ -198,18 +204,37 @@ def main() -> None:
options.disable_autograd, options.disable_autograd,
options.force_schema_registration, options.force_schema_registration,
# options.selected_op_list # options.selected_op_list
operator_selector=get_selector( operator_selector=operator_selector,
options.selected_op_list_path, options.operators_yaml_path )
),
# Generate the python bindings for functionalization's `ViewMeta` classes.
from torchgen.gen_functionalization_type import (
gen_functionalization_view_meta_classes,
)
functionalization_templates_dir = os.path.join(aten_path, "templates")
install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
functionalization_install_dir = os.path.join(
install_dir, "functionalization", "generated"
)
os.makedirs(functionalization_install_dir, exist_ok=True)
assert os.path.isdir(functionalization_install_dir)
assert os.path.isdir(functionalization_templates_dir)
gen_functionalization_view_meta_classes(
options.native_functions_path or NATIVE_FUNCTIONS_PATH,
options.tags_path or TAGS_PATH,
selector=operator_selector,
install_dir=functionalization_install_dir,
template_dir=functionalization_templates_dir,
) )
if options.gen_lazy_ts_backend: if options.gen_lazy_ts_backend:
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc") lazy_install_dir = os.path.join(install_dir, "lazy", "generated")
lazy_install_dir = os.path.join(install_dir, "lazy/generated")
os.makedirs(lazy_install_dir, exist_ok=True) os.makedirs(lazy_install_dir, exist_ok=True)
assert os.path.isfile(ts_backend_yaml), ( assert os.path.isfile(ts_backend_yaml), (

View File

@ -30,6 +30,7 @@ from torch._C import (
_cpu, _cpu,
_dynamo, _dynamo,
_export, _export,
_functionalization,
_functorch, _functorch,
_lazy, _lazy,
_lazy_ts_backend, _lazy_ts_backend,

View File

@ -0,0 +1,16 @@
from torch import Tensor
from torch.types import _bool
# Defined in torch/csrc/functionalization/Module.cpp
class ViewMeta:
has_symbolic_inputs: _bool
# Returns the list of ViewMeta instances of the given functional tensor.
#
# Although we do have python bindings for their types, we won't
# expose them here, since they should not be used by users.
def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ...
# Applies the ViewMeta sequence on top of the given base.
def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ...

View File

@ -284,19 +284,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type]
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
"""
When view replay is turned on, we bypass autograd cache if
the output is aliased.
"""
if config.view_replay_for_aliased_outputs:
for info in metadata.output_info:
if info.functional_tensor is not None:
raise BypassAOTAutogradCache(
"Cannot cache a graph with functional tensor"
)
class AOTAutogradCacheDetails(FxGraphHashDetails): class AOTAutogradCacheDetails(FxGraphHashDetails):
""" """
Object to capture all the details for a dynamo graph module relevant to computing Object to capture all the details for a dynamo graph module relevant to computing
@ -803,7 +790,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
""" """
Perform any preparations to make the cache entry ready for serialization. Perform any preparations to make the cache entry ready for serialization.
""" """
check_metadata_cacheable(self.runtime_metadata)
self.compiled_fw.pre_save() self.compiled_fw.pre_save()
if self.compiled_bw is not None: if self.compiled_bw is not None:
self.compiled_bw.pre_save() self.compiled_bw.pre_save()

View File

@ -43,10 +43,10 @@ from .functional_utils import (
has_metadata_mutation, has_metadata_mutation,
MetadataKey, MetadataKey,
to_fun, to_fun,
ViewMetaSequence,
was_inductor_storage_resized, was_inductor_storage_resized,
) )
from .schemas import ( from .schemas import (
FunctionalTensorMetadataEq,
InputAliasInfo, InputAliasInfo,
MemoryFormatMeta, MemoryFormatMeta,
MutationType, MutationType,
@ -640,7 +640,7 @@ from a multi-output view call"
# #
# The FunctionalTensor will be saved if one of the 2 conditions below # The FunctionalTensor will be saved if one of the 2 conditions below
# is true: # is true:
functional_tensor = None view_meta_sequence = None
if ( if (
# 1. If the output_type is either of: # 1. If the output_type is either of:
# (i) alias_of_intermediate; # (i) alias_of_intermediate;
@ -672,7 +672,7 @@ from a multi-output view call"
and not input_info[base_idx].mutates_metadata and not input_info[base_idx].mutates_metadata
): ):
if isinstance(o, FunctionalTensor): if isinstance(o, FunctionalTensor):
functional_tensor = FunctionalTensorMetadataEq(o.elem) view_meta_sequence = ViewMetaSequence(o)
out_info = OutputAliasInfo( out_info = OutputAliasInfo(
output_type=output_type, output_type=output_type,
@ -680,7 +680,7 @@ from a multi-output view call"
base_idx=base_idx, base_idx=base_idx,
dynamic_dims=dynamic_dims, dynamic_dims=dynamic_dims,
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
functional_tensor=functional_tensor, view_meta_sequence=view_meta_sequence,
) )
output_info.append(out_info) output_info.append(out_info)

View File

@ -14,6 +14,7 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch._C import _functionalization
from torch._logging import getArtifactLogger from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.functional_tensor import FunctionalTensor
@ -224,9 +225,9 @@ def gen_alias_from_base(
aliased_base_tensor, aliased_base_tensor,
target_meta_tensor, target_meta_tensor,
target_requires_grad, target_requires_grad,
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, target_view_meta_sequence: Optional[ViewMetaSequence] = None,
*, *,
replay_views, replay_views: bool,
): ):
# Patch the correct requires_grad field of the output tensor, depending on whether: # Patch the correct requires_grad field of the output tensor, depending on whether:
# (i) the reconstructed output (out) was came from a tensor that requires grad or not; # (i) the reconstructed output (out) was came from a tensor that requires grad or not;
@ -245,13 +246,11 @@ def gen_alias_from_base(
# to replay them (view functions) on the aliased_base_tensor. # to replay them (view functions) on the aliased_base_tensor.
if ( if (
replay_views replay_views
and target_functional_tensor is not None and target_view_meta_sequence is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence)
): ):
functional_tensor = target_functional_tensor.tensor out = _functionalization.apply_view_meta_sequence(
aliased_base_tensor, target_view_meta_sequence.sequence
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
) )
# If re-applying the ViewMeta sequence succeeded, there should be no more # If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and # problems going forward. We just check we got to the target shape and
@ -357,25 +356,45 @@ class MetadataKey:
) )
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata # ViewMeta sequence wrapper for equality comparisons.
# after applying all the ViewMeta operations. #
class FunctionalTensorMetadataEq: # Even though we can compare each ViewMeta instance, we compare the resulting
def __init__(self, tensor: torch.Tensor) -> None: # tensor metadata, instead. That's because the creation of synthetic bases + the
assert torch._is_functional_tensor(tensor) # re-generation of input views might end-up creating a different sequence of
self.tensor = tensor # ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same
# metadata.
#
# Therefore, we store what the end result should look like as serializable
# metadata.
#
# When logging, this class should look like:
#
# ViewMetaSequence(view, select_int, slice_Tensor)
#
# i.e. a parenthesized list of view operations within that ViewMeta sequence.
class ViewMetaSequence:
def __init__(self, tensor: FunctionalTensor) -> None:
assert torch._is_functional_tensor(tensor.elem)
self.sequence = _functionalization.get_view_meta_sequence(tensor.elem)
self.metadata = MetadataKey.make(tensor)
def __repr__(self) -> str:
suffix = len("_ViewMeta")
types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence)
return f"ViewMetaSequence({types})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
# If other is None, then it probably means that we weren't able to recreate # If other is None, then it probably means that we weren't able to recreate
# the FunctionalTensorMetadataEq. One of this cases is when we update the # the ViewMeta sequence. One example is when we update the view metadata by
# view metadata by calling: create_synthetic_base_metadata. # calling: create_synthetic_base_metadata.
if other is None: if other is None:
return True return True
# Comparison against any other type is not implemented. # Comparison against any other type is not implemented.
if not isinstance(other, FunctionalTensorMetadataEq): if not isinstance(other, ViewMetaSequence):
return NotImplemented return NotImplemented
return has_same_metadata(self.tensor, other.tensor) return self.metadata == other.metadata
# new_arg and arg here are either: # new_arg and arg here are either:

View File

@ -89,7 +89,7 @@ def remove_dupe_metadata(
dynamic_dims=o.dynamic_dims, dynamic_dims=o.dynamic_dims,
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
requires_grad=o.requires_grad, requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor, view_meta_sequence=o.view_meta_sequence,
) )
for o in m.output_info for o in m.output_info
], ],
@ -242,7 +242,7 @@ def create_synthetic_base_metadata(
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
base_idx=new_base_idx, # type: ignore[arg-type] base_idx=new_base_idx, # type: ignore[arg-type]
requires_grad=o.requires_grad, requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor, view_meta_sequence=o.view_meta_sequence,
) )
) )

View File

@ -150,7 +150,7 @@ class AliasOfInputHandler:
self.base_idx = info.base_idx self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor self.view_meta_sequence = info.view_meta_sequence
self.replay_views = config.view_replay_for_aliased_outputs self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out): def __call__(self, orig_inputs, fw_outs, out):
@ -159,7 +159,7 @@ class AliasOfInputHandler:
aliased_base_tensor, aliased_base_tensor,
self.unwrap_out(out), self.unwrap_out(out),
self.requires_grad, self.requires_grad,
self.functional_tensor, self.view_meta_sequence,
replay_views=self.replay_views, replay_views=self.replay_views,
) )
@ -190,7 +190,7 @@ class AliasOfIntermediateHandler:
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor self.view_meta_sequence = info.view_meta_sequence
self.replay_views = config.view_replay_for_aliased_outputs self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out): def __call__(self, orig_inputs, fw_outs, out):
@ -199,7 +199,7 @@ class AliasOfIntermediateHandler:
self._unwrap_aliased_base_tensor(aliased_base_tensor), self._unwrap_aliased_base_tensor(aliased_base_tensor),
self.unwrap_out(out), self.unwrap_out(out),
self.requires_grad, self.requires_grad,
self.functional_tensor, self.view_meta_sequence,
replay_views=self.replay_views, replay_views=self.replay_views,
) )

View File

@ -7,7 +7,6 @@ input/output types, metadata, config, function signatures etc.
from __future__ import annotations from __future__ import annotations
import collections import collections
import dataclasses
import functools import functools
import itertools import itertools
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -32,10 +31,7 @@ from torch.fx.experimental._backward_state import BackwardState
from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config from .. import config
from .functional_utils import ( from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence
_check_if_mutation_can_be_in_graph,
FunctionalTensorMetadataEq,
)
from .utils import strict_zip from .utils import strict_zip
@ -117,15 +113,14 @@ class OutputAliasInfo:
dynamic_dims: Optional[set[int]] dynamic_dims: Optional[set[int]]
# requires_grad # requires_grad
requires_grad: bool requires_grad: bool
# FunctionalTensorWrapper that represents this output. # Sequence of ViewMeta objects.
# #
# Provides us the means to replay views from it. # Provides us the means to re-run view functions on other tensors.
# #
# We need to wrap the actual FunctionalTensorWrapper with this class so that # We need to wrap the actual list of ViewMeta with this class so that
# we only compare the tensor's metadata. That's because with the transformations # we compare the ViewMeta elements appropriately, i.e. their type and
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base # the elements returned by the `as_tuple()` call.
# tensor might change. view_meta_sequence: Optional[ViewMetaSequence] = None
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
class MutationType(Enum): class MutationType(Enum):
@ -665,17 +660,6 @@ class ViewAndMutationMeta:
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
# Clear traced tangents at runtime # Clear traced tangents at runtime
self.traced_tangents = [] self.traced_tangents = []
new_output_info = []
for out in self.output_info:
if config.view_replay_for_aliased_outputs:
new_out = out
else:
# If we're not using view_replay, remove the functional tensor.
# Functional tensors are unfortunately not serializable,
# so doing this is required for AOTAutograd caching.
new_out = dataclasses.replace(out, functional_tensor=None)
new_output_info.append(new_out)
self.output_info = new_output_info
for inp_meta in self.subclass_inp_meta: for inp_meta in self.subclass_inp_meta:
if isinstance(inp_meta, SubclassCreationMeta): if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe() inp_meta.make_runtime_safe()

View File

@ -72,6 +72,7 @@
#include <torch/csrc/cpu/Module.h> #include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h> #include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h> #include <torch/csrc/export/pybind.h>
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/functorch/init.h> #include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h> #include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h> #include <torch/csrc/inductor/aoti_package/pybind.h>
@ -2080,6 +2081,7 @@ PyObject* initModule() {
torch::instruction_counter::initModule(module); torch::instruction_counter::initModule(module);
torch::initVerboseBindings(module); torch::initVerboseBindings(module);
ASSERT_TRUE(THPStorage_init(module)); ASSERT_TRUE(THPStorage_init(module));
torch::functionalization::initModule(module);
#ifdef USE_CUDA #ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace // This will only initialise base classes and attach them to library namespace

View File

@ -644,15 +644,6 @@ void initTorchFunctions(PyObject* module) {
at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::isFunctionalTensor(t));
at::functionalization::impl::mark_mutation_hidden_from_autograd(t); at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
}); });
py_module.def(
"_functionalize_apply_view_metas",
[](const at::Tensor& tensor, const at::Tensor& base) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl =
at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->apply_view_metas(base);
});
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);

View File

@ -0,0 +1,71 @@
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalizeFallbackKernel.h>
#include <memory>
namespace torch::functionalization {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
// Create a `torch._C._functionalization` Python module.
auto functionalization = m.def_submodule(
"_functionalization", "functionalization related pybind.");
// Retrieve the ViewMeta sequence of a given functional tensor.
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->view_metas();
});
// Applies the given ViewMeta sequence to the given base.
functionalization.def(
"apply_view_meta_sequence",
[](const at::Tensor& base,
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
sequence) {
return at::functionalization::impl::apply_view_meta_sequence(
base, sequence);
});
// Binding for InverseReturnMode.
py::enum_<at::functionalization::InverseReturnMode>(
functionalization, "InverseReturnMode")
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
.value(
"ViewOrScatterInverse",
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
// Create bindings for the ViewMeta base class.
//
// Needed so that we can take a list of ViewMeta objects as parameter.
// Specifically, in the Python-side, we will have a list of derived ViewMeta
// classes. We need to tell pybind11 that all of those are, in fact, instances
// of different ViewMeta sub-types.
py::class_<
at::functionalization::ViewMeta,
std::shared_ptr<at::functionalization::ViewMeta>>(
functionalization, "ViewMeta")
.def_property_readonly(
"has_symbolic_inputs",
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
return meta->has_symbolic_inputs;
});
// Bindings for `ViewMeta` specializations manually implemented.
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
functionalization);
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
functionalization);
// Bindings for `ViewMeta` specializations automatically generated.
initGenerated(functionalization.ptr());
}
} // namespace torch::functionalization

View File

@ -0,0 +1,36 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::functionalization {
// Creates the default bindings for `ViewMeta` specializations.
//
// Defines a constructor using the types in `SerializableTuple`, as well
// as pickle methods.
template <class T>
void create_binding_with_pickle(py::module m) {
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
m, T::name())
.def(py::init<typename T::SerializableTuple>())
.def(
"as_tuple",
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
})
.def(py::pickle(
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
},
[](const typename T::SerializableTuple& tpl) {
return std::make_shared<T>(tpl);
}));
}
void initModule(PyObject* module);
void initGenerated(PyObject* module);
} // namespace torch::functionalization

View File

@ -213,7 +213,7 @@ _SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic)
class SymIntEqByExpr: class SymIntEqByExpr:
""" """
This is a wrapper around SymInt which has alternative semantics for This is a wrapper around SymInt which has alternative semantics for
equality. Specifically, instead of erroring or guarding, we equality and pickling. Specifically, instead of erroring or guarding, we
instead will hash/compare equality based on the underlying sympy instead will hash/compare equality based on the underlying sympy
expression; e.g., s0 and s1 will always compare as False. expression; e.g., s0 and s1 will always compare as False.
@ -222,31 +222,25 @@ class SymIntEqByExpr:
canonicalize to the same expression via regular simplification. canonicalize to the same expression via regular simplification.
""" """
val: Union[torch.SymInt, int] @staticmethod
def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr:
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return sympy.Integer(val)
def __init__(self, val: Union[torch.SymInt, int]) -> None: def __init__(self, val: Union[torch.SymInt, int]) -> None:
self.val = val self.val: sympy.Expr = SymIntEqByExpr._extract(val)
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(self.val) return repr(self.val)
def _extract(self) -> sympy.Expr:
if isinstance(self.val, torch.SymInt):
return self.val.node.expr
else:
return sympy.Integer(self.val)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
assert isinstance(other, SymIntEqByExpr) assert isinstance(other, SymIntEqByExpr)
# int equality fastpath
if type(self.val) is int and type(other.val) is int:
return self.val == other.val return self.val == other.val
return self._extract() == other._extract()
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self._extract()) return hash(self.val)
def _nested_int_aware_sort( def _nested_int_aware_sort(

View File

@ -23,20 +23,13 @@ from torchgen.model import (
# This file describes the translation of JIT schema to API's used # This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass. # when creating `ViewMeta` specializations that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas. # These API's mostly follow the dispatcher API, with one difference:
# These API's mostly follow the dispatcher API, with a few quirks: # - While the forward function just directly calls into the at::_ops API
# - The lambda capture has to convert reference types to value types # (following the dispatcher convention), the logic here for the reverse function
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations # is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace). # (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments. # Define some specific lambda input arguments.
base_binding = Binding( base_binding = Binding(
name="base", name="base",
@ -46,6 +39,18 @@ base_binding = Binding(
), ),
default=None, default=None,
) )
has_symbolic_inputs_binding = Binding(
name="has_symbolic_inputs",
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
argument=Argument(
name="has_symbolic_inputs",
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
mutated_view_binding = Binding( mutated_view_binding = Binding(
name="mutated_view", name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
), ),
default=None, default=None,
) )
mutated_view_idx_binding = Binding( out_index_binding = Binding(
name="mutated_view_idx", name="out_index",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), nctype=NamedCType(name="out_index", type=BaseCType(longT)),
argument=Argument( argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
), ),
default=None, default=None,
) )
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
) )
# The lambda capture itself doesn't have a name. # Name of the `ViewMeta` specialization class created.
# The name returned here corresponds to the name of the inner function called by the lambda. def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
namespace = "at::functionalization::" if with_namespace else ""
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
# Name of the operation called inside the `forward`/`reverse` implementations.
def name( def name(
g: NativeFunctionsViewGroup, g: NativeFunctionsViewGroup,
*, *,
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse" return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType: def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs # Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1 assert len(func.returns) >= 1
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT) return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]: # Checks whether `func` might return more than one value.
if is_reverse: def is_multi_output(func: FunctionSchema) -> bool:
return [base_binding, mutated_view_binding, mutated_view_idx_binding] return len(func.returns) > 1 or (
else: len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
return [base_binding, mutated_view_idx_binding] )
def inner_call_index(func: FunctionSchema) -> Binding | None: # `ViewMeta` specialization constructor parameters.
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately # All specializations are parematerized by `has_symbolic_inputs` flag.
if len(func.returns) > 1 or ( arguments = [has_symbolic_inputs_binding]
len(func.returns) == 1 and func.returns[0].type.is_list_like()
): # If `func` might return more than 1 value, we also parameterize this specialization
return mutated_view_idx_binding # with the output index.
return None if is_multi_output(func):
arguments.append(out_index_binding)
return arguments
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # `ViewMeta` specialized class' constructor arguments.
#
# Values needed specifically by this specialization, that the base class does not need.
# Same as the class' attributes, but non-owning.
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
return attributes(func, owning=False)
# `ViewMeta` specialized class' non-static member data.
#
# Essential data for calling the instance's `forward` and `reverse functions. You can
# think of them as values that should be captured from the functionalization kernel.
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
return [
reapply_views_binding,
inverse_return_mode_binding,
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
]
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor) assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:] non_self_args = args[1:]
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
# the reverse lambda does the same, but with an additional "mutated_view" arg # the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs # additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument. # their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func) if is_multi_output(func):
if index_binding is not None:
return [ return [
base_binding, base_binding,
mutated_view_binding, mutated_view_binding,
inverse_return_mode_binding, inverse_return_mode_binding,
index_binding, out_index_binding,
] + non_self_bindings ] + non_self_bindings
else: else:
return [ return [

View File

@ -300,83 +300,11 @@ class ViewInverseSignature:
return_type = functionalization.returns_type(self.g.view.func) return_type = functionalization.returns_type(self.g.view.func)
decls = [ decls = [
a.decl() a.decl()
for a in functionalization.inner_arguments( for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
self.g.view.func, is_reverse=True
)
] ]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
else:
return f"{inner_call_name}({', '.join(call_exprs)});"
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True) @dataclass(frozen=True)
class StructuredImplSignature: class StructuredImplSignature:
g: NativeFunctionsGroup g: NativeFunctionsGroup

View File

@ -43,6 +43,8 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_definition, gen_functionalization_definition,
gen_functionalization_registration, gen_functionalization_registration,
gen_functionalization_view_inverse_declaration, gen_functionalization_view_inverse_declaration,
gen_functionalization_view_meta_classes_decl,
gen_functionalization_view_meta_classes_impl,
GenCompositeViewCopyKernel, GenCompositeViewCopyKernel,
) )
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
@ -2493,9 +2495,6 @@ def gen_source_files(
}, },
) )
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
def gen_op_headers( def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]: ) -> list[str]:
@ -2535,6 +2534,9 @@ def gen_source_files(
f"#include <ATen/ops/{g.root_name}_ops.h>", f"#include <ATen/ops/{g.root_name}_ops.h>",
] ]
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
return { return {
"ops_headers": gen_op_headers(g), "ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition( "func_definitions": gen_functionalization_definition(
@ -2600,6 +2602,31 @@ def gen_source_files(
}, },
) )
cpu_fm.write(
"ViewMetaClasses.h",
lambda: {
"view_meta_declarations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
view_groups,
)
)
},
)
cpu_fm.write(
"ViewMetaClasses.cpp",
lambda: {
"view_meta_implementations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
view_groups,
)
),
"op_headers": list(concatMap(gen_op_headers, view_groups)),
},
)
# Note [view_copy NativeFunctions] # Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant. # needs to have a corresponding non-aliasing {view}_copy variant.

View File

@ -1,16 +1,15 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING from typing import Callable, Optional, TYPE_CHECKING
from torchgen.api import cpp, dispatcher from torchgen.api import cpp, dispatcher, functionalization
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import ( from torchgen.api.types import (
BaseCType, BaseCType,
Binding, Binding,
CType, CType,
DispatcherSignature, DispatcherSignature,
FunctionalizationLambda,
iTensorListRefT, iTensorListRefT,
NativeSignature, NativeSignature,
OptionalCType, OptionalCType,
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
) )
from torchgen.utils import dataclass_repr from torchgen.utils import concatMap, dataclass_repr, FileManager
if TYPE_CHECKING: if TYPE_CHECKING:
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
with native_function_manager(f): with native_function_manager(f):
call_sig = DispatcherSignature.from_schema(g.view_copy.func) call_sig = DispatcherSignature.from_schema(g.view_copy.func)
spec = ViewMetaSpecialization(g, f=f)
# the "view_copy" op name that the functionalization kernels need to call # the "view_copy" op name that the functionalization kernels need to call
api_name = g.view_copy.func.name.unambiguous_name() api_name = g.view_copy.func.name.unambiguous_name()
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
] ]
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
# The meta API call should use the same arguments, but convert all tensors to meta tensors first. # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [ meta_call_args = [
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
: at::functionalization::InverseReturnMode::NeverView : at::functionalization::InverseReturnMode::NeverView
); );
{symbolic_inputs_check} {symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = {spec.new()};
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
);
auto compute_reference_meta = auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
""" """
else: else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
return f""" return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str} {unwrap_tensor_args_str}
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
}} }}
}} }}
{symbolic_inputs_check} {symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = {spec.new()};
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass] // See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta && !disable_meta_reference()) {{ if (compute_reference_meta && !disable_meta_reference()) {{
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
return emit_decl_helper(g) return emit_decl_helper(g)
# Helper class for generating `ViewMeta` specializations.
@dataclass
class ViewMetaSpecialization:
g: NativeFunctionsViewGroup
f: NativeFunction
@property
def is_multi_output(self) -> bool:
return functionalization.is_multi_output(self.f.func)
@property
def is_as_strided(self) -> bool:
return str(self.f.func.name) == "as_strided"
@property
def out_index(self) -> str:
if self.is_multi_output:
return functionalization.out_index_binding.name
return "0"
@property
def classname(self) -> str:
return functionalization.classname(self.f.func)
def decl(self) -> list[str]:
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
attributes = functionalization.attributes(self.f.func)
# List of types for declaring the `SerializableTuple` type.
serializable_tuple_args = ",\n".join(
f" {binding.type} /* {binding.name} */"
for binding in (base_ctor_arguments + attributes)
)
# Arguments used for forwarding the tuple elements to the constructor.
destructure_tuple_args = ", ".join(
f"std::get<{i}>(tpl)"
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
)
# List of constructor parameters
ctor_parameters = ", ".join(
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
)
# Call the base class `ViewMeta` constructor.
#
# Both of `is_multi_output` and `is_as_strided` are known values, given the
# operation schema.
is_multi_output_str = str(self.is_multi_output).lower()
is_as_strided_str = str(self.is_as_strided).lower()
base_ctor_bindings = ", ".join(
[
# `has_symbolic_inputs` is always taken as parameter.
functionalization.has_symbolic_inputs_binding.name,
f"/*is_multi_output=*/{is_multi_output_str}",
f"/*is_as_strided=*/{is_as_strided_str}",
# `out_index` is know if the operation returns only one value. Otherwise,
# we also take it as parameter.
f"/*out_index=*/{self.out_index}",
]
)
# Assignments of `extra_ctor_arguments` to their corresponding fields.
# These are extra fields to-be-declared in this specialization.
#
# We need to set `allow_expensive_conversions`, since we are storing owned versions
# of the non-owning arguments.
ctor_assignments = ",\n".join(
f" {e.type.name}({e.expr})"
for e in translate(
extra_ctor_arguments,
attributes,
method=False,
allow_expensive_conversions=True,
)
)
# List of arguments for constructing the `SerializableTuple` from an instance.
tuple_arguments = ", ".join(
binding.name for binding in (base_ctor_arguments + attributes)
)
# List of field declarations.
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
# Override `to_out_index` if this operation returns more than 1 value.
to_out_index_decl = ""
if self.is_multi_output:
to_out_index_decl = (
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
)
return [
f"""
struct TORCH_API {self.classname} : public ViewMeta {{
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname})
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
{self.classname}(const SerializableTuple& tpl)
: {self.classname}({destructure_tuple_args}) {{}}
{self.classname}({ctor_parameters})
: at::functionalization::ViewMeta({base_ctor_bindings}),
{ctor_assignments} {{}}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
{to_out_index_decl}
SerializableTuple to_serializable_tuple() {{
return std::make_tuple({tuple_arguments});
}}
{attr_declarations}
}};
"""
]
# Generate a call to the actual operation.
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
opname = functionalization.name(
self.g,
is_reverse=is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
# Expected arguments for the operation.
assert self.g.view_copy is not None
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
# The context is composed by the constructor arguments (which are also
# the field variables stored in the instance), and the `base` tensor.
context = [functionalization.base_binding]
context += functionalization.base_ctor_arguments(self.f.func)
context += functionalization.attributes(self.f.func)
# If we are generating the call for the reverse function, we also have
# access to `mutated_view` argument.
if is_reverse:
context.append(functionalization.mutated_view_binding)
arguments = ", ".join(
[e.expr for e in translate(context, op_arguments, method=False)]
)
# Index the result if this operation returns multiple values.
maybe_index = ""
if not is_reverse and self.is_multi_output:
maybe_index = f"[{self.out_index}]"
return f"{opname}({arguments}){maybe_index}"
def impl(self) -> list[str]:
functions = [
f"""
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
if (reapply_views) {{
return {self.opcall(is_reverse=False, reapply_views=True)};
}} else {{
return {self.opcall(is_reverse=False, reapply_views=False)};
}}
}}""",
f"""
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
return {self.opcall(is_reverse=True, reapply_views=True)};
}}""",
]
# If this operation returns multiple values, also generate a `to_out_index`
# implementation.
if self.is_multi_output:
functions.append(f"""
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
return {self.new("out_index")};
}}
""")
return functions
# Create the Python binding for this specialized class.
def binding(self) -> list[str]:
name = functionalization.classname(self.f.func, with_namespace=True)
return [f" create_binding_with_pickle<{name}>(functionalization);"]
# Generate an instantiation of this specialized class.
def new(self, out_index: str = "0") -> str:
name = functionalization.classname(self.f.func, with_namespace=True)
ctor_arguments = functionalization.base_ctor_arguments(
self.f.func
) + functionalization.extra_ctor_arguments(self.f.func)
# Replace the `out_index` parameter with the given `out_index`.
arguments = ", ".join(
binding.name if binding.name != "out_index" else out_index
for binding in ctor_arguments
)
return f"std::make_shared<{name}>({arguments})"
# Run the function `run` for both: `view` and `view_inplace` functions.
@staticmethod
def map(
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
) -> list[str]:
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
if f is None:
return []
with native_function_manager(f):
return run(ViewMetaSpecialization(g, f))
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
def gen_functionalization_view_meta_classes_base(
selector: SelectiveBuilder,
g: NativeFunctionsViewGroup,
run: Callable[[ViewMetaSpecialization], list[str]],
) -> list[str]:
if not selector.include_all_operators:
return []
if g.composite:
return []
return ViewMetaSpecialization.map(g, run)
def gen_functionalization_view_meta_classes_decl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.decl
)
def gen_functionalization_view_meta_classes_impl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.impl
)
def gen_functionalization_view_meta_classes_binding(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.binding
)
# Generates the Python bindings for the `ViewMeta` specialized classes.
def gen_functionalization_view_meta_classes(
native_functions_path: str,
tags_path: str,
selector: SelectiveBuilder,
install_dir: str,
template_dir: str,
) -> None:
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
# Parse the native_functions.yaml.
# Then, group them into `NativeFunctionsViewGroup`.
#
# This is the same steps we do in gen.py (ATen codegen).
native_functions = parse_native_yaml(
native_functions_path, tags_path
).native_functions
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
fm.write(
"ViewMetaClassesPythonBinding.cpp",
lambda: {
"view_meta_bindings": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_binding(
selector, g
),
view_groups,
)
),
},
)
def gen_functionalization_registration( def gen_functionalization_registration(
selector: SelectiveBuilder, selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,