Make functionalization ViewMeta serializable with pickle. (#143712)

Fix: #141974

This PR makes `ViewMeta` sequence, present in functional tensors,
serializable with pickle. In order to accomplish that, it makes
`ViewMeta` an abstract class with overridable `forward` and `reverse`
functions. In this context, each operation that once instanciated
`ViewMeta`, should now create a new specialized class that inherits from
`ViewMeta. Therefore, this PR also uses codegen for creating these
specializations.

In summary, these are the changes this PR introduces:

- `ViewMeta` is turned into an abstract class (see
  _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual
  functions that need to be implemented. `to_out_index` should be
  implemented by operations that might return more than 1 output.

- New `ViewMeta` specializations for `resize_` and `_unsafe_view` are
  created (see _FunctionalizeFallbackKernel.h_).

- New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the
  declaration and definition of the `ViewMeta` specializations, which
  are automatically generated in the ATen codegen (see _gen.py_).

- New `_functionalization` Python sub-module is created (see
  _Module.cpp_). It serves as namespace for the `ViewMeta`
  specializations and `InverseReturnMode` enum.

- New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds
  the automatically generated Python bindings for the `ViewMeta`
  specialization, which are generated in the torch codegen (see
  _generate_code.py_).

Note that this PR makes use of codegen at 2 different moments:

- ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes.
- Torch codegen (_generate_code.py_): generated the Python bindings for
  them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712
Approved by: https://github.com/bdhirsh
This commit is contained in:
Yukio Siraichi
2025-01-16 09:22:22 -03:00
committed by PyTorch MergeBot
parent 7c3aa1da1c
commit b8abdaa286
35 changed files with 951 additions and 425 deletions

1
.gitignore vendored
View File

@ -79,6 +79,7 @@ torch/return_types.pyi
torch/nn/functional.pyi torch/nn/functional.pyi
torch/utils/data/datapipes/datapipe.pyi torch/utils/data/datapipes/datapipe.pyi
torch/csrc/autograd/generated/* torch/csrc/autograd/generated/*
torch/csrc/functionalization/generated/*
torch/csrc/lazy/generated/*.[!m]* torch/csrc/lazy/generated/*.[!m]*
torch_compile_debug/ torch_compile_debug/
# Listed manually because some files in this directory are not generated # Listed manually because some files in this directory are not generated

View File

@ -90,6 +90,8 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/VmapGeneratedPlumbing.h", "aten/src/ATen/VmapGeneratedPlumbing.h",
"aten/src/ATen/ViewMetaClasses.h",
"aten/src/ATen/ViewMetaClasses.cpp",
"aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorBody.h",
@ -1087,6 +1089,7 @@ test_suite(
"aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
"aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml", "aten/src/ATen/native/ts_native_functions.yaml",

View File

@ -9,11 +9,6 @@
namespace at::functionalization { namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}
// Note [Functionalization: Alias Removal Part 2] // Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details. // See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl. // This function applies a single update from one of the views to the StorageImpl.
@ -47,7 +42,7 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
std::vector<at::Tensor> tmp_values({base}); std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size()); tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor. // All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values // If need to, we could probably apply this optimization and only bother computing tmp_values
@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
tmp_values.push_back(std::move(next_view)); tmp_values.push_back(std::move(next_view));
} }
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) { for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp. // Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); t = update.view_metas[i]->reverse(tmp_values[i], t);
} }
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t; return t;
@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
} }
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) { void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
if (metas.size() > 1) { if (metas.size() > 1) {
for (size_t i = 1; i < metas.size(); ++i) { for (size_t i = 1; i < metas.size(); ++i) {
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "

View File

@ -8,44 +8,89 @@ namespace at::functionalization {
// See Note [Functionalization Pass In Core] // See Note [Functionalization Pass In Core]
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying)
/// scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to
/// handle.
ViewOrScatterInverse,
};
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
static const char* name() { \
return #TYPE; \
}
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
using SerializableTuple = std::tuple<__VA_ARGS__>;
// ViewMeta is a class used by the functionalization pass to navigate between // ViewMeta is a class used by the functionalization pass to navigate between
// a base tensor and a view tensor. // a base tensor and a view tensor.
// For example, if I call `b = a.view1(...)` // For example, if I call `b = a.view1(...)`
// the functionalization pass will generate and store a ViewMeta on b that looks // the functionalization pass will generate and store a ViewMeta specialization
// like: // for `view1` operation on b that looks like:
// //
// ViewMeta( // struct TORCH_API view1_ViewMeta : public ViewMeta {
// [<captures>](const Tensor& base, int64_t mutated_view_idx) { // FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
// return base.view1(...); // FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
// }, // bool /* reapply_views */,
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view, // const std::vector<int64_t>&);
// int64_t mutated_view_idx) -> at::Tensor { //
// return at::functionalization::impl::view1_inverse(base, mutated_view, // view1_ViewMeta(const SerializableTuple& tpl)
// ...); // : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
//
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
// : ViewMeta(/*has_symbolic_inputs=*/false),
// reapply_views(reapply_views),
// size(size) {}
//
// Tensor forward(const Tensor& base) override {
// return base.view1(...);
// } // }
// //
// The forward_fn lambda describes how to replay view1 on a tensor. // Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
// return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...);
// }
// //
// The reverse_fn lambda describes how, given a tensor that is already a view, // SerializableTuple to_serializable_tuple() {
// return std::make_tuple(reapply_views, size);
// }
//
// bool reapply_views;
// std::vector<int64_t> size;
// };
//
// The forward function describes how to replay view1 on a tensor.
//
// The reverse function describes how, given a tensor that is already a view,
// how to get the corresponding base tensor. See Note [Functionalization Pass: // how to get the corresponding base tensor. See Note [Functionalization Pass:
// View Inverses] for details. // View Inverses] for details.
//
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
// representing the `ViewMeta` instance state. Methods that take in/return such
// a type are used for supporting pickle serialization.
struct ViewMeta { struct ViewMeta {
ViewMeta( ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs, bool has_symbolic_inputs,
bool is_multi_output = false, bool is_multi_output = false,
bool is_as_strided = false, bool is_as_strided = false,
int64_t out_idx = 0) int64_t out_idx = 0)
: forward_fn(std::move(forward)), : out_index(out_idx),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output), is_multi_output(is_multi_output),
is_as_strided(is_as_strided), is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {} has_symbolic_inputs(has_symbolic_inputs) {}
std::function<Tensor(const Tensor&, int64_t)> forward_fn; virtual ~ViewMeta() {}
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
virtual Tensor forward(const Tensor& base) = 0;
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
// See Note [out_idx in ViewMeta] // See Note [out_idx in ViewMeta]
int64_t out_index; int64_t out_index;
@ -57,10 +102,17 @@ struct ViewMeta {
// Tells us if this view operation has any symbolic inputs // Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs; bool has_symbolic_inputs;
// Returns a copy of the current ViewMeta, if out_idx matches the current // Returns a new ViewMeta with the same forward/reverse
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index. // functions, but a new out index.
ViewMeta to_out_idx(int64_t out_idx); //
// This method should be implemented by those `ViewMeta` that have more than
// one output.
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"ViewMeta::to_out_index not implemented. ",
"Likely because there's only one output.");
}
}; };
// FunctionalStorageImpl is a subclass of StorageImpl used by the // FunctionalStorageImpl is a subclass of StorageImpl used by the
@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor new_val; const at::Tensor new_val;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<ViewMeta> view_metas; const std::vector<std::shared_ptr<ViewMeta>> view_metas;
}; };
explicit FunctionalStorageImpl(const Tensor& value); explicit FunctionalStorageImpl(const Tensor& value);
void add_update( void add_update(
const Tensor& updated_val, const Tensor& updated_val,
const std::vector<ViewMeta>& view_metas); const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
bool apply_updates(); bool apply_updates();
const Tensor& base() { const Tensor& base() {
return base_; return base_;

View File

@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const {
// - view_value: The output tensor that we need to wrap. // - view_value: The output tensor that we need to wrap.
// - base: The "base" of the view that `view_value` was generated from. // - base: The "base" of the view that `view_value` was generated from.
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) FunctionalTensorWrapper::FunctionalTensorWrapper(
: c10::TensorImpl( const Tensor& view_value,
c10::DispatchKeySet(DispatchKey::Functionalize), const FunctionalTensorWrapper* base,
view_value.dtype(), const std::shared_ptr<functionalization::ViewMeta>& meta)
view_value.device() : c10::TensorImpl(
), c10::DispatchKeySet(DispatchKey::Functionalize),
value_(view_value), view_value.dtype(),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), view_value.device()),
was_storage_changed_(base->was_storage_changed_), value_(view_value),
is_symbolic_(base->is_symbolic_) is_multi_output_view_(
{ base->is_multi_output_view_ || meta->is_multi_output),
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata(); set_constructor_metadata();
@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_metas_ = base->view_metas_; // copy view_metas_ = base->view_metas_; // copy
} }
view_metas_.push_back(meta); view_metas_.push_back(meta);
maybe_mark_symbolic(meta); maybe_mark_symbolic(meta.get());
storage_ = base->storage_; // alias this tensor's storage with the base tensor's storage_ = base->storage_; // alias this tensor's storage with the base tensor's
} }
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl()); return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
} }
@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
} }
// See Note [Functionalization Pass - Inplace View Ops] // See Note [Functionalization Pass - Inplace View Ops]
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
view_metas_.push_back(meta); view_metas_.push_back(meta);
// Manually track the fact that this tensor recieved a metadata mutation! // Manually track the fact that this tensor recieved a metadata mutation!
has_metadata_mutation_ = true; has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta); maybe_mark_symbolic(meta.get());
// Note [Functionalization Pass - Inplace View Ops] // Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen. // So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()` // An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
value_ = meta.forward_fn(value_, meta.out_index); value_ = meta->forward(value_);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
} }
@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() {
regenerate_from_base(); regenerate_from_base();
} }
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
auto t = base; return view_metas_;
// Reapply views to get the viewed tensor from the base in alias_
for (auto& view_meta: view_metas_) {
t = view_meta.forward_fn(t, view_meta.out_index);
}
return t;
} }
void FunctionalTensorWrapper::regenerate_from_base() { void FunctionalTensorWrapper::regenerate_from_base() {
@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
auto t = storage_impl->base(); auto t = storage_impl->base();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
t = apply_view_metas(t); t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t, /*from_lazy_regenerate=*/true); replace_(t, /*from_lazy_regenerate=*/true);
@ -759,20 +753,28 @@ void freeze_functional_tensor(const Tensor& tensor) {
functional_base_impl->freeze_storage(); functional_base_impl->freeze_storage();
} }
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { Tensor create_functional_tensor_with_view_meta(
const at::Tensor& view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
auto meta_ = meta;
if (out_idx != 0) { if (out_idx != 0) {
// Note [out_idx in ViewMeta] // Note [out_idx in ViewMeta]
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
meta = meta.to_out_idx(out_idx); meta_ = meta->to_out_index(out_idx);
} }
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta); return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
} }
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta) {
std::vector<Tensor> outputs(view_to_wrap.size()); std::vector<Tensor> outputs(view_to_wrap.size());
int64_t i = 0; int64_t i = 0;
for (const auto& tensor : view_to_wrap) { for (const auto& tensor : view_to_wrap) {
@ -782,12 +784,22 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_
return outputs; return outputs;
} }
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
self_impl->mutate_view_meta(meta); self_impl->mutate_view_meta(meta);
} }
Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
Tensor r = base;
for (auto& vm : sequence) {
r = vm->forward(r);
}
return r;
}
// Note [Propagating strides in the functionalization pass] // Note [Propagating strides in the functionalization pass]
// In order to properly compute stride information, the functionalization pass // In order to properly compute stride information, the functionalization pass
// calls each {view} reference implementations with meta tensors. // calls each {view} reference implementations with meta tensors.

View File

@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper( explicit FunctionalTensorWrapper(
const Tensor& view_value, const Tensor& view_value,
const FunctionalTensorWrapper* base, const FunctionalTensorWrapper* base,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
// Get the underlying, actual tensor, that doesn't know anything about // Get the underlying, actual tensor, that doesn't know anything about
// functionalization. // functionalization.
@ -97,17 +97,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode(); ->are_all_mutations_under_no_grad_or_inference_mode();
} }
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
} }
bool is_symbolic() const { bool is_symbolic() const {
return is_symbolic_; return is_symbolic_;
} }
// Runs the forward_fn of every ViewMeta collected in the current instance // Retrieves the ViewMeta sequence of this tensor.
// to some other base. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
Tensor apply_view_metas(const Tensor& base); const;
// Sync's the underlying tensor with its alias, if it's out of date. This // Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2) // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
@ -144,7 +144,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// from the base tensor. This method is used by inplace-view ops like // from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias. // tensor by replaying the views off of the alias.
void mutate_view_meta(const at::functionalization::ViewMeta& meta); void mutate_view_meta(
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
// Custom implementation of self.set_(src) // Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other); void set__impl(const FunctionalTensorWrapper* other);
@ -273,7 +274,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_symbolic_ = false; bool is_symbolic_ = false;
size_t generation_ = 0; size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_; std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
protected: protected:
static void copy_tensor_metadata( static void copy_tensor_metadata(
@ -365,16 +366,20 @@ TORCH_API void propagate_xla_data_direct(
Tensor create_functional_tensor_with_view_meta( Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap, const Tensor& view_to_wrap,
const Tensor& base, const Tensor& base,
functionalization::ViewMeta meta, const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx = 0); int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta( std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap, ITensorListRef view_to_wrap,
const Tensor& base, const Tensor& base,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
void mutate_view_meta( void mutate_view_meta(
const Tensor& self, const Tensor& self,
const functionalization::ViewMeta& meta); const std::shared_ptr<functionalization::ViewMeta>& meta);
TORCH_API Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset( void set_sizes_strides_offset(

View File

@ -1,3 +1,5 @@
#include <ATen/FunctionalizeFallbackKernel.h>
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h> #include <ATen/EmptyTensor.h>
@ -27,6 +29,31 @@
#include <utility> #include <utility>
#endif #endif
namespace at::functionalization {
Tensor resize__ViewMeta::forward(const Tensor& base) {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
}
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return base.as_strided_scatter(
mutated_view, size, c10::contiguous_strides(size));
}
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
return at::_unsafe_view_symint(base, size);
}
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
} // namespace at::functionalization
namespace { namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
const auto& schema = op.schema(); const auto& schema = op.schema();
@ -168,19 +195,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
// The output of resizing is equivalent to taking a slice of a larger tensor. // The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call. // We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { reapply_views, size.vec());
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta); at::functionalization::impl::mutate_view_meta(self, view_meta);
return self; return self;
} }
@ -299,17 +315,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size); tmp_output = at::_unsafe_view_symint(self_, size);
} }
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); bool has_symbolic_inputs = std::any_of(
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta =
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
return at::_unsafe_view_symint(base, size); has_symbolic_inputs, size.vec());
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
// See Note [Propagating strides in the functionalization pass] // See Note [Propagating strides in the functionalization pass]

View File

@ -0,0 +1,58 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
namespace at::functionalization {
// `ViewMeta` implementation for `resize_` operation.
struct TORCH_API resize__ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta);
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* reapply_views */,
const std::vector<int64_t>&);
resize__ViewMeta(const SerializableTuple& tpl)
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
: ViewMeta(/*has_symbolic_inputs=*/false),
reapply_views(reapply_views),
size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(reapply_views, size);
}
bool reapply_views;
std::vector<int64_t> size;
};
// `ViewMeta` implementation for `_unsafe_view` operation.
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta);
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* has_symbolic_inputs */,
const std::vector<c10::SymInt>&);
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
_unsafe_view_ViewMeta(
bool has_symbolic_inputs,
const std::vector<c10::SymInt>& size)
: ViewMeta(has_symbolic_inputs), size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(has_symbolic_inputs, size);
}
std::vector<c10::SymInt> size;
};
} // namespace at::functionalization

View File

@ -2,22 +2,12 @@
// ${generated_comment} // ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
namespace at { namespace at {
namespace functionalization { namespace functionalization {
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying) scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
ViewOrScatterInverse,
};
struct FunctionalInverses { struct FunctionalInverses {
${view_inverse_declarations} ${view_inverse_declarations}

View File

@ -4,7 +4,7 @@
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h> #include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h> #include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalInverses.h> #include <ATen/ViewMetaClasses.h>
#include <ATen/MemoryOverlap.h> #include <ATen/MemoryOverlap.h>
#include <torch/library.h> #include <torch/library.h>

View File

@ -0,0 +1,19 @@
// ${generated_comment}
#include <ATen/FunctionalInverses.h>
#include <ATen/ViewMetaClasses.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#include <ATen/NativeFunctions.h>
#else
${op_headers}
#endif
namespace at {
namespace functionalization {
${view_meta_implementations}
} // namespace functionalization
} // namespace at

View File

@ -0,0 +1,12 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
namespace at {
namespace functionalization {
${view_meta_declarations}
} // namespace functionalization
} // namespace at

View File

@ -0,0 +1,11 @@
#include <ATen/ViewMetaClasses.h>
#include <torch/csrc/functionalization/Module.h>
namespace torch::functionalization {
void initGenerated(PyObject* module) {
auto functionalization = py::handle(module).cast<py::module>();
$view_meta_bindings
}
} // namespace torch::functionalization

View File

@ -117,6 +117,7 @@ def define_targets(rules):
":LazyNonNativeIr.h", ":LazyNonNativeIr.h",
":RegisterDispatchDefinitions.ini", ":RegisterDispatchDefinitions.ini",
":RegisterDispatchKey.cpp", ":RegisterDispatchKey.cpp",
":ViewMetaClassesPythonBinding.cpp",
":native_functions.yaml", ":native_functions.yaml",
":shape_inference.h", ":shape_inference.h",
":tags.yaml", ":tags.yaml",
@ -297,6 +298,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
] ]
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP

View File

@ -929,6 +929,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/disable_torch_function.cpp",
"torch/csrc/utils/verbose.cpp", "torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp", "torch/csrc/cpu/Module.cpp",
"torch/csrc/functionalization/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp", "torch/csrc/instruction_counter/Module.cpp",
] + lazy_tensor_core_python_sources ] + lazy_tensor_core_python_sources

View File

@ -310,6 +310,7 @@ set(GENERATED_CXX_PYTHON
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
) )
set(GENERATED_H_PYTHON set(GENERATED_H_PYTHON
@ -373,6 +374,7 @@ add_custom_command(
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
${autograd_python} ${autograd_python}
${autograd_yaml} ${autograd_yaml}
${autograd_templates} ${autograd_templates}

View File

@ -250,11 +250,7 @@ class AOTAutogradCacheTests(InductorTestCase):
@functorch_config.patch( @functorch_config.patch(
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
) )
def test_view_replay_bypass(self): def test_view_replay(self):
"""
Shoud bypass when view replay is turned on
"""
def fn(a): def fn(a):
tmp = a.detach() tmp = a.detach()
a.mul_(2) a.mul_(2)
@ -262,10 +258,25 @@ class AOTAutogradCacheTests(InductorTestCase):
with torch.autograd._force_original_view_tracking(True): with torch.autograd._force_original_view_tracking(True):
compiled_fn = torch.compile(fn) compiled_fn = torch.compile(fn)
compiled_fn(torch.rand(2, 3))
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) def run_and_check(miss, hit, bypass):
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) self._clear_dynamo_and_codecache()
inp = torch.rand(2, 3)
compiled_inp = inp.clone().detach()
with torch.autograd._force_original_view_tracking(True):
out = fn(inp)
compiled_out = compiled_fn(compiled_inp)
self.assertEqual(out, compiled_out)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
run_and_check(miss=1, hit=0, bypass=0)
run_and_check(miss=1, hit=1, bypass=0)
run_and_check(miss=1, hit=2, bypass=0)
@inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", False) @inductor_config.patch("fx_graph_cache", False)

View File

@ -6897,7 +6897,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
{ {
"enable_autograd_cache": True, "enable_autograd_cache": True,
"strict_autograd_cache": True, "strict_autograd_cache": True,
"view_replay_for_aliased_outputs": False,
} }
) )
@torch._inductor.config.patch("fx_graph_cache", True) @torch._inductor.config.patch("fx_graph_cache", True)

View File

@ -189,6 +189,12 @@ def main() -> None:
) )
options = parser.parse_args() options = parser.parse_args()
# Path: aten/src/ATen
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
operator_selector = get_selector(
options.selected_op_list_path, options.operators_yaml_path
)
generate_code( generate_code(
options.gen_dir, options.gen_dir,
options.native_functions_path, options.native_functions_path,
@ -198,13 +204,32 @@ def main() -> None:
options.disable_autograd, options.disable_autograd,
options.force_schema_registration, options.force_schema_registration,
# options.selected_op_list # options.selected_op_list
operator_selector=get_selector( operator_selector=operator_selector,
options.selected_op_list_path, options.operators_yaml_path )
),
# Generate the python bindings for functionalization's `ViewMeta` classes.
from torchgen.gen_functionalization_type import (
gen_functionalization_view_meta_classes,
)
functionalization_templates_dir = os.path.join(aten_path, "templates")
functionalization_install_dir = os.path.join(
options.gen_dir, "torch/csrc/functionalization/generated"
)
os.makedirs(functionalization_install_dir, exist_ok=True)
assert os.path.isdir(functionalization_install_dir)
assert os.path.isdir(functionalization_templates_dir)
gen_functionalization_view_meta_classes(
options.native_functions_path or NATIVE_FUNCTIONS_PATH,
options.tags_path or TAGS_PATH,
selector=operator_selector,
install_dir=functionalization_install_dir,
template_dir=functionalization_templates_dir,
) )
if options.gen_lazy_ts_backend: if options.gen_lazy_ts_backend:
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"

View File

@ -67,6 +67,7 @@ from . import (
_export, _export,
_cpu, _cpu,
_dynamo, _dynamo,
_functionalization,
_functorch, _functorch,
_lazy, _lazy,
_lazy_ts_backend, _lazy_ts_backend,

View File

@ -0,0 +1,16 @@
from torch import Tensor
from torch.types import _bool
# Defined in torch/csrc/functionalization/Module.cpp
class ViewMeta:
has_symbolic_inputs: _bool
# Returns the list of ViewMeta instances of the given functional tensor.
#
# Although we do have python bindings for their types, we won't
# expose them here, since they should not be used by users.
def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ...
# Applies the ViewMeta sequence on top of the given base.
def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ...

View File

@ -227,19 +227,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
check_node_safe(node) check_node_safe(node)
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
"""
When view replay is turned on, we bypass autograd cache if
the output is aliased.
"""
if config.view_replay_for_aliased_outputs:
for info in metadata.output_info:
if info.functional_tensor is not None:
raise BypassAOTAutogradCache(
"Cannot cache a graph with functional tensor"
)
class AOTAutogradCacheDetails(FxGraphHashDetails): class AOTAutogradCacheDetails(FxGraphHashDetails):
""" """
Object to capture all the details for a dynamo graph module relevant to computing Object to capture all the details for a dynamo graph module relevant to computing
@ -875,7 +862,6 @@ class AOTAutogradCache:
def save(key: str, entry: AOTAutogradCacheEntry, remote: bool): def save(key: str, entry: AOTAutogradCacheEntry, remote: bool):
"""Save a single entry into the cache.""" """Save a single entry into the cache."""
try: try:
check_metadata_cacheable(entry.runtime_metadata)
content = pickle.dumps(entry) content = pickle.dumps(entry)
CacheArtifactManager.record_artifact( CacheArtifactManager.record_artifact(
CacheArtifactType.AOT_AUTOGRAD, key, content CacheArtifactType.AOT_AUTOGRAD, key, content

View File

@ -36,10 +36,10 @@ from .functional_utils import (
has_metadata_mutation, has_metadata_mutation,
MetadataKey, MetadataKey,
to_fun, to_fun,
ViewMetaSequence,
was_inductor_storage_resized, was_inductor_storage_resized,
) )
from .schemas import ( from .schemas import (
FunctionalTensorMetadataEq,
InputAliasInfo, InputAliasInfo,
MutationType, MutationType,
OutputAliasInfo, OutputAliasInfo,
@ -604,7 +604,7 @@ from a multi-output view call"
# #
# The FunctionalTensor will be saved if one of the 2 conditions below # The FunctionalTensor will be saved if one of the 2 conditions below
# is true: # is true:
functional_tensor = None view_meta_sequence = None
if ( if (
# 1. If the output_type is either of: # 1. If the output_type is either of:
# (i) alias_of_intermediate; # (i) alias_of_intermediate;
@ -636,7 +636,7 @@ from a multi-output view call"
and not input_info[base_idx].mutates_metadata and not input_info[base_idx].mutates_metadata
): ):
if isinstance(o, FunctionalTensor): if isinstance(o, FunctionalTensor):
functional_tensor = FunctionalTensorMetadataEq(o.elem) view_meta_sequence = ViewMetaSequence(o)
out_info = OutputAliasInfo( out_info = OutputAliasInfo(
output_type=output_type, output_type=output_type,
@ -644,7 +644,7 @@ from a multi-output view call"
base_idx=base_idx, base_idx=base_idx,
dynamic_dims=dynamic_dims, dynamic_dims=dynamic_dims,
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
functional_tensor=functional_tensor, view_meta_sequence=view_meta_sequence,
) )
output_info.append(out_info) output_info.append(out_info)

View File

@ -13,15 +13,12 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch._C import _functionalization
from torch._logging import getArtifactLogger from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.functional_tensor import FunctionalTensor
from torch._subclasses.meta_utils import is_sparse_any from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
definitely_true,
sym_eq,
SymIntEqByExpr,
)
from torch.multiprocessing.reductions import StorageWeakRef from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import ( from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass, is_traceable_wrapper_subclass,
@ -227,9 +224,9 @@ def gen_alias_from_base(
aliased_base_tensor, aliased_base_tensor,
target_meta_tensor, target_meta_tensor,
target_requires_grad, target_requires_grad,
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, target_view_meta_sequence: Optional[ViewMetaSequence] = None,
*, *,
replay_views, replay_views: bool,
): ):
# Patch the correct requires_grad field of the output tensor, depending on whether: # Patch the correct requires_grad field of the output tensor, depending on whether:
# (i) the reconstructed output (out) was came from a tensor that requires grad or not; # (i) the reconstructed output (out) was came from a tensor that requires grad or not;
@ -248,13 +245,11 @@ def gen_alias_from_base(
# to replay them (view functions) on the aliased_base_tensor. # to replay them (view functions) on the aliased_base_tensor.
if ( if (
replay_views replay_views
and target_functional_tensor is not None and target_view_meta_sequence is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence)
): ):
functional_tensor = target_functional_tensor.tensor out = _functionalization.apply_view_meta_sequence(
aliased_base_tensor, target_view_meta_sequence.sequence
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
) )
# If re-applying the ViewMeta sequence succeeded, there should be no more # If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and # problems going forward. We just check we got to the target shape and
@ -315,28 +310,8 @@ def gen_alias_from_base(
return aliased_out return aliased_out
def has_same_metadata(t1, t2):
return (
definitely_true(sym_eq(t1.size(), t2.size()))
and definitely_true(t1.layout == t2.layout)
and (
is_sparse_any(t1)
or (
definitely_true(sym_eq(t1.stride(), t2.stride()))
and definitely_true(t1.storage_offset() == t2.storage_offset())
)
)
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
@dataclass(frozen=True) @dataclass(frozen=True)
class MetadataKey: class MetadataKey:
"""
This should be equal whenever has_same_metadata would return True
"""
size: Tuple[SymIntEqByExpr, ...] size: Tuple[SymIntEqByExpr, ...]
layout: torch.layout layout: torch.layout
is_sparse: bool is_sparse: bool
@ -360,25 +335,45 @@ class MetadataKey:
) )
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata # ViewMeta sequence wrapper for equality comparisons.
# after applying all the ViewMeta operations. #
class FunctionalTensorMetadataEq: # Even though we can compare each ViewMeta instance, we compare the resulting
def __init__(self, tensor: torch.Tensor) -> None: # tensor metadata, instead. That's because the creation of synthetic bases + the
assert torch._is_functional_tensor(tensor) # re-generation of input views might end-up creating a different sequence of
self.tensor = tensor # ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same
# metadata.
#
# Therefore, we store what the end result should look like as serializable
# metadata.
#
# When logging, this class should look like:
#
# ViewMetaSequence(view, select_int, slice_Tensor)
#
# i.e. a parenthesized list of view operations within that ViewMeta sequence.
class ViewMetaSequence:
def __init__(self, tensor: FunctionalTensor) -> None:
assert torch._is_functional_tensor(tensor.elem)
self.sequence = _functionalization.get_view_meta_sequence(tensor.elem)
self.metadata = MetadataKey.make(tensor)
def __repr__(self) -> str:
suffix = len("_ViewMeta")
types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence)
return f"ViewMetaSequence({types})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
# If other is None, then it probably means that we weren't able to recreate # If other is None, then it probably means that we weren't able to recreate
# the FunctionalTensorMetadataEq. One of this cases is when we update the # the ViewMeta sequence. One example is when we update the view metadata by
# view metadata by calling: create_synthetic_base_metadata. # calling: create_synthetic_base_metadata.
if other is None: if other is None:
return True return True
# Comparison agains any other type is not implemented. # Comparison against any other type is not implemented.
if not isinstance(other, FunctionalTensorMetadataEq): if not isinstance(other, ViewMetaSequence):
return NotImplemented return NotImplemented
return has_same_metadata(self.tensor, other.tensor) return self.metadata == other.metadata
# new_arg and arg here are either: # new_arg and arg here are either:

View File

@ -75,7 +75,7 @@ def remove_dupe_metadata(
dynamic_dims=o.dynamic_dims, dynamic_dims=o.dynamic_dims,
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
requires_grad=o.requires_grad, requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor, view_meta_sequence=o.view_meta_sequence,
) )
for o in m.output_info for o in m.output_info
], ],
@ -226,7 +226,7 @@ def create_synthetic_base_metadata(
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
base_idx=new_base_idx, # type: ignore[arg-type] base_idx=new_base_idx, # type: ignore[arg-type]
requires_grad=o.requires_grad, requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor, view_meta_sequence=o.view_meta_sequence,
) )
) )

View File

@ -172,7 +172,7 @@ class AliasOfInputHandler:
self.base_idx = info.base_idx self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor self.view_meta_sequence = info.view_meta_sequence
self.replay_views = config.view_replay_for_aliased_outputs self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out): def __call__(self, orig_inputs, fw_outs, out):
@ -181,7 +181,7 @@ class AliasOfInputHandler:
aliased_base_tensor, aliased_base_tensor,
self.unwrap_out(out), self.unwrap_out(out),
self.requires_grad, self.requires_grad,
self.functional_tensor, self.view_meta_sequence,
replay_views=self.replay_views, replay_views=self.replay_views,
) )
@ -209,7 +209,7 @@ class AliasOfIntermediateHandler:
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor self.view_meta_sequence = info.view_meta_sequence
self.replay_views = config.view_replay_for_aliased_outputs self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out): def __call__(self, orig_inputs, fw_outs, out):
@ -218,7 +218,7 @@ class AliasOfIntermediateHandler:
aliased_base_tensor, aliased_base_tensor,
self.unwrap_out(out), self.unwrap_out(out),
self.requires_grad, self.requires_grad,
self.functional_tensor, self.view_meta_sequence,
replay_views=self.replay_views, replay_views=self.replay_views,
) )

View File

@ -5,7 +5,6 @@ input/output types, metadata, config, function signatures etc.
""" """
import collections import collections
import dataclasses
import functools import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -20,10 +19,7 @@ from torch._subclasses.fake_tensor import is_fake
from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config from .. import config
from .functional_utils import ( from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence
_check_if_mutation_can_be_in_graph,
FunctionalTensorMetadataEq,
)
from .utils import strict_zip from .utils import strict_zip
@ -92,15 +88,14 @@ class OutputAliasInfo:
dynamic_dims: Optional[Set[int]] dynamic_dims: Optional[Set[int]]
# requires_grad # requires_grad
requires_grad: bool requires_grad: bool
# FunctionalTensorWrapper that represents this output. # Sequence of ViewMeta objects.
# #
# Provides us the means to replay views from it. # Provides us the means to re-run view functions on other tensors.
# #
# We need to wrap the actual FunctionalTensorWrapper with this class so that # We need to wrap the actual list of ViewMeta with this class so that
# we only compare the tensor's metadata. That's because with the transformations # we compare the ViewMeta elements appropriately, i.e. their type and
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base # the elements returned by the `as_tuple()` call.
# tensor might change. view_meta_sequence: Optional[ViewMetaSequence] = None
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
class MutationType(Enum): class MutationType(Enum):
@ -582,17 +577,6 @@ class ViewAndMutationMeta:
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
# Clear traced tangents at runtime # Clear traced tangents at runtime
self.traced_tangents = [] self.traced_tangents = []
new_output_info = []
for out in self.output_info:
if config.view_replay_for_aliased_outputs:
new_out = out
else:
# If we're not using view_replay, remove the functional tensor.
# Functional tensors are unfortunately not serializable,
# so doing this is required for AOTAutograd caching.
new_out = dataclasses.replace(out, functional_tensor=None)
new_output_info.append(new_out)
self.output_info = new_output_info
for inp_meta in self.subclass_inp_meta: for inp_meta in self.subclass_inp_meta:
if isinstance(inp_meta, SubclassCreationMeta): if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe() inp_meta.make_runtime_safe()

View File

@ -71,6 +71,7 @@
#include <torch/csrc/cpu/Module.h> #include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h> #include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h> #include <torch/csrc/export/pybind.h>
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/functorch/init.h> #include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h> #include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h> #include <torch/csrc/inductor/aoti_package/pybind.h>
@ -1869,6 +1870,7 @@ PyObject* initModule() {
torch::instruction_counter::initModule(module); torch::instruction_counter::initModule(module);
torch::initVerboseBindings(module); torch::initVerboseBindings(module);
ASSERT_TRUE(THPStorage_init(module)); ASSERT_TRUE(THPStorage_init(module));
torch::functionalization::initModule(module);
#ifdef USE_CUDA #ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace // This will only initialise base classes and attach them to library namespace

View File

@ -633,15 +633,6 @@ void initTorchFunctions(PyObject* module) {
at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::isFunctionalTensor(t));
at::functionalization::impl::mark_mutation_hidden_from_autograd(t); at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
}); });
py_module.def(
"_functionalize_apply_view_metas",
[](const at::Tensor& tensor, const at::Tensor& base) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl =
at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->apply_view_metas(base);
});
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);

View File

@ -0,0 +1,71 @@
#include <torch/csrc/functionalization/Module.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalizeFallbackKernel.h>
#include <memory>
namespace torch::functionalization {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
// Create a `torch._C._functionalization` Python module.
auto functionalization = m.def_submodule(
"_functionalization", "functionalization related pybind.");
// Retrieve the ViewMeta sequence of a given functional tensor.
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
return impl->view_metas();
});
// Applies the given ViewMeta sequence to the given base.
functionalization.def(
"apply_view_meta_sequence",
[](const at::Tensor& base,
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
sequence) {
return at::functionalization::impl::apply_view_meta_sequence(
base, sequence);
});
// Binding for InverseReturnMode.
py::enum_<at::functionalization::InverseReturnMode>(
functionalization, "InverseReturnMode")
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
.value(
"ViewOrScatterInverse",
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
// Create bindings for the ViewMeta base class.
//
// Needed so that we can take a list of ViewMeta objects as parameter.
// Specifically, in the Python-side, we will have a list of derived ViewMeta
// classes. We need to tell pybind11 that all of those are, in fact, instances
// of different ViewMeta sub-types.
py::class_<
at::functionalization::ViewMeta,
std::shared_ptr<at::functionalization::ViewMeta>>(
functionalization, "ViewMeta")
.def_property_readonly(
"has_symbolic_inputs",
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
return meta->has_symbolic_inputs;
});
// Bindings for `ViewMeta` specializations manually implemented.
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
functionalization);
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
functionalization);
// Bindings for `ViewMeta` specializations automatically generated.
initGenerated(functionalization.ptr());
}
} // namespace torch::functionalization

View File

@ -0,0 +1,36 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::functionalization {
// Creates the default bindings for `ViewMeta` specializations.
//
// Defines a constructor using the types in `SerializableTuple`, as well
// as pickle methods.
template <class T>
void create_binding_with_pickle(py::module m) {
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
m, T::name())
.def(py::init<typename T::SerializableTuple>())
.def(
"as_tuple",
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
})
.def(py::pickle(
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
},
[](const typename T::SerializableTuple& tpl) {
return std::make_shared<T>(tpl);
}));
}
void initModule(PyObject* module);
void initGenerated(PyObject* module);
} // namespace torch::functionalization

View File

@ -23,20 +23,13 @@ from torchgen.model import (
# This file describes the translation of JIT schema to API's used # This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass. # when creating `ViewMeta` specializations that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas. # These API's mostly follow the dispatcher API, with one difference:
# These API's mostly follow the dispatcher API, with a few quirks: # - While the forward function just directly calls into the at::_ops API
# - The lambda capture has to convert reference types to value types # (following the dispatcher convention), the logic here for the reverse function
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations # is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace). # (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments. # Define some specific lambda input arguments.
base_binding = Binding( base_binding = Binding(
name="base", name="base",
@ -46,6 +39,18 @@ base_binding = Binding(
), ),
default=None, default=None,
) )
has_symbolic_inputs_binding = Binding(
name="has_symbolic_inputs",
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
argument=Argument(
name="has_symbolic_inputs",
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
mutated_view_binding = Binding( mutated_view_binding = Binding(
name="mutated_view", name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
), ),
default=None, default=None,
) )
mutated_view_idx_binding = Binding( out_index_binding = Binding(
name="mutated_view_idx", name="out_index",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), nctype=NamedCType(name="out_index", type=BaseCType(longT)),
argument=Argument( argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
), ),
default=None, default=None,
) )
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
) )
# The lambda capture itself doesn't have a name. # Name of the `ViewMeta` specialization class created.
# The name returned here corresponds to the name of the inner function called by the lambda. def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
namespace = "at::functionalization::" if with_namespace else ""
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
# Name of the operation called inside the `forward`/`reverse` implementations.
def name( def name(
g: NativeFunctionsViewGroup, g: NativeFunctionsViewGroup,
*, *,
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse" return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType: def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs # Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1 assert len(func.returns) >= 1
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT) return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]: # Checks whether `func` might return more than one value.
if is_reverse: def is_multi_output(func: FunctionSchema) -> bool:
return [base_binding, mutated_view_binding, mutated_view_idx_binding] return len(func.returns) > 1 or (
else: len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
return [base_binding, mutated_view_idx_binding] )
def inner_call_index(func: FunctionSchema) -> Binding | None: # `ViewMeta` specialization constructor parameters.
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately # All specializations are paremeterized by `has_symbolic_inputs` flag.
if len(func.returns) > 1 or ( arguments = [has_symbolic_inputs_binding]
len(func.returns) == 1 and func.returns[0].type.is_list_like()
): # If `func` might return more than 1 value, we also parameterize this specialization
return mutated_view_idx_binding # with the output index.
return None if is_multi_output(func):
arguments.append(out_index_binding)
return arguments
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # `ViewMeta` specialized class' constructor arguments.
#
# Values needed specifically by this specialization, that the base class does not need.
# Same as the class' attributes, but non-owning.
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
return attributes(func, owning=False)
# `ViewMeta` specialized class' non-static member data.
#
# Essential data for calling the instance's `forward` and `reverse functions. You can
# think of them as values that should be captured from the functionalization kernel.
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
return [
reapply_views_binding,
inverse_return_mode_binding,
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
]
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor) assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:] non_self_args = args[1:]
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
# the reverse lambda does the same, but with an additional "mutated_view" arg # the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs # additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument. # their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func) if is_multi_output(func):
if index_binding is not None:
return [ return [
base_binding, base_binding,
mutated_view_binding, mutated_view_binding,
inverse_return_mode_binding, inverse_return_mode_binding,
index_binding, out_index_binding,
] + non_self_bindings ] + non_self_bindings
else: else:
return [ return [

View File

@ -300,83 +300,11 @@ class ViewInverseSignature:
return_type = functionalization.returns_type(self.g.view.func) return_type = functionalization.returns_type(self.g.view.func)
decls = [ decls = [
a.decl() a.decl()
for a in functionalization.inner_arguments( for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
self.g.view.func, is_reverse=True
)
] ]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
else:
return f'{inner_call_name}({", ".join(call_exprs)});'
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True) @dataclass(frozen=True)
class StructuredImplSignature: class StructuredImplSignature:
g: NativeFunctionsGroup g: NativeFunctionsGroup

View File

@ -45,6 +45,8 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_definition, gen_functionalization_definition,
gen_functionalization_registration, gen_functionalization_registration,
gen_functionalization_view_inverse_declaration, gen_functionalization_view_inverse_declaration,
gen_functionalization_view_meta_classes_decl,
gen_functionalization_view_meta_classes_impl,
GenCompositeViewCopyKernel, GenCompositeViewCopyKernel,
) )
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
@ -2577,48 +2579,48 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
}, },
) )
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
def functionalization_env_callable( def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]: ) -> dict[str, list[str]]:
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
return { return {
"ops_headers": gen_op_headers(g), "ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition( "func_definitions": gen_functionalization_definition(
@ -2684,6 +2686,31 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
}, },
) )
cpu_fm.write(
"ViewMetaClasses.h",
lambda: {
"view_meta_declarations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
view_groups,
)
)
},
)
cpu_fm.write(
"ViewMetaClasses.cpp",
lambda: {
"view_meta_implementations": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
view_groups,
)
),
"op_headers": list(concatMap(gen_op_headers, view_groups)),
},
)
# Note [view_copy NativeFunctions] # Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant. # needs to have a corresponding non-aliasing {view}_copy variant.

View File

@ -1,16 +1,15 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING from typing import Callable, Optional, TYPE_CHECKING
from torchgen.api import cpp, dispatcher from torchgen.api import cpp, dispatcher, functionalization
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import ( from torchgen.api.types import (
BaseCType, BaseCType,
Binding, Binding,
CType, CType,
DispatcherSignature, DispatcherSignature,
FunctionalizationLambda,
iTensorListRefT, iTensorListRefT,
NativeSignature, NativeSignature,
OptionalCType, OptionalCType,
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
) )
from torchgen.utils import dataclass_repr from torchgen.utils import concatMap, dataclass_repr, FileManager
if TYPE_CHECKING: if TYPE_CHECKING:
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
with native_function_manager(f): with native_function_manager(f):
call_sig = DispatcherSignature.from_schema(g.view_copy.func) call_sig = DispatcherSignature.from_schema(g.view_copy.func)
spec = ViewMetaSpecialization(g, f=f)
# the "view_copy" op name that the functionalization kernels need to call # the "view_copy" op name that the functionalization kernels need to call
api_name = g.view_copy.func.name.unambiguous_name() api_name = g.view_copy.func.name.unambiguous_name()
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
] ]
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
# The meta API call should use the same arguments, but convert all tensors to meta tensors first. # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [ meta_call_args = [
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
: at::functionalization::InverseReturnMode::NeverView : at::functionalization::InverseReturnMode::NeverView
); );
{symbolic_inputs_check} {symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = {spec.new()};
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
);
auto compute_reference_meta = auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
""" """
else: else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
return f""" return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str} {unwrap_tensor_args_str}
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
}} }}
}} }}
{symbolic_inputs_check} {symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( auto view_meta = {spec.new()};
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass] // See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta) {{ if (compute_reference_meta) {{
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
return emit_decl_helper(g) return emit_decl_helper(g)
# Helper class for generating `ViewMeta` specializations.
@dataclass
class ViewMetaSpecialization:
g: NativeFunctionsViewGroup
f: NativeFunction
@property
def is_multi_output(self) -> bool:
return functionalization.is_multi_output(self.f.func)
@property
def is_as_strided(self) -> bool:
return str(self.f.func.name) == "as_strided"
@property
def out_index(self) -> str:
if self.is_multi_output:
return functionalization.out_index_binding.name
return "0"
@property
def classname(self) -> str:
return functionalization.classname(self.f.func)
def decl(self) -> list[str]:
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
attributes = functionalization.attributes(self.f.func)
# List of types for declaring the `SerializableTuple` type.
serializable_tuple_args = ",\n".join(
f" {binding.type} /* {binding.name} */"
for binding in (base_ctor_arguments + attributes)
)
# Arguments used for forwarding the tuple elements to the constructor.
destructure_tuple_args = ", ".join(
f"std::get<{i}>(tpl)"
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
)
# List of constructor parameters
ctor_parameters = ", ".join(
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
)
# Call the base class `ViewMeta` constructor.
#
# Both of `is_multi_output` and `is_as_strided` are known values, given the
# operation schema.
is_multi_output_str = str(self.is_multi_output).lower()
is_as_strided_str = str(self.is_as_strided).lower()
base_ctor_bindings = ", ".join(
[
# `has_symbolic_inputs` is always taken as parameter.
functionalization.has_symbolic_inputs_binding.name,
f"/*is_multi_output=*/{is_multi_output_str}",
f"/*is_as_strided=*/{is_as_strided_str}",
# `out_index` is know if the operation returns only one value. Otherwise,
# we also take it as parameter.
f"/*out_index=*/{self.out_index}",
]
)
# Assignments of `extra_ctor_arguments` to their corresponding fields.
# These are extra fields to-be-declared in this specialization.
#
# We need to set `allow_expensive_conversions`, since we are storing owned versions
# of the non-owning arguments.
ctor_assignments = ",\n".join(
f" {e.type.name}({e.expr})"
for e in translate(
extra_ctor_arguments,
attributes,
method=False,
allow_expensive_conversions=True,
)
)
# List of arguments for constructing the `SerializableTuple` from an instance.
tuple_arguments = ", ".join(
binding.name for binding in (base_ctor_arguments + attributes)
)
# List of field declarations.
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
# Override `to_out_index` if this operation returns more than 1 value.
to_out_index_decl = ""
if self.is_multi_output:
to_out_index_decl = (
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
)
return [
f"""
struct TORCH_API {self.classname} : public ViewMeta {{
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname});
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
{self.classname}(const SerializableTuple& tpl)
: {self.classname}({destructure_tuple_args}) {{}}
{self.classname}({ctor_parameters})
: at::functionalization::ViewMeta({base_ctor_bindings}),
{ctor_assignments} {{}}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
{to_out_index_decl}
SerializableTuple to_serializable_tuple() {{
return std::make_tuple({tuple_arguments});
}}
{attr_declarations}
}};
"""
]
# Generate a call to the actual operation.
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
opname = functionalization.name(
self.g,
is_reverse=is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
# Expected arguments for the operation.
assert self.g.view_copy is not None
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
# The context is composed by the constructor arguments (which are also
# the field variables stored in the instance), and the `base` tensor.
context = [functionalization.base_binding]
context += functionalization.base_ctor_arguments(self.f.func)
context += functionalization.attributes(self.f.func)
# If we are generating the call for the reverse function, we also have
# access to `mutated_view` argument.
if is_reverse:
context.append(functionalization.mutated_view_binding)
arguments = ", ".join(
[e.expr for e in translate(context, op_arguments, method=False)]
)
# Index the result if this operation returns multiple values.
maybe_index = ""
if not is_reverse and self.is_multi_output:
maybe_index = f"[{self.out_index}]"
return f"{opname}({arguments}){maybe_index}"
def impl(self) -> list[str]:
functions = [
f"""
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
if (reapply_views) {{
return {self.opcall(is_reverse=False, reapply_views=True)};
}} else {{
return {self.opcall(is_reverse=False, reapply_views=False)};
}}
}}""",
f"""
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
return {self.opcall(is_reverse=True, reapply_views=True)};
}}""",
]
# If this operation returns multiple values, also generate a `to_out_index`
# implementation.
if self.is_multi_output:
functions.append(f"""
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
return {self.new("out_index")};
}}
""")
return functions
# Create the Python binding for this specialized class.
def binding(self) -> list[str]:
name = functionalization.classname(self.f.func, with_namespace=True)
return [f" create_binding_with_pickle<{name}>(functionalization);"]
# Generate an instanciation of this specialized class.
def new(self, out_index: str = "0") -> str:
name = functionalization.classname(self.f.func, with_namespace=True)
ctor_arguments = functionalization.base_ctor_arguments(
self.f.func
) + functionalization.extra_ctor_arguments(self.f.func)
# Replace the `out_index` parameter with the given `out_index`.
arguments = ", ".join(
binding.name if binding.name != "out_index" else out_index
for binding in ctor_arguments
)
return f"std::make_shared<{name}>({arguments})"
# Run the function `run` for both: `view` and `view_inplace` functions.
@staticmethod
def map(
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
) -> list[str]:
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
if f is None:
return []
with native_function_manager(f):
return run(ViewMetaSpecialization(g, f))
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
def gen_functionalization_view_meta_classes_base(
selector: SelectiveBuilder,
g: NativeFunctionsViewGroup,
run: Callable[[ViewMetaSpecialization], list[str]],
) -> list[str]:
if not selector.include_all_operators:
return []
if g.composite:
return []
return ViewMetaSpecialization.map(g, run)
def gen_functionalization_view_meta_classes_decl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.decl
)
def gen_functionalization_view_meta_classes_impl(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.impl
)
def gen_functionalization_view_meta_classes_binding(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> list[str]:
return gen_functionalization_view_meta_classes_base(
selector, g, ViewMetaSpecialization.binding
)
# Generates the Python bindings for the `ViewMeta` specialized classes.
def gen_functionalization_view_meta_classes(
native_functions_path: str,
tags_path: str,
selector: SelectiveBuilder,
install_dir: str,
template_dir: str,
) -> None:
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
# Parse the native_functions.yaml.
# Then, group them into `NativeFunctionsViewGroup`.
#
# This is the same steps we do in gen.py (ATen codegen).
native_functions = parse_native_yaml(
native_functions_path, tags_path
).native_functions
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
fm.write(
"ViewMetaClassesPythonBinding.cpp",
lambda: {
"view_meta_bindings": list(
concatMap(
lambda g: gen_functionalization_view_meta_classes_binding(
selector, g
),
view_groups,
)
),
},
)
def gen_functionalization_registration( def gen_functionalization_registration(
selector: SelectiveBuilder, selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,