Revert "Make functionalization ViewMeta serializable with pickle. (#143712)"

This reverts commit b8abdaa286fd161af48af57a675827f4f849914d.

Reverted https://github.com/pytorch/pytorch/pull/143712 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/143712#issuecomment-2597205261))
This commit is contained in:
PyTorch MergeBot
2025-01-17 00:52:50 +00:00
parent 42c64bd35c
commit 6c713ccb5e
35 changed files with 425 additions and 951 deletions

1
.gitignore vendored
View File

@ -79,7 +79,6 @@ torch/return_types.pyi
torch/nn/functional.pyi
torch/utils/data/datapipes/datapipe.pyi
torch/csrc/autograd/generated/*
torch/csrc/functionalization/generated/*
torch/csrc/lazy/generated/*.[!m]*
torch_compile_debug/
# Listed manually because some files in this directory are not generated

View File

@ -90,8 +90,6 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/VmapGeneratedPlumbing.h",
"aten/src/ATen/ViewMetaClasses.h",
"aten/src/ATen/ViewMetaClasses.cpp",
"aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h",
@ -1089,7 +1087,6 @@ test_suite(
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml",

View File

@ -9,6 +9,11 @@
namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}
// Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl.
@ -42,7 +47,7 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values
@ -50,8 +55,9 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
tmp_values.push_back(std::move(next_view));
}
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i]->reverse(tmp_values[i], t);
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t;
@ -105,13 +111,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
}
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
if (metas.size() > 1) {
for (size_t i = 1; i < metas.size(); ++i) {
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "

View File

@ -8,89 +8,44 @@ namespace at::functionalization {
// See Note [Functionalization Pass In Core]
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying)
/// scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to
/// handle.
ViewOrScatterInverse,
};
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
static const char* name() { \
return #TYPE; \
}
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
using SerializableTuple = std::tuple<__VA_ARGS__>;
// ViewMeta is a class used by the functionalization pass to navigate between
// a base tensor and a view tensor.
// For example, if I call `b = a.view1(...)`
// the functionalization pass will generate and store a ViewMeta specialization
// for `view1` operation on b that looks like:
// the functionalization pass will generate and store a ViewMeta on b that looks
// like:
//
// struct TORCH_API view1_ViewMeta : public ViewMeta {
// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
// bool /* reapply_views */,
// const std::vector<int64_t>&);
//
// view1_ViewMeta(const SerializableTuple& tpl)
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
//
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
// : ViewMeta(/*has_symbolic_inputs=*/false),
// reapply_views(reapply_views),
// size(size) {}
//
// Tensor forward(const Tensor& base) override {
// ViewMeta(
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
// return base.view1(...);
// }
//
// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
// },
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
// int64_t mutated_view_idx) -> at::Tensor {
// return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...);
// }
//
// SerializableTuple to_serializable_tuple() {
// return std::make_tuple(reapply_views, size);
// }
// The forward_fn lambda describes how to replay view1 on a tensor.
//
// bool reapply_views;
// std::vector<int64_t> size;
// };
//
// The forward function describes how to replay view1 on a tensor.
//
// The reverse function describes how, given a tensor that is already a view,
// The reverse_fn lambda describes how, given a tensor that is already a view,
// how to get the corresponding base tensor. See Note [Functionalization Pass:
// View Inverses] for details.
//
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
// representing the `ViewMeta` instance state. Methods that take in/return such
// a type are used for supporting pickle serialization.
struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs,
bool is_multi_output = false,
bool is_as_strided = false,
int64_t out_idx = 0)
: out_index(out_idx),
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output),
is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {}
virtual ~ViewMeta() {}
virtual Tensor forward(const Tensor& base) = 0;
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
// See Note [out_idx in ViewMeta]
int64_t out_index;
@ -102,17 +57,10 @@ struct ViewMeta {
// Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs;
// Returns a new ViewMeta with the same forward/reverse
// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
//
// This method should be implemented by those `ViewMeta` that have more than
// one output.
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"ViewMeta::to_out_index not implemented. ",
"Likely because there's only one output.");
}
ViewMeta to_out_idx(int64_t out_idx);
};
// FunctionalStorageImpl is a subclass of StorageImpl used by the
@ -145,14 +93,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor new_val;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<std::shared_ptr<ViewMeta>> view_metas;
const std::vector<ViewMeta> view_metas;
};
explicit FunctionalStorageImpl(const Tensor& value);
void add_update(
const Tensor& updated_val,
const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
const std::vector<ViewMeta>& view_metas);
bool apply_updates();
const Tensor& base() {
return base_;

View File

@ -129,19 +129,17 @@ void FunctionalTensorWrapper::freeze_storage() const {
// - view_value: The output tensor that we need to wrap.
// - base: The "base" of the view that `view_value` was generated from.
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper::FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const std::shared_ptr<functionalization::ViewMeta>& meta)
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta)
: c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(),
view_value.device()),
view_value.device()
),
value_(view_value),
is_multi_output_view_(
base->is_multi_output_view_ || meta->is_multi_output),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_) {
is_symbolic_(base->is_symbolic_)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
@ -150,10 +148,11 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(
view_metas_ = base->view_metas_; // copy
}
view_metas_.push_back(meta);
maybe_mark_symbolic(meta.get());
maybe_mark_symbolic(meta);
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
}
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
}
@ -177,18 +176,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
}
// See Note [Functionalization Pass - Inplace View Ops]
void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) {
view_metas_.push_back(meta);
// Manually track the fact that this tensor recieved a metadata mutation!
has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta.get());
maybe_mark_symbolic(meta);
// Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard;
value_ = meta->forward(value_);
value_ = meta.forward_fn(value_, meta.out_index);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
@ -369,8 +368,15 @@ void FunctionalTensorWrapper::sync_() {
regenerate_from_base();
}
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
return view_metas_;
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) {
auto t = base;
// Reapply views to get the viewed tensor from the base in alias_
for (auto& view_meta: view_metas_) {
t = view_meta.forward_fn(t, view_meta.out_index);
}
return t;
}
void FunctionalTensorWrapper::regenerate_from_base() {
@ -379,7 +385,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
auto t = storage_impl->base();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
t = apply_view_metas(t);
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t, /*from_lazy_regenerate=*/true);
@ -753,28 +759,20 @@ void freeze_functional_tensor(const Tensor& tensor) {
functional_base_impl->freeze_storage();
}
Tensor create_functional_tensor_with_view_meta(
const at::Tensor& view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx) {
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
auto meta_ = meta;
if (out_idx != 0) {
// Note [out_idx in ViewMeta]
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
meta_ = meta->to_out_index(out_idx);
meta = meta.to_out_idx(out_idx);
}
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
}
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta) {
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) {
std::vector<Tensor> outputs(view_to_wrap.size());
int64_t i = 0;
for (const auto& tensor : view_to_wrap) {
@ -784,22 +782,12 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(
return outputs;
}
void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
self_impl->mutate_view_meta(meta);
}
Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
Tensor r = base;
for (auto& vm : sequence) {
r = vm->forward(r);
}
return r;
}
// Note [Propagating strides in the functionalization pass]
// In order to properly compute stride information, the functionalization pass
// calls each {view} reference implementations with meta tensors.

View File

@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const std::shared_ptr<functionalization::ViewMeta>& meta);
const functionalization::ViewMeta& meta);
// Get the underlying, actual tensor, that doesn't know anything about
// functionalization.
@ -97,17 +97,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode();
}
void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
}
bool is_symbolic() const {
return is_symbolic_;
}
// Retrieves the ViewMeta sequence of this tensor.
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
const;
// Runs the forward_fn of every ViewMeta collected in the current instance
// to some other base.
Tensor apply_view_metas(const Tensor& base);
// Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
@ -144,8 +144,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias.
void mutate_view_meta(
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
// Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other);
@ -274,7 +273,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_symbolic_ = false;
size_t generation_ = 0;
std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
std::vector<at::functionalization::ViewMeta> view_metas_;
protected:
static void copy_tensor_metadata(
@ -366,20 +365,16 @@ TORCH_API void propagate_xla_data_direct(
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta);
const functionalization::ViewMeta& meta);
void mutate_view_meta(
const Tensor& self,
const std::shared_ptr<functionalization::ViewMeta>& meta);
TORCH_API Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
const functionalization::ViewMeta& meta);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset(

View File

@ -1,5 +1,3 @@
#include <ATen/FunctionalizeFallbackKernel.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
@ -29,31 +27,6 @@
#include <utility>
#endif
namespace at::functionalization {
Tensor resize__ViewMeta::forward(const Tensor& base) {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
}
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return base.as_strided_scatter(
mutated_view, size, c10::contiguous_strides(size));
}
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
return at::_unsafe_view_symint(base, size);
}
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
} // namespace at::functionalization
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
const auto& schema = op.schema();
@ -195,8 +168,19 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
// The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
reapply_views, size.vec());
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
}
@ -315,11 +299,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size);
}
bool has_symbolic_inputs = std::any_of(
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
auto view_meta =
std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
has_symbolic_inputs, size.vec());
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
// See Note [Propagating strides in the functionalization pass]

View File

@ -1,58 +0,0 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
namespace at::functionalization {
// `ViewMeta` implementation for `resize_` operation.
struct TORCH_API resize__ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta);
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* reapply_views */,
const std::vector<int64_t>&);
resize__ViewMeta(const SerializableTuple& tpl)
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
: ViewMeta(/*has_symbolic_inputs=*/false),
reapply_views(reapply_views),
size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(reapply_views, size);
}
bool reapply_views;
std::vector<int64_t> size;
};
// `ViewMeta` implementation for `_unsafe_view` operation.
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta);
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* has_symbolic_inputs */,
const std::vector<c10::SymInt>&);
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
_unsafe_view_ViewMeta(
bool has_symbolic_inputs,
const std::vector<c10::SymInt>& size)
: ViewMeta(has_symbolic_inputs), size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(has_symbolic_inputs, size);
}
std::vector<c10::SymInt> size;
};
} // namespace at::functionalization

View File

@ -2,12 +2,22 @@
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/Tensor.h>
namespace at {
namespace functionalization {
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying) scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
ViewOrScatterInverse,
};
struct FunctionalInverses {
${view_inverse_declarations}

View File

@ -4,7 +4,7 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/ViewMetaClasses.h>
#include <ATen/FunctionalInverses.h>
#include <ATen/MemoryOverlap.h>
#include <torch/library.h>

View File

@ -1,19 +0,0 @@
// ${generated_comment}
#include <ATen/FunctionalInverses.h>
#include <ATen/ViewMetaClasses.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#include <ATen/NativeFunctions.h>
#else
${op_headers}
#endif
namespace at {
namespace functionalization {
${view_meta_implementations}
} // namespace functionalization
} // namespace at

View File

@ -1,12 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
namespace at {
namespace functionalization {
${view_meta_declarations}
} // namespace functionalization
} // namespace at

View File

@ -1,11 +0,0 @@
#include <ATen/ViewMetaClasses.h>
#include <torch/csrc/functionalization/Module.h>
namespace torch::functionalization {
void initGenerated(PyObject* module) {
auto functionalization = py::handle(module).cast<py::module>();
$view_meta_bindings
}
} // namespace torch::functionalization

View File

@ -117,7 +117,6 @@ def define_targets(rules):
":LazyNonNativeIr.h",
":RegisterDispatchDefinitions.ini",
":RegisterDispatchKey.cpp",
":ViewMetaClassesPythonBinding.cpp",
":native_functions.yaml",
":shape_inference.h",
":tags.yaml",
@ -298,7 +297,6 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
]
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP

View File

@ -929,7 +929,6 @@ libtorch_python_core_sources = [
"torch/csrc/utils/disable_torch_function.cpp",
"torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp",
"torch/csrc/functionalization/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp",
] + lazy_tensor_core_python_sources

View File

@ -310,7 +310,6 @@ set(GENERATED_CXX_PYTHON
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
)
set(GENERATED_H_PYTHON
@ -374,7 +373,6 @@ add_custom_command(
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
${autograd_python}
${autograd_yaml}
${autograd_templates}

View File

@ -250,7 +250,11 @@ class AOTAutogradCacheTests(InductorTestCase):
@functorch_config.patch(
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
)
def test_view_replay(self):
def test_view_replay_bypass(self):
"""
Shoud bypass when view replay is turned on
"""
def fn(a):
tmp = a.detach()
a.mul_(2)
@ -258,25 +262,10 @@ class AOTAutogradCacheTests(InductorTestCase):
with torch.autograd._force_original_view_tracking(True):
compiled_fn = torch.compile(fn)
compiled_fn(torch.rand(2, 3))
def run_and_check(miss, hit, bypass):
self._clear_dynamo_and_codecache()
inp = torch.rand(2, 3)
compiled_inp = inp.clone().detach()
with torch.autograd._force_original_view_tracking(True):
out = fn(inp)
compiled_out = compiled_fn(compiled_inp)
self.assertEqual(out, compiled_out)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
run_and_check(miss=1, hit=0, bypass=0)
run_and_check(miss=1, hit=1, bypass=0)
run_and_check(miss=1, hit=2, bypass=0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", False)

View File

@ -6897,6 +6897,7 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
{
"enable_autograd_cache": True,
"strict_autograd_cache": True,
"view_replay_for_aliased_outputs": False,
}
)
@torch._inductor.config.patch("fx_graph_cache", True)

View File

@ -189,12 +189,6 @@ def main() -> None:
)
options = parser.parse_args()
# Path: aten/src/ATen
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
operator_selector = get_selector(
options.selected_op_list_path, options.operators_yaml_path
)
generate_code(
options.gen_dir,
options.native_functions_path,
@ -204,32 +198,13 @@ def main() -> None:
options.disable_autograd,
options.force_schema_registration,
# options.selected_op_list
operator_selector=operator_selector,
)
# Generate the python bindings for functionalization's `ViewMeta` classes.
from torchgen.gen_functionalization_type import (
gen_functionalization_view_meta_classes,
)
functionalization_templates_dir = os.path.join(aten_path, "templates")
functionalization_install_dir = os.path.join(
options.gen_dir, "torch/csrc/functionalization/generated"
)
os.makedirs(functionalization_install_dir, exist_ok=True)
assert os.path.isdir(functionalization_install_dir)
assert os.path.isdir(functionalization_templates_dir)
gen_functionalization_view_meta_classes(
options.native_functions_path or NATIVE_FUNCTIONS_PATH,
options.tags_path or TAGS_PATH,
selector=operator_selector,
install_dir=functionalization_install_dir,
template_dir=functionalization_templates_dir,
operator_selector=get_selector(
options.selected_op_list_path, options.operators_yaml_path
),
)
if options.gen_lazy_ts_backend:
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"

View File

@ -67,7 +67,6 @@ from . import (
_export,
_cpu,
_dynamo,
_functionalization,
_functorch,
_lazy,
_lazy_ts_backend,

View File

@ -1,16 +0,0 @@
from torch import Tensor
from torch.types import _bool
# Defined in torch/csrc/functionalization/Module.cpp
class ViewMeta:
has_symbolic_inputs: _bool
# Returns the list of ViewMeta instances of the given functional tensor.
#
# Although we do have python bindings for their types, we won't
# expose them here, since they should not be used by users.
def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ...
# Applies the ViewMeta sequence on top of the given base.
def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ...

View File

@ -227,6 +227,19 @@ def check_cacheable(gm: torch.fx.GraphModule):
check_node_safe(node)
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
"""
When view replay is turned on, we bypass autograd cache if
the output is aliased.
"""
if config.view_replay_for_aliased_outputs:
for info in metadata.output_info:
if info.functional_tensor is not None:
raise BypassAOTAutogradCache(
"Cannot cache a graph with functional tensor"
)
class AOTAutogradCacheDetails(FxGraphHashDetails):
"""
Object to capture all the details for a dynamo graph module relevant to computing
@ -862,6 +875,7 @@ class AOTAutogradCache:
def save(key: str, entry: AOTAutogradCacheEntry, remote: bool):
"""Save a single entry into the cache."""
try:
check_metadata_cacheable(entry.runtime_metadata)
content = pickle.dumps(entry)
CacheArtifactManager.record_artifact(
CacheArtifactType.AOT_AUTOGRAD, key, content

View File

@ -36,10 +36,10 @@ from .functional_utils import (
has_metadata_mutation,
MetadataKey,
to_fun,
ViewMetaSequence,
was_inductor_storage_resized,
)
from .schemas import (
FunctionalTensorMetadataEq,
InputAliasInfo,
MutationType,
OutputAliasInfo,
@ -604,7 +604,7 @@ from a multi-output view call"
#
# The FunctionalTensor will be saved if one of the 2 conditions below
# is true:
view_meta_sequence = None
functional_tensor = None
if (
# 1. If the output_type is either of:
# (i) alias_of_intermediate;
@ -636,7 +636,7 @@ from a multi-output view call"
and not input_info[base_idx].mutates_metadata
):
if isinstance(o, FunctionalTensor):
view_meta_sequence = ViewMetaSequence(o)
functional_tensor = FunctionalTensorMetadataEq(o.elem)
out_info = OutputAliasInfo(
output_type=output_type,
@ -644,7 +644,7 @@ from a multi-output view call"
base_idx=base_idx,
dynamic_dims=dynamic_dims,
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
view_meta_sequence=view_meta_sequence,
functional_tensor=functional_tensor,
)
output_info.append(out_info)

View File

@ -13,12 +13,15 @@ from typing import Optional, Tuple
import torch
from torch import Tensor
from torch._C import _functionalization
from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
from torch.fx.experimental.symbolic_shapes import (
definitely_true,
sym_eq,
SymIntEqByExpr,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
@ -224,9 +227,9 @@ def gen_alias_from_base(
aliased_base_tensor,
target_meta_tensor,
target_requires_grad,
target_view_meta_sequence: Optional[ViewMetaSequence] = None,
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None,
*,
replay_views: bool,
replay_views,
):
# Patch the correct requires_grad field of the output tensor, depending on whether:
# (i) the reconstructed output (out) was came from a tensor that requires grad or not;
@ -245,11 +248,13 @@ def gen_alias_from_base(
# to replay them (view functions) on the aliased_base_tensor.
if (
replay_views
and target_view_meta_sequence is not None
and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence)
and target_functional_tensor is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
):
out = _functionalization.apply_view_meta_sequence(
aliased_base_tensor, target_view_meta_sequence.sequence
functional_tensor = target_functional_tensor.tensor
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
@ -310,8 +315,28 @@ def gen_alias_from_base(
return aliased_out
def has_same_metadata(t1, t2):
return (
definitely_true(sym_eq(t1.size(), t2.size()))
and definitely_true(t1.layout == t2.layout)
and (
is_sparse_any(t1)
or (
definitely_true(sym_eq(t1.stride(), t2.stride()))
and definitely_true(t1.storage_offset() == t2.storage_offset())
)
)
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
@dataclass(frozen=True)
class MetadataKey:
"""
This should be equal whenever has_same_metadata would return True
"""
size: Tuple[SymIntEqByExpr, ...]
layout: torch.layout
is_sparse: bool
@ -335,45 +360,25 @@ class MetadataKey:
)
# ViewMeta sequence wrapper for equality comparisons.
#
# Even though we can compare each ViewMeta instance, we compare the resulting
# tensor metadata, instead. That's because the creation of synthetic bases + the
# re-generation of input views might end-up creating a different sequence of
# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same
# metadata.
#
# Therefore, we store what the end result should look like as serializable
# metadata.
#
# When logging, this class should look like:
#
# ViewMetaSequence(view, select_int, slice_Tensor)
#
# i.e. a parenthesized list of view operations within that ViewMeta sequence.
class ViewMetaSequence:
def __init__(self, tensor: FunctionalTensor) -> None:
assert torch._is_functional_tensor(tensor.elem)
self.sequence = _functionalization.get_view_meta_sequence(tensor.elem)
self.metadata = MetadataKey.make(tensor)
def __repr__(self) -> str:
suffix = len("_ViewMeta")
types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence)
return f"ViewMetaSequence({types})"
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata
# after applying all the ViewMeta operations.
class FunctionalTensorMetadataEq:
def __init__(self, tensor: torch.Tensor) -> None:
assert torch._is_functional_tensor(tensor)
self.tensor = tensor
def __eq__(self, other: object) -> bool:
# If other is None, then it probably means that we weren't able to recreate
# the ViewMeta sequence. One example is when we update the view metadata by
# calling: create_synthetic_base_metadata.
# the FunctionalTensorMetadataEq. One of this cases is when we update the
# view metadata by calling: create_synthetic_base_metadata.
if other is None:
return True
# Comparison against any other type is not implemented.
if not isinstance(other, ViewMetaSequence):
# Comparison agains any other type is not implemented.
if not isinstance(other, FunctionalTensorMetadataEq):
return NotImplemented
return self.metadata == other.metadata
return has_same_metadata(self.tensor, other.tensor)
# new_arg and arg here are either:

View File

@ -75,7 +75,7 @@ def remove_dupe_metadata(
dynamic_dims=o.dynamic_dims,
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
requires_grad=o.requires_grad,
view_meta_sequence=o.view_meta_sequence,
functional_tensor=o.functional_tensor,
)
for o in m.output_info
],
@ -226,7 +226,7 @@ def create_synthetic_base_metadata(
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
base_idx=new_base_idx, # type: ignore[arg-type]
requires_grad=o.requires_grad,
view_meta_sequence=o.view_meta_sequence,
functional_tensor=o.functional_tensor,
)
)

View File

@ -172,7 +172,7 @@ class AliasOfInputHandler:
self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad
self.view_meta_sequence = info.view_meta_sequence
self.functional_tensor = info.functional_tensor
self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out):
@ -181,7 +181,7 @@ class AliasOfInputHandler:
aliased_base_tensor,
self.unwrap_out(out),
self.requires_grad,
self.view_meta_sequence,
self.functional_tensor,
replay_views=self.replay_views,
)
@ -209,7 +209,7 @@ class AliasOfIntermediateHandler:
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad
self.view_meta_sequence = info.view_meta_sequence
self.functional_tensor = info.functional_tensor
self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out):
@ -218,7 +218,7 @@ class AliasOfIntermediateHandler:
aliased_base_tensor,
self.unwrap_out(out),
self.requires_grad,
self.view_meta_sequence,
self.functional_tensor,
replay_views=self.replay_views,
)

View File

@ -5,6 +5,7 @@ input/output types, metadata, config, function signatures etc.
"""
import collections
import dataclasses
import functools
from dataclasses import dataclass, field
from enum import Enum
@ -19,7 +20,10 @@ from torch._subclasses.fake_tensor import is_fake
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config
from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence
from .functional_utils import (
_check_if_mutation_can_be_in_graph,
FunctionalTensorMetadataEq,
)
from .utils import strict_zip
@ -88,14 +92,15 @@ class OutputAliasInfo:
dynamic_dims: Optional[Set[int]]
# requires_grad
requires_grad: bool
# Sequence of ViewMeta objects.
# FunctionalTensorWrapper that represents this output.
#
# Provides us the means to re-run view functions on other tensors.
# Provides us the means to replay views from it.
#
# We need to wrap the actual list of ViewMeta with this class so that
# we compare the ViewMeta elements appropriately, i.e. their type and
# the elements returned by the `as_tuple()` call.
view_meta_sequence: Optional[ViewMetaSequence] = None
# We need to wrap the actual FunctionalTensorWrapper with this class so that
# we only compare the tensor's metadata. That's because with the transformations
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base
# tensor might change.
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
class MutationType(Enum):
@ -577,6 +582,17 @@ class ViewAndMutationMeta:
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
# Clear traced tangents at runtime
self.traced_tangents = []
new_output_info = []
for out in self.output_info:
if config.view_replay_for_aliased_outputs:
new_out = out
else:
# If we're not using view_replay, remove the functional tensor.
# Functional tensors are unfortunately not serializable,
# so doing this is required for AOTAutograd caching.
new_out = dataclasses.replace(out, functional_tensor=None)
new_output_info.append(new_out)
self.output_info = new_output_info
for inp_meta in self.subclass_inp_meta:
if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe()

View File

@ -71,7 +71,6 @@
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h>
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h>
@ -1839,7 +1838,6 @@ PyObject* initModule() {
torch::instruction_counter::initModule(module);
torch::initVerboseBindings(module);
ASSERT_TRUE(THPStorage_init(module));
torch::functionalization::initModule(module);
#ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace

View File

@ -633,6 +633,15 @@ void initTorchFunctions(PyObject* module) {
at::functionalization::impl::isFunctionalTensor(t));
at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
});
py_module.def(
"_functionalize_apply_view_metas",
[](const at::Tensor& tensor, const at::Tensor& base) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl =
at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->apply_view_metas(base);
});
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);

View File

@ -1,71 +0,0 @@
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalizeFallbackKernel.h>
#include <memory>
namespace torch::functionalization {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
// Create a `torch._C._functionalization` Python module.
auto functionalization = m.def_submodule(
"_functionalization", "functionalization related pybind.");
// Retrieve the ViewMeta sequence of a given functional tensor.
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->view_metas();
});
// Applies the given ViewMeta sequence to the given base.
functionalization.def(
"apply_view_meta_sequence",
[](const at::Tensor& base,
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
sequence) {
return at::functionalization::impl::apply_view_meta_sequence(
base, sequence);
});
// Binding for InverseReturnMode.
py::enum_<at::functionalization::InverseReturnMode>(
functionalization, "InverseReturnMode")
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
.value(
"ViewOrScatterInverse",
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
// Create bindings for the ViewMeta base class.
//
// Needed so that we can take a list of ViewMeta objects as parameter.
// Specifically, in the Python-side, we will have a list of derived ViewMeta
// classes. We need to tell pybind11 that all of those are, in fact, instances
// of different ViewMeta sub-types.
py::class_<
at::functionalization::ViewMeta,
std::shared_ptr<at::functionalization::ViewMeta>>(
functionalization, "ViewMeta")
.def_property_readonly(
"has_symbolic_inputs",
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
return meta->has_symbolic_inputs;
});
// Bindings for `ViewMeta` specializations manually implemented.
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
functionalization);
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
functionalization);
// Bindings for `ViewMeta` specializations automatically generated.
initGenerated(functionalization.ptr());
}
} // namespace torch::functionalization

View File

@ -1,36 +0,0 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::functionalization {
// Creates the default bindings for `ViewMeta` specializations.
//
// Defines a constructor using the types in `SerializableTuple`, as well
// as pickle methods.
template <class T>
void create_binding_with_pickle(py::module m) {
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
m, T::name())
.def(py::init<typename T::SerializableTuple>())
.def(
"as_tuple",
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
})
.def(py::pickle(
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
},
[](const typename T::SerializableTuple& tpl) {
return std::make_shared<T>(tpl);
}));
}
void initModule(PyObject* module);
void initGenerated(PyObject* module);
} // namespace torch::functionalization

View File

@ -23,13 +23,20 @@ from torchgen.model import (
# This file describes the translation of JIT schema to API's used
# when creating `ViewMeta` specializations that are used by the functionalization pass.
# These API's mostly follow the dispatcher API, with one difference:
# - While the forward function just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse function
# when creating view lambdas that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas.
# These API's mostly follow the dispatcher API, with a few quirks:
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
@ -39,18 +46,6 @@ base_binding = Binding(
),
default=None,
)
has_symbolic_inputs_binding = Binding(
name="has_symbolic_inputs",
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
argument=Argument(
name="has_symbolic_inputs",
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
@ -59,11 +54,11 @@ mutated_view_binding = Binding(
),
default=None,
)
out_index_binding = Binding(
name="out_index",
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
mutated_view_idx_binding = Binding(
name="mutated_view_idx",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
argument=Argument(
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
@ -91,13 +86,8 @@ inverse_return_mode_binding = Binding(
)
# Name of the `ViewMeta` specialization class created.
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
namespace = "at::functionalization::" if with_namespace else ""
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
# Name of the operation called inside the `forward`/`reverse` implementations.
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(
g: NativeFunctionsViewGroup,
*,
@ -134,6 +124,24 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
@ -144,49 +152,24 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT)
# Checks whether `func` might return more than one value.
def is_multi_output(func: FunctionSchema) -> bool:
return len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
)
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
# `ViewMeta` specialization constructor parameters.
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
# All specializations are paremeterized by `has_symbolic_inputs` flag.
arguments = [has_symbolic_inputs_binding]
# If `func` might return more than 1 value, we also parameterize this specialization
# with the output index.
if is_multi_output(func):
arguments.append(out_index_binding)
return arguments
def inner_call_index(func: FunctionSchema) -> Binding | None:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like()
):
return mutated_view_idx_binding
return None
# `ViewMeta` specialized class' constructor arguments.
#
# Values needed specifically by this specialization, that the base class does not need.
# Same as the class' attributes, but non-owning.
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
return attributes(func, owning=False)
# `ViewMeta` specialized class' non-static member data.
#
# Essential data for calling the instance's `forward` and `reverse functions. You can
# think of them as values that should be captured from the functionalization kernel.
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
return [
reapply_views_binding,
inverse_return_mode_binding,
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
]
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
@ -200,12 +183,13 @@ def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
if is_multi_output(func):
index_binding = inner_call_index(func)
if index_binding is not None:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
out_index_binding,
index_binding,
] + non_self_bindings
else:
return [

View File

@ -300,11 +300,83 @@ class ViewInverseSignature:
return_type = functionalization.returns_type(self.g.view.func)
decls = [
a.decl()
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
for a in functionalization.inner_arguments(
self.g.view.func, is_reverse=True
)
]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
else:
return f'{inner_call_name}({", ".join(call_exprs)});'
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True)
class StructuredImplSignature:
g: NativeFunctionsGroup

View File

@ -45,8 +45,6 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_functionalization_view_meta_classes_decl,
gen_functionalization_view_meta_classes_impl,
GenCompositeViewCopyKernel,
)
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
@ -2579,6 +2577,9 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
},
)
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
@ -2618,9 +2619,6 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
return {
"ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition(
@ -2686,31 +2684,6 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
},
)
cpu_fm.write(
"ViewMetaClasses.h",
lambda: {
"view_meta_declarations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
view_groups,
)
)
},
)
cpu_fm.write(
"ViewMetaClasses.cpp",
lambda: {
"view_meta_implementations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
view_groups,
)
),
"op_headers": list(concatMap(gen_op_headers, view_groups)),
},
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant.

View File

@ -1,15 +1,16 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
from torchgen.api import cpp, dispatcher, functionalization
from torchgen.api import cpp, dispatcher
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
DispatcherSignature,
FunctionalizationLambda,
iTensorListRefT,
NativeSignature,
OptionalCType,
@ -47,7 +48,7 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
)
from torchgen.utils import concatMap, dataclass_repr, FileManager
from torchgen.utils import dataclass_repr
if TYPE_CHECKING:
@ -364,8 +365,6 @@ def emit_view_functionalization_body(
with native_function_manager(f):
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
spec = ViewMetaSpecialization(g, f=f)
# the "view_copy" op name that the functionalization kernels need to call
api_name = g.view_copy.func.name.unambiguous_name()
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
@ -386,6 +385,9 @@ def emit_view_functionalization_body(
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
]
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
@ -413,7 +415,19 @@ def emit_view_functionalization_body(
: at::functionalization::InverseReturnMode::NeverView
);
{symbolic_inputs_check}
auto view_meta = {spec.new()};
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
);
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
@ -441,6 +455,7 @@ def emit_view_functionalization_body(
"""
else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str}
@ -474,7 +489,21 @@ def emit_view_functionalization_body(
}}
}}
{symbolic_inputs_check}
auto view_meta = {spec.new()};
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta) {{
@ -742,301 +771,6 @@ def gen_functionalization_view_inverse_declaration(
return emit_decl_helper(g)
# Helper class for generating `ViewMeta` specializations.
@dataclass
class ViewMetaSpecialization:
g: NativeFunctionsViewGroup
f: NativeFunction
@property
def is_multi_output(self) -> bool:
return functionalization.is_multi_output(self.f.func)
@property
def is_as_strided(self) -> bool:
return str(self.f.func.name) == "as_strided"
@property
def out_index(self) -> str:
if self.is_multi_output:
return functionalization.out_index_binding.name
return "0"
@property
def classname(self) -> str:
return functionalization.classname(self.f.func)
def decl(self) -> list[str]:
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
attributes = functionalization.attributes(self.f.func)
# List of types for declaring the `SerializableTuple` type.
serializable_tuple_args = ",\n".join(
f" {binding.type} /* {binding.name} */"
for binding in (base_ctor_arguments + attributes)
)
# Arguments used for forwarding the tuple elements to the constructor.
destructure_tuple_args = ", ".join(
f"std::get<{i}>(tpl)"
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
)
# List of constructor parameters
ctor_parameters = ", ".join(
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
)
# Call the base class `ViewMeta` constructor.
#
# Both of `is_multi_output` and `is_as_strided` are known values, given the
# operation schema.
is_multi_output_str = str(self.is_multi_output).lower()
is_as_strided_str = str(self.is_as_strided).lower()
base_ctor_bindings = ", ".join(
[
# `has_symbolic_inputs` is always taken as parameter.
functionalization.has_symbolic_inputs_binding.name,
f"/*is_multi_output=*/{is_multi_output_str}",
f"/*is_as_strided=*/{is_as_strided_str}",
# `out_index` is know if the operation returns only one value. Otherwise,
# we also take it as parameter.
f"/*out_index=*/{self.out_index}",
]
)
# Assignments of `extra_ctor_arguments` to their corresponding fields.
# These are extra fields to-be-declared in this specialization.
#
# We need to set `allow_expensive_conversions`, since we are storing owned versions
# of the non-owning arguments.
ctor_assignments = ",\n".join(
f" {e.type.name}({e.expr})"
for e in translate(
extra_ctor_arguments,
attributes,
method=False,
allow_expensive_conversions=True,
)
)
# List of arguments for constructing the `SerializableTuple` from an instance.
tuple_arguments = ", ".join(
binding.name for binding in (base_ctor_arguments + attributes)
)
# List of field declarations.
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
# Override `to_out_index` if this operation returns more than 1 value.
to_out_index_decl = ""
if self.is_multi_output:
to_out_index_decl = (
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
)
return [
f"""
struct TORCH_API {self.classname} : public ViewMeta {{
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname});
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
{self.classname}(const SerializableTuple& tpl)
: {self.classname}({destructure_tuple_args}) {{}}
{self.classname}({ctor_parameters})
: at::functionalization::ViewMeta({base_ctor_bindings}),
{ctor_assignments} {{}}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
{to_out_index_decl}
SerializableTuple to_serializable_tuple() {{
return std::make_tuple({tuple_arguments});
}}
{attr_declarations}
}};
"""
]
# Generate a call to the actual operation.
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
opname = functionalization.name(
self.g,
is_reverse=is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
# Expected arguments for the operation.
assert self.g.view_copy is not None
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
# The context is composed by the constructor arguments (which are also
# the field variables stored in the instance), and the `base` tensor.
context = [functionalization.base_binding]
context += functionalization.base_ctor_arguments(self.f.func)
context += functionalization.attributes(self.f.func)
# If we are generating the call for the reverse function, we also have
# access to `mutated_view` argument.
if is_reverse:
context.append(functionalization.mutated_view_binding)
arguments = ", ".join(
[e.expr for e in translate(context, op_arguments, method=False)]
)
# Index the result if this operation returns multiple values.
maybe_index = ""
if not is_reverse and self.is_multi_output:
maybe_index = f"[{self.out_index}]"
return f"{opname}({arguments}){maybe_index}"
def impl(self) -> list[str]:
functions = [
f"""
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
if (reapply_views) {{
return {self.opcall(is_reverse=False, reapply_views=True)};
}} else {{
return {self.opcall(is_reverse=False, reapply_views=False)};
}}
}}""",
f"""
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
return {self.opcall(is_reverse=True, reapply_views=True)};
}}""",
]
# If this operation returns multiple values, also generate a `to_out_index`
# implementation.
if self.is_multi_output:
functions.append(f"""
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
return {self.new("out_index")};
}}
""")
return functions
# Create the Python binding for this specialized class.
def binding(self) -> list[str]:
name = functionalization.classname(self.f.func, with_namespace=True)
return [f" create_binding_with_pickle<{name}>(functionalization);"]
# Generate an instanciation of this specialized class.
def new(self, out_index: str = "0") -> str:
name = functionalization.classname(self.f.func, with_namespace=True)
ctor_arguments = functionalization.base_ctor_arguments(
self.f.func
) + functionalization.extra_ctor_arguments(self.f.func)
# Replace the `out_index` parameter with the given `out_index`.
arguments = ", ".join(
binding.name if binding.name != "out_index" else out_index
for binding in ctor_arguments
)
return f"std::make_shared<{name}>({arguments})"
# Run the function `run` for both: `view` and `view_inplace` functions.
@staticmethod
def map(
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
) -> list[str]:
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
if f is None:
return []
with native_function_manager(f):
return run(ViewMetaSpecialization(g, f))
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
def gen_functionalization_view_meta_classes_base(
selector: SelectiveBuilder,
g: NativeFunctionsViewGroup,
run: Callable[[ViewMetaSpecialization], list[str]],
) -> list[str]:
if not selector.include_all_operators:
return []
if g.composite:
return []
return ViewMetaSpecialization.map(g, run)
def gen_functionalization_view_meta_classes_decl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.decl
)
def gen_functionalization_view_meta_classes_impl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.impl
)
def gen_functionalization_view_meta_classes_binding(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.binding
)
# Generates the Python bindings for the `ViewMeta` specialized classes.
def gen_functionalization_view_meta_classes(
native_functions_path: str,
tags_path: str,
selector: SelectiveBuilder,
install_dir: str,
template_dir: str,
) -> None:
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
# Parse the native_functions.yaml.
# Then, group them into `NativeFunctionsViewGroup`.
#
# This is the same steps we do in gen.py (ATen codegen).
native_functions = parse_native_yaml(
native_functions_path, tags_path
).native_functions
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
fm.write(
"ViewMetaClassesPythonBinding.cpp",
lambda: {
"view_meta_bindings": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_binding(
selector, g
),
view_groups,
)
),
},
)
def gen_functionalization_registration(
selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,