mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
### Summary: NOTE: This is a re-export of https://github.com/pytorch/pytorch/pull/161994 ; the changes between these two PRs is exclusively to the buck/build files (Summary from #161994 ) Attempted rebase of https://github.com/pytorch/pytorch/pull/143712. This reverts commit 6c713ccb5e0df227dd5b630057cbccd373cbe7d6. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela imported-using-ghimport Test Plan: Imported from OSS Differential Revision: D81524507 Pulled By: Lucaskabela Pull Request resolved: https://github.com/pytorch/pytorch/pull/163769 Approved by: https://github.com/dolpm Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
29cbcbac42
commit
7d710403b0
1
.gitignore
vendored
1
.gitignore
vendored
@ -82,6 +82,7 @@ torch/return_types.pyi
|
|||||||
torch/nn/functional.pyi
|
torch/nn/functional.pyi
|
||||||
torch/utils/data/datapipes/datapipe.pyi
|
torch/utils/data/datapipes/datapipe.pyi
|
||||||
torch/csrc/autograd/generated/*
|
torch/csrc/autograd/generated/*
|
||||||
|
torch/csrc/functionalization/generated/*
|
||||||
torch/csrc/lazy/generated/*.[!m]*
|
torch/csrc/lazy/generated/*.[!m]*
|
||||||
torch_compile_debug/
|
torch_compile_debug/
|
||||||
# Listed manually because some files in this directory are not generated
|
# Listed manually because some files in this directory are not generated
|
||||||
|
@ -90,6 +90,8 @@ generated_cpu_cpp = [
|
|||||||
"aten/src/ATen/NativeMetaFunctions.h",
|
"aten/src/ATen/NativeMetaFunctions.h",
|
||||||
"aten/src/ATen/RegistrationDeclarations.h",
|
"aten/src/ATen/RegistrationDeclarations.h",
|
||||||
"aten/src/ATen/VmapGeneratedPlumbing.h",
|
"aten/src/ATen/VmapGeneratedPlumbing.h",
|
||||||
|
"aten/src/ATen/ViewMetaClasses.h",
|
||||||
|
"aten/src/ATen/ViewMetaClasses.cpp",
|
||||||
"aten/src/ATen/core/aten_interned_strings.h",
|
"aten/src/ATen/core/aten_interned_strings.h",
|
||||||
"aten/src/ATen/core/enum_tag.h",
|
"aten/src/ATen/core/enum_tag.h",
|
||||||
"aten/src/ATen/core/TensorBody.h",
|
"aten/src/ATen/core/TensorBody.h",
|
||||||
@ -1074,6 +1076,7 @@ test_suite(
|
|||||||
"aten/src/ATen/templates/LazyNonNativeIr.h",
|
"aten/src/ATen/templates/LazyNonNativeIr.h",
|
||||||
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
||||||
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
|
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
|
||||||
|
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
|
||||||
"aten/src/ATen/native/native_functions.yaml",
|
"aten/src/ATen/native/native_functions.yaml",
|
||||||
"aten/src/ATen/native/tags.yaml",
|
"aten/src/ATen/native/tags.yaml",
|
||||||
"aten/src/ATen/native/ts_native_functions.yaml",
|
"aten/src/ATen/native/ts_native_functions.yaml",
|
||||||
|
@ -9,11 +9,6 @@
|
|||||||
|
|
||||||
namespace at::functionalization {
|
namespace at::functionalization {
|
||||||
|
|
||||||
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
|
|
||||||
if (out_idx == this->out_index) return *this;
|
|
||||||
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note [Functionalization: Alias Removal Part 2]
|
// Note [Functionalization: Alias Removal Part 2]
|
||||||
// See Note [Functionalization: Alias Removal] for more details.
|
// See Note [Functionalization: Alias Removal] for more details.
|
||||||
// This function applies a single update from one of the views to the StorageImpl.
|
// This function applies a single update from one of the views to the StorageImpl.
|
||||||
@ -42,12 +37,12 @@ ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
|
|||||||
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
|
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
|
||||||
at::Tensor t = update.new_val;
|
at::Tensor t = update.new_val;
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||||
if (update.view_metas.empty()) return t;
|
if (update.view_metas.empty()) { return t; }
|
||||||
|
|
||||||
std::vector<at::Tensor> tmp_values({base});
|
std::vector<at::Tensor> tmp_values({base});
|
||||||
tmp_values.reserve(update.view_metas.size());
|
tmp_values.reserve(update.view_metas.size());
|
||||||
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
|
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
|
||||||
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
|
at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
|
||||||
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
|
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
|
||||||
// All of these ops require additional information to recover the sizes of the original tensor.
|
// All of these ops require additional information to recover the sizes of the original tensor.
|
||||||
// If need to, we could probably apply this optimization and only bother computing tmp_values
|
// If need to, we could probably apply this optimization and only bother computing tmp_values
|
||||||
@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
|
|||||||
tmp_values.push_back(std::move(next_view));
|
tmp_values.push_back(std::move(next_view));
|
||||||
}
|
}
|
||||||
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
|
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
|
||||||
int64_t out_idx = update.view_metas[i].out_index;
|
|
||||||
// Each view inverse is implemented in ViewInverses.cpp.
|
// Each view inverse is implemented in ViewInverses.cpp.
|
||||||
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
|
t = update.view_metas[i]->reverse(tmp_values[i], t);
|
||||||
}
|
}
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||||
return t;
|
return t;
|
||||||
@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
|
|||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
|
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
|
||||||
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
|
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
|
||||||
|
|
||||||
if (metas.size() > 1) {
|
if (metas.size() > 1) {
|
||||||
for (size_t i = 1; i < metas.size(); ++i) {
|
for (size_t i = 1; i < metas.size(); ++i) {
|
||||||
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
|
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
|
||||||
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
|
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
|
||||||
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
|
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
|
||||||
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
|
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
|
||||||
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
|
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
|
||||||
|
@ -8,44 +8,89 @@ namespace at::functionalization {
|
|||||||
|
|
||||||
// See Note [Functionalization Pass In Core]
|
// See Note [Functionalization Pass In Core]
|
||||||
|
|
||||||
|
enum class InverseReturnMode {
|
||||||
|
/// Specifies that functional inverses should always return a view.
|
||||||
|
AlwaysView,
|
||||||
|
/// Specifies that functional inverses should always return a non-view / copy.
|
||||||
|
NeverView,
|
||||||
|
/// Specifies that functional inverses should return a view unless a (copying)
|
||||||
|
/// scatter
|
||||||
|
/// inverse exists, in which case that will be used instead.
|
||||||
|
/// This avoids as_strided() calls that can be difficult for subclasses to
|
||||||
|
/// handle.
|
||||||
|
ViewOrScatterInverse,
|
||||||
|
};
|
||||||
|
|
||||||
|
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
|
||||||
|
static const char* name() { \
|
||||||
|
return #TYPE; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
|
||||||
|
using SerializableTuple = std::tuple<__VA_ARGS__>
|
||||||
|
|
||||||
// ViewMeta is a class used by the functionalization pass to navigate between
|
// ViewMeta is a class used by the functionalization pass to navigate between
|
||||||
// a base tensor and a view tensor.
|
// a base tensor and a view tensor.
|
||||||
// For example, if I call `b = a.view1(...)`
|
// For example, if I call `b = a.view1(...)`
|
||||||
// the functionalization pass will generate and store a ViewMeta on b that looks
|
// the functionalization pass will generate and store a ViewMeta specialization
|
||||||
// like:
|
// for `view1` operation on b that looks like:
|
||||||
//
|
//
|
||||||
// ViewMeta(
|
// struct TORCH_API view1_ViewMeta : public ViewMeta {
|
||||||
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
|
// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
|
||||||
|
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||||
|
// bool /* reapply_views */,
|
||||||
|
// const std::vector<int64_t>&);
|
||||||
|
//
|
||||||
|
// view1_ViewMeta(const SerializableTuple& tpl)
|
||||||
|
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||||
|
//
|
||||||
|
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
|
||||||
|
// : ViewMeta(/*has_symbolic_inputs=*/false),
|
||||||
|
// reapply_views(reapply_views),
|
||||||
|
// size(size) {}
|
||||||
|
//
|
||||||
|
// Tensor forward(const Tensor& base) override {
|
||||||
// return base.view1(...);
|
// return base.view1(...);
|
||||||
// },
|
// }
|
||||||
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
|
//
|
||||||
// int64_t mutated_view_idx) -> at::Tensor {
|
// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
|
||||||
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
||||||
// ...);
|
// ...);
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// The forward_fn lambda describes how to replay view1 on a tensor.
|
// SerializableTuple to_serializable_tuple() {
|
||||||
|
// return std::make_tuple(reapply_views, size);
|
||||||
|
// }
|
||||||
//
|
//
|
||||||
// The reverse_fn lambda describes how, given a tensor that is already a view,
|
// bool reapply_views;
|
||||||
|
// std::vector<int64_t> size;
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// The forward function describes how to replay view1 on a tensor.
|
||||||
|
//
|
||||||
|
// The reverse function describes how, given a tensor that is already a view,
|
||||||
// how to get the corresponding base tensor. See Note [Functionalization Pass:
|
// how to get the corresponding base tensor. See Note [Functionalization Pass:
|
||||||
// View Inverses] for details.
|
// View Inverses] for details.
|
||||||
|
//
|
||||||
|
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
|
||||||
|
// representing the `ViewMeta` instance state. Methods that take in/return such
|
||||||
|
// a type are used for supporting pickle serialization.
|
||||||
struct ViewMeta {
|
struct ViewMeta {
|
||||||
ViewMeta(
|
ViewMeta(
|
||||||
std::function<Tensor(const Tensor&, int64_t)> forward,
|
|
||||||
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
|
|
||||||
bool has_symbolic_inputs,
|
bool has_symbolic_inputs,
|
||||||
bool is_multi_output = false,
|
bool is_multi_output = false,
|
||||||
bool is_as_strided = false,
|
bool is_as_strided = false,
|
||||||
int64_t out_idx = 0)
|
int64_t out_idx = 0)
|
||||||
: forward_fn(std::move(forward)),
|
: out_index(out_idx),
|
||||||
reverse_fn(std::move(reverse)),
|
|
||||||
out_index(out_idx),
|
|
||||||
is_multi_output(is_multi_output),
|
is_multi_output(is_multi_output),
|
||||||
is_as_strided(is_as_strided),
|
is_as_strided(is_as_strided),
|
||||||
has_symbolic_inputs(has_symbolic_inputs) {}
|
has_symbolic_inputs(has_symbolic_inputs) {}
|
||||||
|
|
||||||
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
|
virtual ~ViewMeta() = default;
|
||||||
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
|
|
||||||
|
virtual Tensor forward(const Tensor& base) = 0;
|
||||||
|
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
|
||||||
|
|
||||||
// See Note [out_idx in ViewMeta]
|
// See Note [out_idx in ViewMeta]
|
||||||
int64_t out_index;
|
int64_t out_index;
|
||||||
|
|
||||||
@ -57,10 +102,17 @@ struct ViewMeta {
|
|||||||
// Tells us if this view operation has any symbolic inputs
|
// Tells us if this view operation has any symbolic inputs
|
||||||
bool has_symbolic_inputs;
|
bool has_symbolic_inputs;
|
||||||
|
|
||||||
// Returns a copy of the current ViewMeta, if out_idx matches the current
|
// Returns a new ViewMeta with the same forward/reverse
|
||||||
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
|
|
||||||
// functions, but a new out index.
|
// functions, but a new out index.
|
||||||
ViewMeta to_out_idx(int64_t out_idx);
|
//
|
||||||
|
// This method should be implemented by those `ViewMeta` that have more than
|
||||||
|
// one output.
|
||||||
|
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"ViewMeta::to_out_index not implemented. ",
|
||||||
|
"Likely because there's only one output.");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// FunctionalStorageImpl is a subclass of StorageImpl used by the
|
// FunctionalStorageImpl is a subclass of StorageImpl used by the
|
||||||
@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
|||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
const at::Tensor new_val;
|
const at::Tensor new_val;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
const std::vector<ViewMeta> view_metas;
|
const std::vector<std::shared_ptr<ViewMeta>> view_metas;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit FunctionalStorageImpl(const Tensor& value);
|
explicit FunctionalStorageImpl(const Tensor& value);
|
||||||
|
|
||||||
void add_update(
|
void add_update(
|
||||||
const Tensor& updated_val,
|
const Tensor& updated_val,
|
||||||
const std::vector<ViewMeta>& view_metas);
|
const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
|
||||||
bool apply_updates();
|
bool apply_updates();
|
||||||
const Tensor& base() {
|
const Tensor& base() {
|
||||||
return base_;
|
return base_;
|
||||||
|
@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const {
|
|||||||
// - view_value: The output tensor that we need to wrap.
|
// - view_value: The output tensor that we need to wrap.
|
||||||
// - base: The "base" of the view that `view_value` was generated from.
|
// - base: The "base" of the view that `view_value` was generated from.
|
||||||
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
|
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
|
||||||
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta)
|
FunctionalTensorWrapper::FunctionalTensorWrapper(
|
||||||
|
const Tensor& view_value,
|
||||||
|
const FunctionalTensorWrapper* base,
|
||||||
|
const std::shared_ptr<functionalization::ViewMeta>& meta)
|
||||||
: c10::TensorImpl(
|
: c10::TensorImpl(
|
||||||
c10::DispatchKeySet(DispatchKey::Functionalize),
|
c10::DispatchKeySet(DispatchKey::Functionalize),
|
||||||
view_value.dtype(),
|
view_value.dtype(),
|
||||||
base->storage().data_ptr().device()
|
base->storage().data_ptr().device()),
|
||||||
),
|
|
||||||
value_(view_value),
|
value_(view_value),
|
||||||
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
|
is_multi_output_view_(
|
||||||
|
base->is_multi_output_view_ || meta->is_multi_output),
|
||||||
was_storage_changed_(base->was_storage_changed_),
|
was_storage_changed_(base->was_storage_changed_),
|
||||||
is_symbolic_(base->is_symbolic_)
|
is_symbolic_(base->is_symbolic_) {
|
||||||
{
|
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
set_constructor_metadata();
|
set_constructor_metadata();
|
||||||
@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
|
|||||||
view_metas_ = base->view_metas_; // copy
|
view_metas_ = base->view_metas_; // copy
|
||||||
}
|
}
|
||||||
view_metas_.push_back(meta);
|
view_metas_.push_back(meta);
|
||||||
maybe_mark_symbolic(meta);
|
maybe_mark_symbolic(meta.get());
|
||||||
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
|
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
|
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
|
||||||
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
|
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
|
||||||
}
|
}
|
||||||
@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// See Note [Functionalization Pass - Inplace View Ops]
|
// See Note [Functionalization Pass - Inplace View Ops]
|
||||||
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) {
|
void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
|
||||||
view_metas_.push_back(meta);
|
view_metas_.push_back(meta);
|
||||||
// Manually track the fact that this tensor received a metadata mutation!
|
// Manually track the fact that this tensor received a metadata mutation!
|
||||||
has_metadata_mutation_ = true;
|
has_metadata_mutation_ = true;
|
||||||
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
|
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
|
||||||
maybe_mark_symbolic(meta);
|
maybe_mark_symbolic(meta.get());
|
||||||
// Note [Functionalization Pass - Inplace View Ops]
|
// Note [Functionalization Pass - Inplace View Ops]
|
||||||
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
||||||
// An example is transpose_, e.g. `a.transpose_()`
|
// An example is transpose_, e.g. `a.transpose_()`
|
||||||
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
||||||
at::AutoDispatchSkipFunctionalize guard;
|
at::AutoDispatchSkipFunctionalize guard;
|
||||||
value_ = meta.forward_fn(value_, meta.out_index);
|
value_ = meta->forward(value_);
|
||||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() {
|
|||||||
regenerate_from_base();
|
regenerate_from_base();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) {
|
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
|
||||||
auto t = base;
|
return view_metas_;
|
||||||
|
|
||||||
// Reapply views to get the viewed tensor from the base in alias_
|
|
||||||
for (auto& view_meta: view_metas_) {
|
|
||||||
t = view_meta.forward_fn(t, view_meta.out_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
return t;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void FunctionalTensorWrapper::regenerate_from_base() {
|
void FunctionalTensorWrapper::regenerate_from_base() {
|
||||||
@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
|
|||||||
auto t = storage_impl->base();
|
auto t = storage_impl->base();
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||||
t = apply_view_metas(t);
|
t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||||
|
|
||||||
replace_(t, /*from_lazy_regenerate=*/true);
|
replace_(t, /*from_lazy_regenerate=*/true);
|
||||||
@ -727,11 +721,11 @@ bool isFunctionalTensor(const std::optional<Tensor>& t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
|
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
|
||||||
if (t_list.empty()) return false;
|
if (t_list.empty()) { return false; }
|
||||||
auto functional_count = 0;
|
auto functional_count = 0;
|
||||||
for (const auto i : c10::irange(t_list.size())) {
|
for (const auto i : c10::irange(t_list.size())) {
|
||||||
auto const & e= t_list[i];
|
auto const & e= t_list[i];
|
||||||
if (!e.has_value() || !e->defined()) continue;
|
if (!e.has_value() || !e->defined()) { continue; }
|
||||||
if (isFunctionalTensor(e)) {
|
if (isFunctionalTensor(e)) {
|
||||||
++functional_count;
|
++functional_count;
|
||||||
}
|
}
|
||||||
@ -741,10 +735,10 @@ bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
|
static bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
|
||||||
if (list.size() == 0) return false;
|
if (list.size() == 0) { return false; }
|
||||||
auto functional_count = 0;
|
auto functional_count = 0;
|
||||||
for (const auto& tensor : list) {
|
for (const auto& tensor : list) {
|
||||||
if (!tensor.defined()) continue;
|
if (!tensor.defined()) { continue; }
|
||||||
if (isFunctionalTensor(tensor)) {
|
if (isFunctionalTensor(tensor)) {
|
||||||
++functional_count;
|
++functional_count;
|
||||||
}
|
}
|
||||||
@ -762,20 +756,28 @@ void freeze_functional_tensor(const Tensor& tensor) {
|
|||||||
functional_base_impl->freeze_storage();
|
functional_base_impl->freeze_storage();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
|
Tensor create_functional_tensor_with_view_meta(
|
||||||
|
const at::Tensor& view_to_wrap,
|
||||||
|
const at::Tensor& base,
|
||||||
|
const std::shared_ptr<functionalization::ViewMeta>& meta,
|
||||||
|
int64_t out_idx) {
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
|
||||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
|
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
|
||||||
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
|
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
|
||||||
|
auto meta_ = meta;
|
||||||
if (out_idx != 0) {
|
if (out_idx != 0) {
|
||||||
// Note [out_idx in ViewMeta]
|
// Note [out_idx in ViewMeta]
|
||||||
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
|
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
|
||||||
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
|
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
|
||||||
meta = meta.to_out_idx(out_idx);
|
meta_ = meta->to_out_index(out_idx);
|
||||||
}
|
}
|
||||||
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
|
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) {
|
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
||||||
|
ITensorListRef view_to_wrap,
|
||||||
|
const at::Tensor& base,
|
||||||
|
const std::shared_ptr<functionalization::ViewMeta>& meta) {
|
||||||
std::vector<Tensor> outputs(view_to_wrap.size());
|
std::vector<Tensor> outputs(view_to_wrap.size());
|
||||||
int64_t i = 0;
|
int64_t i = 0;
|
||||||
for (const auto& tensor : view_to_wrap) {
|
for (const auto& tensor : view_to_wrap) {
|
||||||
@ -785,12 +787,22 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) {
|
void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
|
||||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
|
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
|
||||||
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
|
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
|
||||||
self_impl->mutate_view_meta(meta);
|
self_impl->mutate_view_meta(meta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor apply_view_meta_sequence(
|
||||||
|
const Tensor& base,
|
||||||
|
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
|
||||||
|
Tensor r = base;
|
||||||
|
for (auto& vm : sequence) {
|
||||||
|
r = vm->forward(r);
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
// Note [Propagating strides in the functionalization pass]
|
// Note [Propagating strides in the functionalization pass]
|
||||||
// In order to properly compute stride information, the functionalization pass
|
// In order to properly compute stride information, the functionalization pass
|
||||||
// calls each {view} reference implementations with meta tensors.
|
// calls each {view} reference implementations with meta tensors.
|
||||||
@ -884,7 +896,7 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s
|
|||||||
const auto& ivalue = returns[idx];
|
const auto& ivalue = returns[idx];
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
const auto& t = ivalue.toTensor();
|
const auto& t = ivalue.toTensor();
|
||||||
if (!t.defined()) continue;
|
if (!t.defined()) { continue; }
|
||||||
at::functionalization::impl::sync(t);
|
at::functionalization::impl::sync(t);
|
||||||
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
|
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
|
||||||
(*stack)[returns_begin + idx] = t_new;
|
(*stack)[returns_begin + idx] = t_new;
|
||||||
|
@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||||||
explicit FunctionalTensorWrapper(
|
explicit FunctionalTensorWrapper(
|
||||||
const Tensor& view_value,
|
const Tensor& view_value,
|
||||||
const FunctionalTensorWrapper* base,
|
const FunctionalTensorWrapper* base,
|
||||||
const functionalization::ViewMeta& meta);
|
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||||
|
|
||||||
// Get the underlying, actual tensor, that doesn't know anything about
|
// Get the underlying, actual tensor, that doesn't know anything about
|
||||||
// functionalization.
|
// functionalization.
|
||||||
@ -99,17 +99,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||||||
->are_all_mutations_under_no_grad_or_inference_mode();
|
->are_all_mutations_under_no_grad_or_inference_mode();
|
||||||
}
|
}
|
||||||
|
|
||||||
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
|
void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
|
||||||
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
|
is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_symbolic() const {
|
bool is_symbolic() const {
|
||||||
return is_symbolic_;
|
return is_symbolic_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs the forward_fn of every ViewMeta collected in the current instance
|
// Retrieves the ViewMeta sequence of this tensor.
|
||||||
// to some other base.
|
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
|
||||||
Tensor apply_view_metas(const Tensor& base);
|
const;
|
||||||
|
|
||||||
// Sync's the underlying tensor with its alias, if it's out of date. This
|
// Sync's the underlying tensor with its alias, if it's out of date. This
|
||||||
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
||||||
@ -146,7 +146,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||||||
// from the base tensor. This method is used by inplace-view ops like
|
// from the base tensor. This method is used by inplace-view ops like
|
||||||
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
||||||
// tensor by replaying the views off of the alias.
|
// tensor by replaying the views off of the alias.
|
||||||
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
|
void mutate_view_meta(
|
||||||
|
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
|
||||||
|
|
||||||
// Custom implementation of self.set_(src)
|
// Custom implementation of self.set_(src)
|
||||||
void set__impl(const FunctionalTensorWrapper* other);
|
void set__impl(const FunctionalTensorWrapper* other);
|
||||||
@ -285,7 +286,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||||||
bool is_symbolic_ = false;
|
bool is_symbolic_ = false;
|
||||||
|
|
||||||
size_t generation_ = 0;
|
size_t generation_ = 0;
|
||||||
std::vector<at::functionalization::ViewMeta> view_metas_;
|
std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static void copy_tensor_metadata(
|
static void copy_tensor_metadata(
|
||||||
@ -377,16 +378,20 @@ TORCH_API void propagate_xla_data_direct(
|
|||||||
Tensor create_functional_tensor_with_view_meta(
|
Tensor create_functional_tensor_with_view_meta(
|
||||||
const Tensor& view_to_wrap,
|
const Tensor& view_to_wrap,
|
||||||
const Tensor& base,
|
const Tensor& base,
|
||||||
functionalization::ViewMeta meta,
|
const std::shared_ptr<functionalization::ViewMeta>& meta,
|
||||||
int64_t out_idx = 0);
|
int64_t out_idx = 0);
|
||||||
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
||||||
ITensorListRef view_to_wrap,
|
ITensorListRef view_to_wrap,
|
||||||
const Tensor& base,
|
const Tensor& base,
|
||||||
const functionalization::ViewMeta& meta);
|
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||||
|
|
||||||
void mutate_view_meta(
|
void mutate_view_meta(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
const functionalization::ViewMeta& meta);
|
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||||
|
|
||||||
|
TORCH_API Tensor apply_view_meta_sequence(
|
||||||
|
const Tensor& base,
|
||||||
|
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
|
||||||
|
|
||||||
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
||||||
void set_sizes_strides_offset(
|
void set_sizes_strides_offset(
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
#include <ATen/FunctionalizeFallbackKernel.h>
|
||||||
|
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
#include <ATen/core/LegacyTypeDispatch.h>
|
#include <ATen/core/LegacyTypeDispatch.h>
|
||||||
#include <ATen/EmptyTensor.h>
|
#include <ATen/EmptyTensor.h>
|
||||||
@ -7,7 +9,6 @@
|
|||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/util/strides.h>
|
#include <c10/util/strides.h>
|
||||||
#include <ATen/EmptyTensor.h>
|
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
@ -28,6 +29,31 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace at::functionalization {
|
||||||
|
|
||||||
|
Tensor resize__ViewMeta::forward(const Tensor& base) {
|
||||||
|
if (reapply_views) {
|
||||||
|
return base.as_strided(size, c10::contiguous_strides(size));
|
||||||
|
} else {
|
||||||
|
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
|
||||||
|
return base.as_strided_scatter(
|
||||||
|
mutated_view, size, c10::contiguous_strides(size));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
|
||||||
|
return at::_unsafe_view_symint(base, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
|
||||||
|
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::functionalization
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
|
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
|
||||||
const auto& schema = op.schema();
|
const auto& schema = op.schema();
|
||||||
@ -106,7 +132,9 @@ namespace {
|
|||||||
const auto& ivalue = returns[idx];
|
const auto& ivalue = returns[idx];
|
||||||
if (ivalue.isTensor() && should_wrap_outputs) {
|
if (ivalue.isTensor() && should_wrap_outputs) {
|
||||||
const auto& t = ivalue.toTensor();
|
const auto& t = ivalue.toTensor();
|
||||||
if (!t.defined()) continue;
|
if (!t.defined()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
|
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
|
||||||
(*stack)[returns_begin + idx] = t_new;
|
(*stack)[returns_begin + idx] = t_new;
|
||||||
} else if (ivalue.isTensorList() && should_wrap_outputs) {
|
} else if (ivalue.isTensorList() && should_wrap_outputs) {
|
||||||
@ -169,19 +197,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
|
|||||||
// The output of resizing is equivalent to taking a slice of a larger tensor.
|
// The output of resizing is equivalent to taking a slice of a larger tensor.
|
||||||
// We have to emulate this "slicing" with an as_strided call.
|
// We have to emulate this "slicing" with an as_strided call.
|
||||||
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
||||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
|
||||||
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
reapply_views, size.vec());
|
||||||
if (reapply_views) {
|
|
||||||
return base.as_strided(size, c10::contiguous_strides(size));
|
|
||||||
} else {
|
|
||||||
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
|
||||||
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
|
|
||||||
},
|
|
||||||
/*has_symbolic_inputs=*/false
|
|
||||||
);
|
|
||||||
at::functionalization::impl::mutate_view_meta(self, view_meta);
|
at::functionalization::impl::mutate_view_meta(self, view_meta);
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
@ -300,17 +317,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
|
|||||||
tmp_output = at::_unsafe_view_symint(self_, size);
|
tmp_output = at::_unsafe_view_symint(self_, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
|
bool has_symbolic_inputs = std::any_of(
|
||||||
|
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
|
||||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
auto view_meta =
|
||||||
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
|
||||||
return at::_unsafe_view_symint(base, size);
|
has_symbolic_inputs, size.vec());
|
||||||
},
|
|
||||||
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
|
||||||
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
|
|
||||||
},
|
|
||||||
/*has_symbolic_inputs=*/has_symbolic_inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
|
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
|
||||||
// See Note [Propagating strides in the functionalization pass]
|
// See Note [Propagating strides in the functionalization pass]
|
||||||
|
58
aten/src/ATen/FunctionalizeFallbackKernel.h
Normal file
58
aten/src/ATen/FunctionalizeFallbackKernel.h
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/FunctionalStorageImpl.h>
|
||||||
|
|
||||||
|
namespace at::functionalization {
|
||||||
|
|
||||||
|
// `ViewMeta` implementation for `resize_` operation.
|
||||||
|
struct TORCH_API resize__ViewMeta : public ViewMeta {
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta)
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||||
|
bool /* reapply_views */,
|
||||||
|
const std::vector<int64_t>&);
|
||||||
|
|
||||||
|
resize__ViewMeta(const SerializableTuple& tpl)
|
||||||
|
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||||
|
|
||||||
|
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
|
||||||
|
: ViewMeta(/*has_symbolic_inputs=*/false),
|
||||||
|
reapply_views(reapply_views),
|
||||||
|
size(size) {}
|
||||||
|
|
||||||
|
Tensor forward(const Tensor& base) override;
|
||||||
|
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||||
|
|
||||||
|
SerializableTuple to_serializable_tuple() {
|
||||||
|
return std::make_tuple(reapply_views, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool reapply_views;
|
||||||
|
std::vector<int64_t> size;
|
||||||
|
};
|
||||||
|
|
||||||
|
// `ViewMeta` implementation for `_unsafe_view` operation.
|
||||||
|
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta)
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||||
|
bool /* has_symbolic_inputs */,
|
||||||
|
const std::vector<c10::SymInt>&);
|
||||||
|
|
||||||
|
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
|
||||||
|
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||||
|
|
||||||
|
_unsafe_view_ViewMeta(
|
||||||
|
bool has_symbolic_inputs,
|
||||||
|
const std::vector<c10::SymInt>& size)
|
||||||
|
: ViewMeta(has_symbolic_inputs), size(size) {}
|
||||||
|
|
||||||
|
Tensor forward(const Tensor& base) override;
|
||||||
|
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||||
|
|
||||||
|
SerializableTuple to_serializable_tuple() {
|
||||||
|
return std::make_tuple(has_symbolic_inputs, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<c10::SymInt> size;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace at::functionalization
|
@ -2,22 +2,12 @@
|
|||||||
|
|
||||||
// ${generated_comment}
|
// ${generated_comment}
|
||||||
|
|
||||||
|
#include <ATen/FunctionalStorageImpl.h>
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace functionalization {
|
namespace functionalization {
|
||||||
|
|
||||||
enum class InverseReturnMode {
|
|
||||||
/// Specifies that functional inverses should always return a view.
|
|
||||||
AlwaysView,
|
|
||||||
/// Specifies that functional inverses should always return a non-view / copy.
|
|
||||||
NeverView,
|
|
||||||
/// Specifies that functional inverses should return a view unless a (copying) scatter
|
|
||||||
/// inverse exists, in which case that will be used instead.
|
|
||||||
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
|
|
||||||
ViewOrScatterInverse,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FunctionalInverses {
|
struct FunctionalInverses {
|
||||||
|
|
||||||
${view_inverse_declarations}
|
${view_inverse_declarations}
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <ATen/core/LegacyTypeDispatch.h>
|
#include <ATen/core/LegacyTypeDispatch.h>
|
||||||
#include <ATen/EmptyTensor.h>
|
#include <ATen/EmptyTensor.h>
|
||||||
#include <ATen/FunctionalTensorWrapper.h>
|
#include <ATen/FunctionalTensorWrapper.h>
|
||||||
#include <ATen/FunctionalInverses.h>
|
#include <ATen/ViewMetaClasses.h>
|
||||||
#include <ATen/MemoryOverlap.h>
|
#include <ATen/MemoryOverlap.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
19
aten/src/ATen/templates/ViewMetaClasses.cpp
Normal file
19
aten/src/ATen/templates/ViewMetaClasses.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
// ${generated_comment}
|
||||||
|
|
||||||
|
#include <ATen/FunctionalInverses.h>
|
||||||
|
#include <ATen/ViewMetaClasses.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Operators.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
${op_headers}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace functionalization {
|
||||||
|
|
||||||
|
${view_meta_implementations}
|
||||||
|
|
||||||
|
} // namespace functionalization
|
||||||
|
} // namespace at
|
12
aten/src/ATen/templates/ViewMetaClasses.h
Normal file
12
aten/src/ATen/templates/ViewMetaClasses.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
// ${generated_comment}
|
||||||
|
|
||||||
|
#include <ATen/FunctionalStorageImpl.h>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace functionalization {
|
||||||
|
|
||||||
|
${view_meta_declarations}
|
||||||
|
|
||||||
|
} // namespace functionalization
|
||||||
|
} // namespace at
|
11
aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp
Normal file
11
aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#include <ATen/ViewMetaClasses.h>
|
||||||
|
#include <torch/csrc/functionalization/Module.h>
|
||||||
|
|
||||||
|
namespace torch::functionalization {
|
||||||
|
|
||||||
|
void initGenerated(PyObject* module) {
|
||||||
|
auto functionalization = py::handle(module).cast<py::module>();
|
||||||
|
$view_meta_bindings
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::functionalization
|
@ -391,6 +391,8 @@ def get_aten_generated_files(enabled_backends):
|
|||||||
"CompositeExplicitAutogradFunctions_inl.h",
|
"CompositeExplicitAutogradFunctions_inl.h",
|
||||||
"CompositeExplicitAutogradNonFunctionalFunctions.h",
|
"CompositeExplicitAutogradNonFunctionalFunctions.h",
|
||||||
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
||||||
|
"ViewMetaClasses.h",
|
||||||
|
"ViewMetaClasses.cpp",
|
||||||
"VmapGeneratedPlumbing.h",
|
"VmapGeneratedPlumbing.h",
|
||||||
"core/ATenOpList.cpp",
|
"core/ATenOpList.cpp",
|
||||||
"core/TensorBody.h",
|
"core/TensorBody.h",
|
||||||
@ -1193,6 +1195,7 @@ def define_buck_targets(
|
|||||||
"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
|
"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
|
||||||
"Operators.h": ":gen_aten[Operators.h]",
|
"Operators.h": ":gen_aten[Operators.h]",
|
||||||
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
|
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
|
||||||
|
"ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]",
|
||||||
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
|
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
|
||||||
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
|
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
|
||||||
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
|
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
|
||||||
|
@ -118,6 +118,9 @@ def define_targets(rules):
|
|||||||
":LazyNonNativeIr.h",
|
":LazyNonNativeIr.h",
|
||||||
":RegisterDispatchDefinitions.ini",
|
":RegisterDispatchDefinitions.ini",
|
||||||
":RegisterDispatchKey.cpp",
|
":RegisterDispatchKey.cpp",
|
||||||
|
":ViewMetaClassesPythonBinding.cpp",
|
||||||
|
":ViewMetaClasses.cpp",
|
||||||
|
":ViewMetaClasses.h",
|
||||||
":native_functions.yaml",
|
":native_functions.yaml",
|
||||||
":shape_inference.h",
|
":shape_inference.h",
|
||||||
":tags.yaml",
|
":tags.yaml",
|
||||||
@ -170,6 +173,7 @@ GENERATED_H = [
|
|||||||
"FunctionalInverses.h",
|
"FunctionalInverses.h",
|
||||||
"RedispatchFunctions.h",
|
"RedispatchFunctions.h",
|
||||||
"RegistrationDeclarations.h",
|
"RegistrationDeclarations.h",
|
||||||
|
"ViewMetaClasses.h",
|
||||||
"VmapGeneratedPlumbing.h",
|
"VmapGeneratedPlumbing.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -246,6 +250,7 @@ GENERATED_CPP = [
|
|||||||
"RegisterFunctionalization_1.cpp",
|
"RegisterFunctionalization_1.cpp",
|
||||||
"RegisterFunctionalization_2.cpp",
|
"RegisterFunctionalization_2.cpp",
|
||||||
"RegisterFunctionalization_3.cpp",
|
"RegisterFunctionalization_3.cpp",
|
||||||
|
"ViewMetaClasses.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
GENERATED_CPP_CORE = [
|
GENERATED_CPP_CORE = [
|
||||||
@ -307,6 +312,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
|
|||||||
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
||||||
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
||||||
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
||||||
|
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
|
||||||
]
|
]
|
||||||
|
|
||||||
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP
|
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP
|
||||||
|
@ -1010,6 +1010,7 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/utils/disable_torch_function.cpp",
|
"torch/csrc/utils/disable_torch_function.cpp",
|
||||||
"torch/csrc/utils/verbose.cpp",
|
"torch/csrc/utils/verbose.cpp",
|
||||||
"torch/csrc/cpu/Module.cpp",
|
"torch/csrc/cpu/Module.cpp",
|
||||||
|
"torch/csrc/functionalization/Module.cpp",
|
||||||
"torch/csrc/instruction_counter/Module.cpp",
|
"torch/csrc/instruction_counter/Module.cpp",
|
||||||
"torch/nativert/python/Bindings.cpp",
|
"torch/nativert/python/Bindings.cpp",
|
||||||
] + lazy_tensor_core_python_sources
|
] + lazy_tensor_core_python_sources
|
||||||
@ -1052,6 +1053,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
|
|||||||
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
||||||
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
||||||
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
||||||
|
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp",
|
||||||
]]
|
]]
|
||||||
|
|
||||||
_libtorch_python_sources.extend(libtorch_python_core_sources)
|
_libtorch_python_sources.extend(libtorch_python_core_sources)
|
||||||
|
@ -316,6 +316,7 @@ set(GENERATED_CXX_PYTHON
|
|||||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
||||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
|
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
|
||||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
|
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
|
||||||
|
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
set(GENERATED_H_PYTHON
|
set(GENERATED_H_PYTHON
|
||||||
@ -379,6 +380,9 @@ add_custom_command(
|
|||||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
|
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
|
||||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
|
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
|
||||||
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
|
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
|
||||||
|
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h"
|
||||||
|
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp"
|
||||||
|
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
|
||||||
${autograd_python}
|
${autograd_python}
|
||||||
${autograd_yaml}
|
${autograd_yaml}
|
||||||
${autograd_templates}
|
${autograd_templates}
|
||||||
|
@ -156,6 +156,7 @@ def get_generate_code_bin_outs():
|
|||||||
"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
|
"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
|
||||||
"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
|
"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
|
||||||
"autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
|
"autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
|
||||||
|
"functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"],
|
||||||
})
|
})
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
@ -519,11 +519,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||||||
@functorch_config.patch(
|
@functorch_config.patch(
|
||||||
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
|
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
|
||||||
)
|
)
|
||||||
def test_view_replay_bypass(self):
|
def test_view_replay(self):
|
||||||
"""
|
|
||||||
Should bypass when view replay is turned on
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fn(a):
|
def fn(a):
|
||||||
tmp = a.detach()
|
tmp = a.detach()
|
||||||
a.mul_(2)
|
a.mul_(2)
|
||||||
@ -531,10 +527,25 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||||||
|
|
||||||
with torch.autograd._force_original_view_tracking(True):
|
with torch.autograd._force_original_view_tracking(True):
|
||||||
compiled_fn = torch.compile(fn)
|
compiled_fn = torch.compile(fn)
|
||||||
compiled_fn(torch.rand(2, 3))
|
|
||||||
|
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
def run_and_check(miss, hit, bypass):
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
self._clear_dynamo_and_codecache()
|
||||||
|
|
||||||
|
inp = torch.rand(2, 3)
|
||||||
|
compiled_inp = inp.clone().detach()
|
||||||
|
|
||||||
|
with torch.autograd._force_original_view_tracking(True):
|
||||||
|
out = fn(inp)
|
||||||
|
compiled_out = compiled_fn(compiled_inp)
|
||||||
|
|
||||||
|
self.assertEqual(out, compiled_out)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
|
||||||
|
|
||||||
|
run_and_check(miss=1, hit=0, bypass=0)
|
||||||
|
run_and_check(miss=1, hit=1, bypass=0)
|
||||||
|
run_and_check(miss=1, hit=2, bypass=0)
|
||||||
|
|
||||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||||
@inductor_config.patch("fx_graph_cache", True)
|
@inductor_config.patch("fx_graph_cache", True)
|
||||||
|
@ -8500,7 +8500,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
|||||||
{
|
{
|
||||||
"enable_autograd_cache": True,
|
"enable_autograd_cache": True,
|
||||||
"strict_autograd_cache": True,
|
"strict_autograd_cache": True,
|
||||||
"view_replay_for_aliased_outputs": False,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@torch._inductor.config.patch("fx_graph_cache", True)
|
@torch._inductor.config.patch("fx_graph_cache", True)
|
||||||
|
@ -189,6 +189,12 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
options = parser.parse_args()
|
options = parser.parse_args()
|
||||||
|
|
||||||
|
# Path: aten/src/ATen
|
||||||
|
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
|
||||||
|
operator_selector = get_selector(
|
||||||
|
options.selected_op_list_path, options.operators_yaml_path
|
||||||
|
)
|
||||||
|
|
||||||
generate_code(
|
generate_code(
|
||||||
options.gen_dir,
|
options.gen_dir,
|
||||||
options.native_functions_path,
|
options.native_functions_path,
|
||||||
@ -198,18 +204,37 @@ def main() -> None:
|
|||||||
options.disable_autograd,
|
options.disable_autograd,
|
||||||
options.force_schema_registration,
|
options.force_schema_registration,
|
||||||
# options.selected_op_list
|
# options.selected_op_list
|
||||||
operator_selector=get_selector(
|
operator_selector=operator_selector,
|
||||||
options.selected_op_list_path, options.operators_yaml_path
|
)
|
||||||
),
|
|
||||||
|
# Generate the python bindings for functionalization's `ViewMeta` classes.
|
||||||
|
from torchgen.gen_functionalization_type import (
|
||||||
|
gen_functionalization_view_meta_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
functionalization_templates_dir = os.path.join(aten_path, "templates")
|
||||||
|
install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
|
||||||
|
functionalization_install_dir = os.path.join(
|
||||||
|
install_dir, "functionalization", "generated"
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(functionalization_install_dir, exist_ok=True)
|
||||||
|
assert os.path.isdir(functionalization_install_dir)
|
||||||
|
assert os.path.isdir(functionalization_templates_dir)
|
||||||
|
|
||||||
|
gen_functionalization_view_meta_classes(
|
||||||
|
options.native_functions_path or NATIVE_FUNCTIONS_PATH,
|
||||||
|
options.tags_path or TAGS_PATH,
|
||||||
|
selector=operator_selector,
|
||||||
|
install_dir=functionalization_install_dir,
|
||||||
|
template_dir=functionalization_templates_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
if options.gen_lazy_ts_backend:
|
if options.gen_lazy_ts_backend:
|
||||||
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
|
|
||||||
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
|
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
|
||||||
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
|
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
|
||||||
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
|
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
|
||||||
install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
|
lazy_install_dir = os.path.join(install_dir, "lazy", "generated")
|
||||||
lazy_install_dir = os.path.join(install_dir, "lazy/generated")
|
|
||||||
os.makedirs(lazy_install_dir, exist_ok=True)
|
os.makedirs(lazy_install_dir, exist_ok=True)
|
||||||
|
|
||||||
assert os.path.isfile(ts_backend_yaml), (
|
assert os.path.isfile(ts_backend_yaml), (
|
||||||
|
@ -30,6 +30,7 @@ from torch._C import (
|
|||||||
_cpu,
|
_cpu,
|
||||||
_dynamo,
|
_dynamo,
|
||||||
_export,
|
_export,
|
||||||
|
_functionalization,
|
||||||
_functorch,
|
_functorch,
|
||||||
_lazy,
|
_lazy,
|
||||||
_lazy_ts_backend,
|
_lazy_ts_backend,
|
||||||
|
16
torch/_C/_functionalization.pyi
Normal file
16
torch/_C/_functionalization.pyi
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
from torch.types import _bool
|
||||||
|
|
||||||
|
# Defined in torch/csrc/functionalization/Module.cpp
|
||||||
|
|
||||||
|
class ViewMeta:
|
||||||
|
has_symbolic_inputs: _bool
|
||||||
|
|
||||||
|
# Returns the list of ViewMeta instances of the given functional tensor.
|
||||||
|
#
|
||||||
|
# Although we do have python bindings for their types, we won't
|
||||||
|
# expose them here, since they should not be used by users.
|
||||||
|
def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ...
|
||||||
|
|
||||||
|
# Applies the ViewMeta sequence on top of the given base.
|
||||||
|
def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ...
|
@ -284,19 +284,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
|
|||||||
check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type]
|
check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
|
|
||||||
"""
|
|
||||||
When view replay is turned on, we bypass autograd cache if
|
|
||||||
the output is aliased.
|
|
||||||
"""
|
|
||||||
if config.view_replay_for_aliased_outputs:
|
|
||||||
for info in metadata.output_info:
|
|
||||||
if info.functional_tensor is not None:
|
|
||||||
raise BypassAOTAutogradCache(
|
|
||||||
"Cannot cache a graph with functional tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AOTAutogradCacheDetails(FxGraphHashDetails):
|
class AOTAutogradCacheDetails(FxGraphHashDetails):
|
||||||
"""
|
"""
|
||||||
Object to capture all the details for a dynamo graph module relevant to computing
|
Object to capture all the details for a dynamo graph module relevant to computing
|
||||||
@ -803,7 +790,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
|
|||||||
"""
|
"""
|
||||||
Perform any preparations to make the cache entry ready for serialization.
|
Perform any preparations to make the cache entry ready for serialization.
|
||||||
"""
|
"""
|
||||||
check_metadata_cacheable(self.runtime_metadata)
|
|
||||||
self.compiled_fw.pre_save()
|
self.compiled_fw.pre_save()
|
||||||
if self.compiled_bw is not None:
|
if self.compiled_bw is not None:
|
||||||
self.compiled_bw.pre_save()
|
self.compiled_bw.pre_save()
|
||||||
|
@ -43,10 +43,10 @@ from .functional_utils import (
|
|||||||
has_metadata_mutation,
|
has_metadata_mutation,
|
||||||
MetadataKey,
|
MetadataKey,
|
||||||
to_fun,
|
to_fun,
|
||||||
|
ViewMetaSequence,
|
||||||
was_inductor_storage_resized,
|
was_inductor_storage_resized,
|
||||||
)
|
)
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
FunctionalTensorMetadataEq,
|
|
||||||
InputAliasInfo,
|
InputAliasInfo,
|
||||||
MemoryFormatMeta,
|
MemoryFormatMeta,
|
||||||
MutationType,
|
MutationType,
|
||||||
@ -640,7 +640,7 @@ from a multi-output view call"
|
|||||||
#
|
#
|
||||||
# The FunctionalTensor will be saved if one of the 2 conditions below
|
# The FunctionalTensor will be saved if one of the 2 conditions below
|
||||||
# is true:
|
# is true:
|
||||||
functional_tensor = None
|
view_meta_sequence = None
|
||||||
if (
|
if (
|
||||||
# 1. If the output_type is either of:
|
# 1. If the output_type is either of:
|
||||||
# (i) alias_of_intermediate;
|
# (i) alias_of_intermediate;
|
||||||
@ -672,7 +672,7 @@ from a multi-output view call"
|
|||||||
and not input_info[base_idx].mutates_metadata
|
and not input_info[base_idx].mutates_metadata
|
||||||
):
|
):
|
||||||
if isinstance(o, FunctionalTensor):
|
if isinstance(o, FunctionalTensor):
|
||||||
functional_tensor = FunctionalTensorMetadataEq(o.elem)
|
view_meta_sequence = ViewMetaSequence(o)
|
||||||
|
|
||||||
out_info = OutputAliasInfo(
|
out_info = OutputAliasInfo(
|
||||||
output_type=output_type,
|
output_type=output_type,
|
||||||
@ -680,7 +680,7 @@ from a multi-output view call"
|
|||||||
base_idx=base_idx,
|
base_idx=base_idx,
|
||||||
dynamic_dims=dynamic_dims,
|
dynamic_dims=dynamic_dims,
|
||||||
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
|
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
|
||||||
functional_tensor=functional_tensor,
|
view_meta_sequence=view_meta_sequence,
|
||||||
)
|
)
|
||||||
output_info.append(out_info)
|
output_info.append(out_info)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch._C import _functionalization
|
||||||
from torch._logging import getArtifactLogger
|
from torch._logging import getArtifactLogger
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||||
@ -224,9 +225,9 @@ def gen_alias_from_base(
|
|||||||
aliased_base_tensor,
|
aliased_base_tensor,
|
||||||
target_meta_tensor,
|
target_meta_tensor,
|
||||||
target_requires_grad,
|
target_requires_grad,
|
||||||
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None,
|
target_view_meta_sequence: Optional[ViewMetaSequence] = None,
|
||||||
*,
|
*,
|
||||||
replay_views,
|
replay_views: bool,
|
||||||
):
|
):
|
||||||
# Patch the correct requires_grad field of the output tensor, depending on whether:
|
# Patch the correct requires_grad field of the output tensor, depending on whether:
|
||||||
# (i) the reconstructed output (out) was came from a tensor that requires grad or not;
|
# (i) the reconstructed output (out) was came from a tensor that requires grad or not;
|
||||||
@ -245,13 +246,11 @@ def gen_alias_from_base(
|
|||||||
# to replay them (view functions) on the aliased_base_tensor.
|
# to replay them (view functions) on the aliased_base_tensor.
|
||||||
if (
|
if (
|
||||||
replay_views
|
replay_views
|
||||||
and target_functional_tensor is not None
|
and target_view_meta_sequence is not None
|
||||||
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
|
and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence)
|
||||||
):
|
):
|
||||||
functional_tensor = target_functional_tensor.tensor
|
out = _functionalization.apply_view_meta_sequence(
|
||||||
|
aliased_base_tensor, target_view_meta_sequence.sequence
|
||||||
out = torch._functionalize_apply_view_metas(
|
|
||||||
functional_tensor, aliased_base_tensor
|
|
||||||
)
|
)
|
||||||
# If re-applying the ViewMeta sequence succeeded, there should be no more
|
# If re-applying the ViewMeta sequence succeeded, there should be no more
|
||||||
# problems going forward. We just check we got to the target shape and
|
# problems going forward. We just check we got to the target shape and
|
||||||
@ -357,25 +356,45 @@ class MetadataKey:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata
|
# ViewMeta sequence wrapper for equality comparisons.
|
||||||
# after applying all the ViewMeta operations.
|
#
|
||||||
class FunctionalTensorMetadataEq:
|
# Even though we can compare each ViewMeta instance, we compare the resulting
|
||||||
def __init__(self, tensor: torch.Tensor) -> None:
|
# tensor metadata, instead. That's because the creation of synthetic bases + the
|
||||||
assert torch._is_functional_tensor(tensor)
|
# re-generation of input views might end-up creating a different sequence of
|
||||||
self.tensor = tensor
|
# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same
|
||||||
|
# metadata.
|
||||||
|
#
|
||||||
|
# Therefore, we store what the end result should look like as serializable
|
||||||
|
# metadata.
|
||||||
|
#
|
||||||
|
# When logging, this class should look like:
|
||||||
|
#
|
||||||
|
# ViewMetaSequence(view, select_int, slice_Tensor)
|
||||||
|
#
|
||||||
|
# i.e. a parenthesized list of view operations within that ViewMeta sequence.
|
||||||
|
class ViewMetaSequence:
|
||||||
|
def __init__(self, tensor: FunctionalTensor) -> None:
|
||||||
|
assert torch._is_functional_tensor(tensor.elem)
|
||||||
|
self.sequence = _functionalization.get_view_meta_sequence(tensor.elem)
|
||||||
|
self.metadata = MetadataKey.make(tensor)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
suffix = len("_ViewMeta")
|
||||||
|
types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence)
|
||||||
|
return f"ViewMetaSequence({types})"
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
# If other is None, then it probably means that we weren't able to recreate
|
# If other is None, then it probably means that we weren't able to recreate
|
||||||
# the FunctionalTensorMetadataEq. One of this cases is when we update the
|
# the ViewMeta sequence. One example is when we update the view metadata by
|
||||||
# view metadata by calling: create_synthetic_base_metadata.
|
# calling: create_synthetic_base_metadata.
|
||||||
if other is None:
|
if other is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Comparison against any other type is not implemented.
|
# Comparison against any other type is not implemented.
|
||||||
if not isinstance(other, FunctionalTensorMetadataEq):
|
if not isinstance(other, ViewMetaSequence):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
return has_same_metadata(self.tensor, other.tensor)
|
return self.metadata == other.metadata
|
||||||
|
|
||||||
|
|
||||||
# new_arg and arg here are either:
|
# new_arg and arg here are either:
|
||||||
|
@ -89,7 +89,7 @@ def remove_dupe_metadata(
|
|||||||
dynamic_dims=o.dynamic_dims,
|
dynamic_dims=o.dynamic_dims,
|
||||||
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
|
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
|
||||||
requires_grad=o.requires_grad,
|
requires_grad=o.requires_grad,
|
||||||
functional_tensor=o.functional_tensor,
|
view_meta_sequence=o.view_meta_sequence,
|
||||||
)
|
)
|
||||||
for o in m.output_info
|
for o in m.output_info
|
||||||
],
|
],
|
||||||
@ -242,7 +242,7 @@ def create_synthetic_base_metadata(
|
|||||||
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
|
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
|
||||||
base_idx=new_base_idx, # type: ignore[arg-type]
|
base_idx=new_base_idx, # type: ignore[arg-type]
|
||||||
requires_grad=o.requires_grad,
|
requires_grad=o.requires_grad,
|
||||||
functional_tensor=o.functional_tensor,
|
view_meta_sequence=o.view_meta_sequence,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class AliasOfInputHandler:
|
|||||||
self.base_idx = info.base_idx
|
self.base_idx = info.base_idx
|
||||||
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
||||||
self.requires_grad = info.requires_grad
|
self.requires_grad = info.requires_grad
|
||||||
self.functional_tensor = info.functional_tensor
|
self.view_meta_sequence = info.view_meta_sequence
|
||||||
self.replay_views = config.view_replay_for_aliased_outputs
|
self.replay_views = config.view_replay_for_aliased_outputs
|
||||||
|
|
||||||
def __call__(self, orig_inputs, fw_outs, out):
|
def __call__(self, orig_inputs, fw_outs, out):
|
||||||
@ -159,7 +159,7 @@ class AliasOfInputHandler:
|
|||||||
aliased_base_tensor,
|
aliased_base_tensor,
|
||||||
self.unwrap_out(out),
|
self.unwrap_out(out),
|
||||||
self.requires_grad,
|
self.requires_grad,
|
||||||
self.functional_tensor,
|
self.view_meta_sequence,
|
||||||
replay_views=self.replay_views,
|
replay_views=self.replay_views,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class AliasOfIntermediateHandler:
|
|||||||
|
|
||||||
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
||||||
self.requires_grad = info.requires_grad
|
self.requires_grad = info.requires_grad
|
||||||
self.functional_tensor = info.functional_tensor
|
self.view_meta_sequence = info.view_meta_sequence
|
||||||
self.replay_views = config.view_replay_for_aliased_outputs
|
self.replay_views = config.view_replay_for_aliased_outputs
|
||||||
|
|
||||||
def __call__(self, orig_inputs, fw_outs, out):
|
def __call__(self, orig_inputs, fw_outs, out):
|
||||||
@ -199,7 +199,7 @@ class AliasOfIntermediateHandler:
|
|||||||
self._unwrap_aliased_base_tensor(aliased_base_tensor),
|
self._unwrap_aliased_base_tensor(aliased_base_tensor),
|
||||||
self.unwrap_out(out),
|
self.unwrap_out(out),
|
||||||
self.requires_grad,
|
self.requires_grad,
|
||||||
self.functional_tensor,
|
self.view_meta_sequence,
|
||||||
replay_views=self.replay_views,
|
replay_views=self.replay_views,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ input/output types, metadata, config, function signatures etc.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import dataclasses
|
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -32,10 +31,7 @@ from torch.fx.experimental._backward_state import BackwardState
|
|||||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
|
|
||||||
from .. import config
|
from .. import config
|
||||||
from .functional_utils import (
|
from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence
|
||||||
_check_if_mutation_can_be_in_graph,
|
|
||||||
FunctionalTensorMetadataEq,
|
|
||||||
)
|
|
||||||
from .utils import strict_zip
|
from .utils import strict_zip
|
||||||
|
|
||||||
|
|
||||||
@ -117,15 +113,14 @@ class OutputAliasInfo:
|
|||||||
dynamic_dims: Optional[set[int]]
|
dynamic_dims: Optional[set[int]]
|
||||||
# requires_grad
|
# requires_grad
|
||||||
requires_grad: bool
|
requires_grad: bool
|
||||||
# FunctionalTensorWrapper that represents this output.
|
# Sequence of ViewMeta objects.
|
||||||
#
|
#
|
||||||
# Provides us the means to replay views from it.
|
# Provides us the means to re-run view functions on other tensors.
|
||||||
#
|
#
|
||||||
# We need to wrap the actual FunctionalTensorWrapper with this class so that
|
# We need to wrap the actual list of ViewMeta with this class so that
|
||||||
# we only compare the tensor's metadata. That's because with the transformations
|
# we compare the ViewMeta elements appropriately, i.e. their type and
|
||||||
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base
|
# the elements returned by the `as_tuple()` call.
|
||||||
# tensor might change.
|
view_meta_sequence: Optional[ViewMetaSequence] = None
|
||||||
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MutationType(Enum):
|
class MutationType(Enum):
|
||||||
@ -665,17 +660,6 @@ class ViewAndMutationMeta:
|
|||||||
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
|
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
|
||||||
# Clear traced tangents at runtime
|
# Clear traced tangents at runtime
|
||||||
self.traced_tangents = []
|
self.traced_tangents = []
|
||||||
new_output_info = []
|
|
||||||
for out in self.output_info:
|
|
||||||
if config.view_replay_for_aliased_outputs:
|
|
||||||
new_out = out
|
|
||||||
else:
|
|
||||||
# If we're not using view_replay, remove the functional tensor.
|
|
||||||
# Functional tensors are unfortunately not serializable,
|
|
||||||
# so doing this is required for AOTAutograd caching.
|
|
||||||
new_out = dataclasses.replace(out, functional_tensor=None)
|
|
||||||
new_output_info.append(new_out)
|
|
||||||
self.output_info = new_output_info
|
|
||||||
for inp_meta in self.subclass_inp_meta:
|
for inp_meta in self.subclass_inp_meta:
|
||||||
if isinstance(inp_meta, SubclassCreationMeta):
|
if isinstance(inp_meta, SubclassCreationMeta):
|
||||||
inp_meta.make_runtime_safe()
|
inp_meta.make_runtime_safe()
|
||||||
|
@ -72,6 +72,7 @@
|
|||||||
#include <torch/csrc/cpu/Module.h>
|
#include <torch/csrc/cpu/Module.h>
|
||||||
#include <torch/csrc/dynamo/init.h>
|
#include <torch/csrc/dynamo/init.h>
|
||||||
#include <torch/csrc/export/pybind.h>
|
#include <torch/csrc/export/pybind.h>
|
||||||
|
#include <torch/csrc/functionalization/Module.h>
|
||||||
#include <torch/csrc/functorch/init.h>
|
#include <torch/csrc/functorch/init.h>
|
||||||
#include <torch/csrc/fx/node.h>
|
#include <torch/csrc/fx/node.h>
|
||||||
#include <torch/csrc/inductor/aoti_package/pybind.h>
|
#include <torch/csrc/inductor/aoti_package/pybind.h>
|
||||||
@ -2080,6 +2081,7 @@ PyObject* initModule() {
|
|||||||
torch::instruction_counter::initModule(module);
|
torch::instruction_counter::initModule(module);
|
||||||
torch::initVerboseBindings(module);
|
torch::initVerboseBindings(module);
|
||||||
ASSERT_TRUE(THPStorage_init(module));
|
ASSERT_TRUE(THPStorage_init(module));
|
||||||
|
torch::functionalization::initModule(module);
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
// This will only initialise base classes and attach them to library namespace
|
// This will only initialise base classes and attach them to library namespace
|
||||||
|
@ -644,15 +644,6 @@ void initTorchFunctions(PyObject* module) {
|
|||||||
at::functionalization::impl::isFunctionalTensor(t));
|
at::functionalization::impl::isFunctionalTensor(t));
|
||||||
at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
|
at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
|
||||||
});
|
});
|
||||||
py_module.def(
|
|
||||||
"_functionalize_apply_view_metas",
|
|
||||||
[](const at::Tensor& tensor, const at::Tensor& base) {
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
at::functionalization::impl::isFunctionalTensor(tensor));
|
|
||||||
auto impl =
|
|
||||||
at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
|
||||||
return impl->apply_view_metas(base);
|
|
||||||
});
|
|
||||||
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
|
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
|
||||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||||
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||||
|
71
torch/csrc/functionalization/Module.cpp
Normal file
71
torch/csrc/functionalization/Module.cpp
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#include <torch/csrc/functionalization/Module.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
#include <ATen/FunctionalStorageImpl.h>
|
||||||
|
#include <ATen/FunctionalTensorWrapper.h>
|
||||||
|
#include <ATen/FunctionalizeFallbackKernel.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace torch::functionalization {
|
||||||
|
|
||||||
|
void initModule(PyObject* module) {
|
||||||
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
|
||||||
|
// Create a `torch._C._functionalization` Python module.
|
||||||
|
auto functionalization = m.def_submodule(
|
||||||
|
"_functionalization", "functionalization related pybind.");
|
||||||
|
|
||||||
|
// Retrieve the ViewMeta sequence of a given functional tensor.
|
||||||
|
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
at::functionalization::impl::isFunctionalTensor(tensor));
|
||||||
|
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
||||||
|
return impl->view_metas();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Applies the given ViewMeta sequence to the given base.
|
||||||
|
functionalization.def(
|
||||||
|
"apply_view_meta_sequence",
|
||||||
|
[](const at::Tensor& base,
|
||||||
|
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
|
||||||
|
sequence) {
|
||||||
|
return at::functionalization::impl::apply_view_meta_sequence(
|
||||||
|
base, sequence);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Binding for InverseReturnMode.
|
||||||
|
py::enum_<at::functionalization::InverseReturnMode>(
|
||||||
|
functionalization, "InverseReturnMode")
|
||||||
|
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
|
||||||
|
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
|
||||||
|
.value(
|
||||||
|
"ViewOrScatterInverse",
|
||||||
|
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
|
||||||
|
|
||||||
|
// Create bindings for the ViewMeta base class.
|
||||||
|
//
|
||||||
|
// Needed so that we can take a list of ViewMeta objects as parameter.
|
||||||
|
// Specifically, in the Python-side, we will have a list of derived ViewMeta
|
||||||
|
// classes. We need to tell pybind11 that all of those are, in fact, instances
|
||||||
|
// of different ViewMeta sub-types.
|
||||||
|
py::class_<
|
||||||
|
at::functionalization::ViewMeta,
|
||||||
|
std::shared_ptr<at::functionalization::ViewMeta>>(
|
||||||
|
functionalization, "ViewMeta")
|
||||||
|
.def_property_readonly(
|
||||||
|
"has_symbolic_inputs",
|
||||||
|
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
|
||||||
|
return meta->has_symbolic_inputs;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Bindings for `ViewMeta` specializations manually implemented.
|
||||||
|
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
|
||||||
|
functionalization);
|
||||||
|
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
|
||||||
|
functionalization);
|
||||||
|
|
||||||
|
// Bindings for `ViewMeta` specializations automatically generated.
|
||||||
|
initGenerated(functionalization.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::functionalization
|
36
torch/csrc/functionalization/Module.h
Normal file
36
torch/csrc/functionalization/Module.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/FunctionalStorageImpl.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
namespace torch::functionalization {
|
||||||
|
|
||||||
|
// Creates the default bindings for `ViewMeta` specializations.
|
||||||
|
//
|
||||||
|
// Defines a constructor using the types in `SerializableTuple`, as well
|
||||||
|
// as pickle methods.
|
||||||
|
template <class T>
|
||||||
|
void create_binding_with_pickle(py::module m) {
|
||||||
|
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
|
||||||
|
m, T::name())
|
||||||
|
.def(py::init<typename T::SerializableTuple>())
|
||||||
|
.def(
|
||||||
|
"as_tuple",
|
||||||
|
[](const std::shared_ptr<T>& meta) {
|
||||||
|
return meta->to_serializable_tuple();
|
||||||
|
})
|
||||||
|
.def(py::pickle(
|
||||||
|
[](const std::shared_ptr<T>& meta) {
|
||||||
|
return meta->to_serializable_tuple();
|
||||||
|
},
|
||||||
|
[](const typename T::SerializableTuple& tpl) {
|
||||||
|
return std::make_shared<T>(tpl);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void initModule(PyObject* module);
|
||||||
|
void initGenerated(PyObject* module);
|
||||||
|
|
||||||
|
} // namespace torch::functionalization
|
@ -213,7 +213,7 @@ _SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic)
|
|||||||
class SymIntEqByExpr:
|
class SymIntEqByExpr:
|
||||||
"""
|
"""
|
||||||
This is a wrapper around SymInt which has alternative semantics for
|
This is a wrapper around SymInt which has alternative semantics for
|
||||||
equality. Specifically, instead of erroring or guarding, we
|
equality and pickling. Specifically, instead of erroring or guarding, we
|
||||||
instead will hash/compare equality based on the underlying sympy
|
instead will hash/compare equality based on the underlying sympy
|
||||||
expression; e.g., s0 and s1 will always compare as False.
|
expression; e.g., s0 and s1 will always compare as False.
|
||||||
|
|
||||||
@ -222,31 +222,25 @@ class SymIntEqByExpr:
|
|||||||
canonicalize to the same expression via regular simplification.
|
canonicalize to the same expression via regular simplification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
val: Union[torch.SymInt, int]
|
@staticmethod
|
||||||
|
def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr:
|
||||||
|
if isinstance(val, torch.SymInt):
|
||||||
|
return val.node.expr
|
||||||
|
else:
|
||||||
|
return sympy.Integer(val)
|
||||||
|
|
||||||
def __init__(self, val: Union[torch.SymInt, int]) -> None:
|
def __init__(self, val: Union[torch.SymInt, int]) -> None:
|
||||||
self.val = val
|
self.val: sympy.Expr = SymIntEqByExpr._extract(val)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return repr(self.val)
|
return repr(self.val)
|
||||||
|
|
||||||
def _extract(self) -> sympy.Expr:
|
|
||||||
if isinstance(self.val, torch.SymInt):
|
|
||||||
return self.val.node.expr
|
|
||||||
else:
|
|
||||||
return sympy.Integer(self.val)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
assert isinstance(other, SymIntEqByExpr)
|
assert isinstance(other, SymIntEqByExpr)
|
||||||
|
|
||||||
# int equality fastpath
|
|
||||||
if type(self.val) is int and type(other.val) is int:
|
|
||||||
return self.val == other.val
|
return self.val == other.val
|
||||||
|
|
||||||
return self._extract() == other._extract()
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self._extract())
|
return hash(self.val)
|
||||||
|
|
||||||
|
|
||||||
def _nested_int_aware_sort(
|
def _nested_int_aware_sort(
|
||||||
|
@ -23,20 +23,13 @@ from torchgen.model import (
|
|||||||
|
|
||||||
|
|
||||||
# This file describes the translation of JIT schema to API's used
|
# This file describes the translation of JIT schema to API's used
|
||||||
# when creating view lambdas that are used by the functionalization pass.
|
# when creating `ViewMeta` specializations that are used by the functionalization pass.
|
||||||
# There are two types of lambdas: forward lambdas and reverse lambdas.
|
# These API's mostly follow the dispatcher API, with one difference:
|
||||||
# These API's mostly follow the dispatcher API, with a few quirks:
|
# - While the forward function just directly calls into the at::_ops API
|
||||||
# - The lambda capture has to convert reference types to value types
|
# (following the dispatcher convention), the logic here for the reverse function
|
||||||
# - While the forward lambda just directly calls into the at::_ops API
|
|
||||||
# (following the dispatcher convention), the logic here for the reverse lambda
|
|
||||||
# is responsible for generating both the call-site, and the declarations
|
# is responsible for generating both the call-site, and the declarations
|
||||||
# (which are implemented manually in the at::functionalization::impl namespace).
|
# (which are implemented manually in the at::functionalization::impl namespace).
|
||||||
|
|
||||||
# The lambdas generated for each view op in the functionalization pass are of the form
|
|
||||||
# [capture_arguments](outer_arguments) -> returns_type {
|
|
||||||
# return name(inner_arguments);
|
|
||||||
# }
|
|
||||||
|
|
||||||
# Define some specific lambda input arguments.
|
# Define some specific lambda input arguments.
|
||||||
base_binding = Binding(
|
base_binding = Binding(
|
||||||
name="base",
|
name="base",
|
||||||
@ -46,6 +39,18 @@ base_binding = Binding(
|
|||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_symbolic_inputs_binding = Binding(
|
||||||
|
name="has_symbolic_inputs",
|
||||||
|
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
|
||||||
|
argument=Argument(
|
||||||
|
name="has_symbolic_inputs",
|
||||||
|
type=BaseType(BaseTy.bool),
|
||||||
|
default=None,
|
||||||
|
annotation=None,
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
mutated_view_binding = Binding(
|
mutated_view_binding = Binding(
|
||||||
name="mutated_view",
|
name="mutated_view",
|
||||||
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
||||||
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
|
|||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
mutated_view_idx_binding = Binding(
|
out_index_binding = Binding(
|
||||||
name="mutated_view_idx",
|
name="out_index",
|
||||||
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
|
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
|
||||||
argument=Argument(
|
argument=Argument(
|
||||||
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
|
||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# The lambda capture itself doesn't have a name.
|
# Name of the `ViewMeta` specialization class created.
|
||||||
# The name returned here corresponds to the name of the inner function called by the lambda.
|
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
|
||||||
|
namespace = "at::functionalization::" if with_namespace else ""
|
||||||
|
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
|
||||||
|
|
||||||
|
|
||||||
|
# Name of the operation called inside the `forward`/`reverse` implementations.
|
||||||
def name(
|
def name(
|
||||||
g: NativeFunctionsViewGroup,
|
g: NativeFunctionsViewGroup,
|
||||||
*,
|
*,
|
||||||
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
|||||||
return f"{api_name}_inverse"
|
return f"{api_name}_inverse"
|
||||||
|
|
||||||
|
|
||||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
|
||||||
# capture arguments include all arguments except `self`.
|
|
||||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
|
||||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
|
||||||
args = func.arguments.flat_all
|
|
||||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
|
||||||
non_self_args = args[1:]
|
|
||||||
non_self_value_bindings = [
|
|
||||||
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
|
||||||
]
|
|
||||||
|
|
||||||
all_bindings = [
|
|
||||||
inverse_return_mode_binding if is_reverse else reapply_views_binding
|
|
||||||
]
|
|
||||||
all_bindings.extend(non_self_value_bindings)
|
|
||||||
return all_bindings
|
|
||||||
|
|
||||||
|
|
||||||
def returns_type(func: FunctionSchema) -> CType:
|
def returns_type(func: FunctionSchema) -> CType:
|
||||||
# Assertion: all view ops return tensor-like outputs
|
# Assertion: all view ops return tensor-like outputs
|
||||||
assert len(func.returns) >= 1
|
assert len(func.returns) >= 1
|
||||||
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
|
|||||||
return BaseCType(tensorT)
|
return BaseCType(tensorT)
|
||||||
|
|
||||||
|
|
||||||
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
# Checks whether `func` might return more than one value.
|
||||||
if is_reverse:
|
def is_multi_output(func: FunctionSchema) -> bool:
|
||||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
return len(func.returns) > 1 or (
|
||||||
else:
|
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
|
||||||
return [base_binding, mutated_view_idx_binding]
|
)
|
||||||
|
|
||||||
|
|
||||||
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
# `ViewMeta` specialization constructor parameters.
|
||||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
# All specializations are parematerized by `has_symbolic_inputs` flag.
|
||||||
if len(func.returns) > 1 or (
|
arguments = [has_symbolic_inputs_binding]
|
||||||
len(func.returns) == 1 and func.returns[0].type.is_list_like()
|
|
||||||
):
|
# If `func` might return more than 1 value, we also parameterize this specialization
|
||||||
return mutated_view_idx_binding
|
# with the output index.
|
||||||
return None
|
if is_multi_output(func):
|
||||||
|
arguments.append(out_index_binding)
|
||||||
|
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
|
||||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
# `ViewMeta` specialized class' constructor arguments.
|
||||||
|
#
|
||||||
|
# Values needed specifically by this specialization, that the base class does not need.
|
||||||
|
# Same as the class' attributes, but non-owning.
|
||||||
|
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||||
|
return attributes(func, owning=False)
|
||||||
|
|
||||||
|
|
||||||
|
# `ViewMeta` specialized class' non-static member data.
|
||||||
|
#
|
||||||
|
# Essential data for calling the instance's `forward` and `reverse functions. You can
|
||||||
|
# think of them as values that should be captured from the functionalization kernel.
|
||||||
|
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
|
||||||
|
args = func.arguments.flat_all
|
||||||
|
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||||
|
return [
|
||||||
|
reapply_views_binding,
|
||||||
|
inverse_return_mode_binding,
|
||||||
|
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||||
args = func.arguments.flat_all
|
args = func.arguments.flat_all
|
||||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||||
non_self_args = args[1:]
|
non_self_args = args[1:]
|
||||||
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
|||||||
# the reverse lambda does the same, but with an additional "mutated_view" arg
|
# the reverse lambda does the same, but with an additional "mutated_view" arg
|
||||||
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
|
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
|
||||||
# their corresponding view_inverse function takes in an additional index argument.
|
# their corresponding view_inverse function takes in an additional index argument.
|
||||||
index_binding = inner_call_index(func)
|
if is_multi_output(func):
|
||||||
if index_binding is not None:
|
|
||||||
return [
|
return [
|
||||||
base_binding,
|
base_binding,
|
||||||
mutated_view_binding,
|
mutated_view_binding,
|
||||||
inverse_return_mode_binding,
|
inverse_return_mode_binding,
|
||||||
index_binding,
|
out_index_binding,
|
||||||
] + non_self_bindings
|
] + non_self_bindings
|
||||||
else:
|
else:
|
||||||
return [
|
return [
|
||||||
|
@ -300,83 +300,11 @@ class ViewInverseSignature:
|
|||||||
return_type = functionalization.returns_type(self.g.view.func)
|
return_type = functionalization.returns_type(self.g.view.func)
|
||||||
decls = [
|
decls = [
|
||||||
a.decl()
|
a.decl()
|
||||||
for a in functionalization.inner_arguments(
|
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
|
||||||
self.g.view.func, is_reverse=True
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
|
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FunctionalizationLambda:
|
|
||||||
g: NativeFunctionsViewGroup
|
|
||||||
|
|
||||||
# are we generating the forward lambda or the reverse lambda?
|
|
||||||
is_reverse: bool
|
|
||||||
|
|
||||||
def captures(self) -> list[Expr]:
|
|
||||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
|
||||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
|
||||||
# and plumb it into the lambda.
|
|
||||||
outer_ctx = dispatcher.arguments(self.g.view.func) + [
|
|
||||||
functionalization.reapply_views_binding,
|
|
||||||
functionalization.inverse_return_mode_binding,
|
|
||||||
]
|
|
||||||
capture_bindings = functionalization.capture_arguments(
|
|
||||||
self.g.view.func, is_reverse=self.is_reverse
|
|
||||||
)
|
|
||||||
# allow_expensive_conversions is set because we want to convert
|
|
||||||
# some reference types (IntArrayRef) to value types (vector<int64_t>).
|
|
||||||
capture_exprs = translate.translate(
|
|
||||||
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
|
|
||||||
)
|
|
||||||
return capture_exprs
|
|
||||||
|
|
||||||
def decl(self) -> str:
|
|
||||||
return_type = functionalization.returns_type(self.g.view.func)
|
|
||||||
capture_str = ", ".join(
|
|
||||||
f"{val.type.name} = {val.expr}" for val in self.captures()
|
|
||||||
)
|
|
||||||
decls = [
|
|
||||||
a.decl()
|
|
||||||
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
|
|
||||||
]
|
|
||||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
|
||||||
|
|
||||||
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
|
||||||
inner_call_name = functionalization.name(
|
|
||||||
self.g,
|
|
||||||
is_reverse=self.is_reverse,
|
|
||||||
include_namespace=True,
|
|
||||||
reapply_views=reapply_views,
|
|
||||||
)
|
|
||||||
|
|
||||||
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
|
|
||||||
capture_ctx = functionalization.capture_arguments(
|
|
||||||
self.g.view.func, is_reverse=self.is_reverse
|
|
||||||
)
|
|
||||||
full_ctx = arg_ctx + capture_ctx
|
|
||||||
|
|
||||||
assert self.g.view_copy is not None
|
|
||||||
call_bindings = functionalization.inner_arguments(
|
|
||||||
self.g.view_copy.func, is_reverse=self.is_reverse
|
|
||||||
)
|
|
||||||
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
|
|
||||||
call_exprs = [
|
|
||||||
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
|
|
||||||
]
|
|
||||||
if not self.is_reverse and maybe_index is not None:
|
|
||||||
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
|
|
||||||
else:
|
|
||||||
return f"{inner_call_name}({', '.join(call_exprs)});"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_func(
|
|
||||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
|
||||||
) -> FunctionalizationLambda:
|
|
||||||
return FunctionalizationLambda(g, is_reverse)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class StructuredImplSignature:
|
class StructuredImplSignature:
|
||||||
g: NativeFunctionsGroup
|
g: NativeFunctionsGroup
|
||||||
|
@ -43,6 +43,8 @@ from torchgen.gen_functionalization_type import (
|
|||||||
gen_functionalization_definition,
|
gen_functionalization_definition,
|
||||||
gen_functionalization_registration,
|
gen_functionalization_registration,
|
||||||
gen_functionalization_view_inverse_declaration,
|
gen_functionalization_view_inverse_declaration,
|
||||||
|
gen_functionalization_view_meta_classes_decl,
|
||||||
|
gen_functionalization_view_meta_classes_impl,
|
||||||
GenCompositeViewCopyKernel,
|
GenCompositeViewCopyKernel,
|
||||||
)
|
)
|
||||||
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
|
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
|
||||||
@ -2493,9 +2495,6 @@ def gen_source_files(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def functionalization_env_callable(
|
|
||||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
|
||||||
) -> dict[str, list[str]]:
|
|
||||||
def gen_op_headers(
|
def gen_op_headers(
|
||||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@ -2535,6 +2534,9 @@ def gen_source_files(
|
|||||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def functionalization_env_callable(
|
||||||
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
return {
|
return {
|
||||||
"ops_headers": gen_op_headers(g),
|
"ops_headers": gen_op_headers(g),
|
||||||
"func_definitions": gen_functionalization_definition(
|
"func_definitions": gen_functionalization_definition(
|
||||||
@ -2600,6 +2602,31 @@ def gen_source_files(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cpu_fm.write(
|
||||||
|
"ViewMetaClasses.h",
|
||||||
|
lambda: {
|
||||||
|
"view_meta_declarations": list(
|
||||||
|
concatMap(
|
||||||
|
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
|
||||||
|
view_groups,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
cpu_fm.write(
|
||||||
|
"ViewMetaClasses.cpp",
|
||||||
|
lambda: {
|
||||||
|
"view_meta_implementations": list(
|
||||||
|
concatMap(
|
||||||
|
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
|
||||||
|
view_groups,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"op_headers": list(concatMap(gen_op_headers, view_groups)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Note [view_copy NativeFunctions]
|
# Note [view_copy NativeFunctions]
|
||||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||||
# needs to have a corresponding non-aliasing {view}_copy variant.
|
# needs to have a corresponding non-aliasing {view}_copy variant.
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, TYPE_CHECKING
|
from typing import Callable, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp, dispatcher
|
from torchgen.api import cpp, dispatcher, functionalization
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
BaseCType,
|
BaseCType,
|
||||||
Binding,
|
Binding,
|
||||||
CType,
|
CType,
|
||||||
DispatcherSignature,
|
DispatcherSignature,
|
||||||
FunctionalizationLambda,
|
|
||||||
iTensorListRefT,
|
iTensorListRefT,
|
||||||
NativeSignature,
|
NativeSignature,
|
||||||
OptionalCType,
|
OptionalCType,
|
||||||
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
|
|||||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||||
)
|
)
|
||||||
from torchgen.utils import dataclass_repr
|
from torchgen.utils import concatMap, dataclass_repr, FileManager
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
|
|||||||
with native_function_manager(f):
|
with native_function_manager(f):
|
||||||
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
|
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
|
||||||
|
|
||||||
|
spec = ViewMetaSpecialization(g, f=f)
|
||||||
|
|
||||||
# the "view_copy" op name that the functionalization kernels need to call
|
# the "view_copy" op name that the functionalization kernels need to call
|
||||||
api_name = g.view_copy.func.name.unambiguous_name()
|
api_name = g.view_copy.func.name.unambiguous_name()
|
||||||
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
|
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
|
||||||
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
|
|||||||
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
|
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
|
|
||||||
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
|
|
||||||
|
|
||||||
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
|
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
|
||||||
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
||||||
meta_call_args = [
|
meta_call_args = [
|
||||||
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
|
|||||||
: at::functionalization::InverseReturnMode::NeverView
|
: at::functionalization::InverseReturnMode::NeverView
|
||||||
);
|
);
|
||||||
{symbolic_inputs_check}
|
{symbolic_inputs_check}
|
||||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
auto view_meta = {spec.new()};
|
||||||
{forward_lambda.decl()} {{
|
|
||||||
if (reapply_views) {{
|
|
||||||
return {forward_lambda.inner_call(reapply_views=True)}
|
|
||||||
}} else {{
|
|
||||||
return {forward_lambda.inner_call(reapply_views=False)}
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
{reverse_lambda.decl()} {{
|
|
||||||
return {reverse_lambda.inner_call()}
|
|
||||||
}},
|
|
||||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
|
|
||||||
);
|
|
||||||
auto compute_reference_meta =
|
auto compute_reference_meta =
|
||||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
||||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
||||||
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
|
|
||||||
return f"""
|
return f"""
|
||||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||||
{unwrap_tensor_args_str}
|
{unwrap_tensor_args_str}
|
||||||
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
|
|||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
{symbolic_inputs_check}
|
{symbolic_inputs_check}
|
||||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
auto view_meta = {spec.new()};
|
||||||
{forward_lambda.decl()} {{
|
|
||||||
if (reapply_views) {{
|
|
||||||
return {forward_lambda.inner_call(reapply_views=True)}
|
|
||||||
}} else {{
|
|
||||||
return {forward_lambda.inner_call(reapply_views=False)}
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
{reverse_lambda.decl()} {{
|
|
||||||
return {reverse_lambda.inner_call()}
|
|
||||||
}},
|
|
||||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
|
|
||||||
/*is_multi_output=*/{str(is_multi_output_view).lower()},
|
|
||||||
/*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
|
|
||||||
);
|
|
||||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
||||||
// See Note [Propagating strides in the functionalization pass]
|
// See Note [Propagating strides in the functionalization pass]
|
||||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||||
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
|
|||||||
return emit_decl_helper(g)
|
return emit_decl_helper(g)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper class for generating `ViewMeta` specializations.
|
||||||
|
@dataclass
|
||||||
|
class ViewMetaSpecialization:
|
||||||
|
g: NativeFunctionsViewGroup
|
||||||
|
f: NativeFunction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multi_output(self) -> bool:
|
||||||
|
return functionalization.is_multi_output(self.f.func)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_as_strided(self) -> bool:
|
||||||
|
return str(self.f.func.name) == "as_strided"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_index(self) -> str:
|
||||||
|
if self.is_multi_output:
|
||||||
|
return functionalization.out_index_binding.name
|
||||||
|
return "0"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classname(self) -> str:
|
||||||
|
return functionalization.classname(self.f.func)
|
||||||
|
|
||||||
|
def decl(self) -> list[str]:
|
||||||
|
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
|
||||||
|
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
|
||||||
|
attributes = functionalization.attributes(self.f.func)
|
||||||
|
|
||||||
|
# List of types for declaring the `SerializableTuple` type.
|
||||||
|
serializable_tuple_args = ",\n".join(
|
||||||
|
f" {binding.type} /* {binding.name} */"
|
||||||
|
for binding in (base_ctor_arguments + attributes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Arguments used for forwarding the tuple elements to the constructor.
|
||||||
|
destructure_tuple_args = ", ".join(
|
||||||
|
f"std::get<{i}>(tpl)"
|
||||||
|
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
|
||||||
|
)
|
||||||
|
|
||||||
|
# List of constructor parameters
|
||||||
|
ctor_parameters = ", ".join(
|
||||||
|
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the base class `ViewMeta` constructor.
|
||||||
|
#
|
||||||
|
# Both of `is_multi_output` and `is_as_strided` are known values, given the
|
||||||
|
# operation schema.
|
||||||
|
is_multi_output_str = str(self.is_multi_output).lower()
|
||||||
|
is_as_strided_str = str(self.is_as_strided).lower()
|
||||||
|
|
||||||
|
base_ctor_bindings = ", ".join(
|
||||||
|
[
|
||||||
|
# `has_symbolic_inputs` is always taken as parameter.
|
||||||
|
functionalization.has_symbolic_inputs_binding.name,
|
||||||
|
f"/*is_multi_output=*/{is_multi_output_str}",
|
||||||
|
f"/*is_as_strided=*/{is_as_strided_str}",
|
||||||
|
# `out_index` is know if the operation returns only one value. Otherwise,
|
||||||
|
# we also take it as parameter.
|
||||||
|
f"/*out_index=*/{self.out_index}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assignments of `extra_ctor_arguments` to their corresponding fields.
|
||||||
|
# These are extra fields to-be-declared in this specialization.
|
||||||
|
#
|
||||||
|
# We need to set `allow_expensive_conversions`, since we are storing owned versions
|
||||||
|
# of the non-owning arguments.
|
||||||
|
ctor_assignments = ",\n".join(
|
||||||
|
f" {e.type.name}({e.expr})"
|
||||||
|
for e in translate(
|
||||||
|
extra_ctor_arguments,
|
||||||
|
attributes,
|
||||||
|
method=False,
|
||||||
|
allow_expensive_conversions=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# List of arguments for constructing the `SerializableTuple` from an instance.
|
||||||
|
tuple_arguments = ", ".join(
|
||||||
|
binding.name for binding in (base_ctor_arguments + attributes)
|
||||||
|
)
|
||||||
|
|
||||||
|
# List of field declarations.
|
||||||
|
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
|
||||||
|
|
||||||
|
# Override `to_out_index` if this operation returns more than 1 value.
|
||||||
|
to_out_index_decl = ""
|
||||||
|
if self.is_multi_output:
|
||||||
|
to_out_index_decl = (
|
||||||
|
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
f"""
|
||||||
|
struct TORCH_API {self.classname} : public ViewMeta {{
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname})
|
||||||
|
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
|
||||||
|
|
||||||
|
{self.classname}(const SerializableTuple& tpl)
|
||||||
|
: {self.classname}({destructure_tuple_args}) {{}}
|
||||||
|
|
||||||
|
{self.classname}({ctor_parameters})
|
||||||
|
: at::functionalization::ViewMeta({base_ctor_bindings}),
|
||||||
|
{ctor_assignments} {{}}
|
||||||
|
|
||||||
|
Tensor forward(const Tensor& base) override;
|
||||||
|
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||||
|
{to_out_index_decl}
|
||||||
|
|
||||||
|
SerializableTuple to_serializable_tuple() {{
|
||||||
|
return std::make_tuple({tuple_arguments});
|
||||||
|
}}
|
||||||
|
|
||||||
|
{attr_declarations}
|
||||||
|
}};
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate a call to the actual operation.
|
||||||
|
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
|
||||||
|
opname = functionalization.name(
|
||||||
|
self.g,
|
||||||
|
is_reverse=is_reverse,
|
||||||
|
include_namespace=True,
|
||||||
|
reapply_views=reapply_views,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected arguments for the operation.
|
||||||
|
assert self.g.view_copy is not None
|
||||||
|
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
|
||||||
|
|
||||||
|
# The context is composed by the constructor arguments (which are also
|
||||||
|
# the field variables stored in the instance), and the `base` tensor.
|
||||||
|
context = [functionalization.base_binding]
|
||||||
|
context += functionalization.base_ctor_arguments(self.f.func)
|
||||||
|
context += functionalization.attributes(self.f.func)
|
||||||
|
|
||||||
|
# If we are generating the call for the reverse function, we also have
|
||||||
|
# access to `mutated_view` argument.
|
||||||
|
if is_reverse:
|
||||||
|
context.append(functionalization.mutated_view_binding)
|
||||||
|
|
||||||
|
arguments = ", ".join(
|
||||||
|
[e.expr for e in translate(context, op_arguments, method=False)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index the result if this operation returns multiple values.
|
||||||
|
maybe_index = ""
|
||||||
|
if not is_reverse and self.is_multi_output:
|
||||||
|
maybe_index = f"[{self.out_index}]"
|
||||||
|
|
||||||
|
return f"{opname}({arguments}){maybe_index}"
|
||||||
|
|
||||||
|
def impl(self) -> list[str]:
|
||||||
|
functions = [
|
||||||
|
f"""
|
||||||
|
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
|
||||||
|
if (reapply_views) {{
|
||||||
|
return {self.opcall(is_reverse=False, reapply_views=True)};
|
||||||
|
}} else {{
|
||||||
|
return {self.opcall(is_reverse=False, reapply_views=False)};
|
||||||
|
}}
|
||||||
|
}}""",
|
||||||
|
f"""
|
||||||
|
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
|
||||||
|
return {self.opcall(is_reverse=True, reapply_views=True)};
|
||||||
|
}}""",
|
||||||
|
]
|
||||||
|
|
||||||
|
# If this operation returns multiple values, also generate a `to_out_index`
|
||||||
|
# implementation.
|
||||||
|
if self.is_multi_output:
|
||||||
|
functions.append(f"""
|
||||||
|
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
|
||||||
|
return {self.new("out_index")};
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
|
||||||
|
return functions
|
||||||
|
|
||||||
|
# Create the Python binding for this specialized class.
|
||||||
|
def binding(self) -> list[str]:
|
||||||
|
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||||
|
return [f" create_binding_with_pickle<{name}>(functionalization);"]
|
||||||
|
|
||||||
|
# Generate an instantiation of this specialized class.
|
||||||
|
def new(self, out_index: str = "0") -> str:
|
||||||
|
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||||
|
ctor_arguments = functionalization.base_ctor_arguments(
|
||||||
|
self.f.func
|
||||||
|
) + functionalization.extra_ctor_arguments(self.f.func)
|
||||||
|
# Replace the `out_index` parameter with the given `out_index`.
|
||||||
|
arguments = ", ".join(
|
||||||
|
binding.name if binding.name != "out_index" else out_index
|
||||||
|
for binding in ctor_arguments
|
||||||
|
)
|
||||||
|
return f"std::make_shared<{name}>({arguments})"
|
||||||
|
|
||||||
|
# Run the function `run` for both: `view` and `view_inplace` functions.
|
||||||
|
@staticmethod
|
||||||
|
def map(
|
||||||
|
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
|
||||||
|
) -> list[str]:
|
||||||
|
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
|
||||||
|
if f is None:
|
||||||
|
return []
|
||||||
|
with native_function_manager(f):
|
||||||
|
return run(ViewMetaSpecialization(g, f))
|
||||||
|
|
||||||
|
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
|
||||||
|
|
||||||
|
|
||||||
|
def gen_functionalization_view_meta_classes_base(
|
||||||
|
selector: SelectiveBuilder,
|
||||||
|
g: NativeFunctionsViewGroup,
|
||||||
|
run: Callable[[ViewMetaSpecialization], list[str]],
|
||||||
|
) -> list[str]:
|
||||||
|
if not selector.include_all_operators:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if g.composite:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return ViewMetaSpecialization.map(g, run)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_functionalization_view_meta_classes_decl(
|
||||||
|
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||||
|
) -> list[str]:
|
||||||
|
return gen_functionalization_view_meta_classes_base(
|
||||||
|
selector, g, ViewMetaSpecialization.decl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_functionalization_view_meta_classes_impl(
|
||||||
|
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||||
|
) -> list[str]:
|
||||||
|
return gen_functionalization_view_meta_classes_base(
|
||||||
|
selector, g, ViewMetaSpecialization.impl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_functionalization_view_meta_classes_binding(
|
||||||
|
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||||
|
) -> list[str]:
|
||||||
|
return gen_functionalization_view_meta_classes_base(
|
||||||
|
selector, g, ViewMetaSpecialization.binding
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Generates the Python bindings for the `ViewMeta` specialized classes.
|
||||||
|
def gen_functionalization_view_meta_classes(
|
||||||
|
native_functions_path: str,
|
||||||
|
tags_path: str,
|
||||||
|
selector: SelectiveBuilder,
|
||||||
|
install_dir: str,
|
||||||
|
template_dir: str,
|
||||||
|
) -> None:
|
||||||
|
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
|
||||||
|
|
||||||
|
# Parse the native_functions.yaml.
|
||||||
|
# Then, group them into `NativeFunctionsViewGroup`.
|
||||||
|
#
|
||||||
|
# This is the same steps we do in gen.py (ATen codegen).
|
||||||
|
native_functions = parse_native_yaml(
|
||||||
|
native_functions_path, tags_path
|
||||||
|
).native_functions
|
||||||
|
native_functions_with_view_groups = get_grouped_by_view_native_functions(
|
||||||
|
native_functions
|
||||||
|
)
|
||||||
|
view_groups = [
|
||||||
|
g
|
||||||
|
for g in native_functions_with_view_groups
|
||||||
|
if isinstance(g, NativeFunctionsViewGroup)
|
||||||
|
]
|
||||||
|
|
||||||
|
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
|
||||||
|
fm.write(
|
||||||
|
"ViewMetaClassesPythonBinding.cpp",
|
||||||
|
lambda: {
|
||||||
|
"view_meta_bindings": list(
|
||||||
|
concatMap(
|
||||||
|
lambda g: gen_functionalization_view_meta_classes_binding(
|
||||||
|
selector, g
|
||||||
|
),
|
||||||
|
view_groups,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gen_functionalization_registration(
|
def gen_functionalization_registration(
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
|
Reference in New Issue
Block a user