mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make functionalization ViewMeta
serializable with pickle. (#143712)
Fix: #141974 This PR makes `ViewMeta` sequence, present in functional tensors, serializable with pickle. In order to accomplish that, it makes `ViewMeta` an abstract class with overridable `forward` and `reverse` functions. In this context, each operation that once instanciated `ViewMeta`, should now create a new specialized class that inherits from `ViewMeta. Therefore, this PR also uses codegen for creating these specializations. In summary, these are the changes this PR introduces: - `ViewMeta` is turned into an abstract class (see _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual functions that need to be implemented. `to_out_index` should be implemented by operations that might return more than 1 output. - New `ViewMeta` specializations for `resize_` and `_unsafe_view` are created (see _FunctionalizeFallbackKernel.h_). - New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the declaration and definition of the `ViewMeta` specializations, which are automatically generated in the ATen codegen (see _gen.py_). - New `_functionalization` Python sub-module is created (see _Module.cpp_). It serves as namespace for the `ViewMeta` specializations and `InverseReturnMode` enum. - New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds the automatically generated Python bindings for the `ViewMeta` specialization, which are generated in the torch codegen (see _generate_code.py_). Note that this PR makes use of codegen at 2 different moments: - ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes. - Torch codegen (_generate_code.py_): generated the Python bindings for them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7c3aa1da1c
commit
b8abdaa286
1
.gitignore
vendored
1
.gitignore
vendored
@ -79,6 +79,7 @@ torch/return_types.pyi
|
||||
torch/nn/functional.pyi
|
||||
torch/utils/data/datapipes/datapipe.pyi
|
||||
torch/csrc/autograd/generated/*
|
||||
torch/csrc/functionalization/generated/*
|
||||
torch/csrc/lazy/generated/*.[!m]*
|
||||
torch_compile_debug/
|
||||
# Listed manually because some files in this directory are not generated
|
||||
|
@ -90,6 +90,8 @@ generated_cpu_cpp = [
|
||||
"aten/src/ATen/NativeMetaFunctions.h",
|
||||
"aten/src/ATen/RegistrationDeclarations.h",
|
||||
"aten/src/ATen/VmapGeneratedPlumbing.h",
|
||||
"aten/src/ATen/ViewMetaClasses.h",
|
||||
"aten/src/ATen/ViewMetaClasses.cpp",
|
||||
"aten/src/ATen/core/aten_interned_strings.h",
|
||||
"aten/src/ATen/core/enum_tag.h",
|
||||
"aten/src/ATen/core/TensorBody.h",
|
||||
@ -1087,6 +1089,7 @@ test_suite(
|
||||
"aten/src/ATen/templates/LazyNonNativeIr.h",
|
||||
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
||||
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
|
||||
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"aten/src/ATen/native/ts_native_functions.yaml",
|
||||
|
@ -9,11 +9,6 @@
|
||||
|
||||
namespace at::functionalization {
|
||||
|
||||
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
|
||||
if (out_idx == this->out_index) return *this;
|
||||
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
|
||||
}
|
||||
|
||||
// Note [Functionalization: Alias Removal Part 2]
|
||||
// See Note [Functionalization: Alias Removal] for more details.
|
||||
// This function applies a single update from one of the views to the StorageImpl.
|
||||
@ -47,7 +42,7 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
|
||||
std::vector<at::Tensor> tmp_values({base});
|
||||
tmp_values.reserve(update.view_metas.size());
|
||||
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
|
||||
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
|
||||
at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
|
||||
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
|
||||
// All of these ops require additional information to recover the sizes of the original tensor.
|
||||
// If need to, we could probably apply this optimization and only bother computing tmp_values
|
||||
@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
|
||||
tmp_values.push_back(std::move(next_view));
|
||||
}
|
||||
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
|
||||
int64_t out_idx = update.view_metas[i].out_index;
|
||||
// Each view inverse is implemented in ViewInverses.cpp.
|
||||
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
|
||||
t = update.view_metas[i]->reverse(tmp_values[i], t);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||
return t;
|
||||
@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
|
||||
}
|
||||
|
||||
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
|
||||
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
|
||||
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
|
||||
|
||||
if (metas.size() > 1) {
|
||||
for (size_t i = 1; i < metas.size(); ++i) {
|
||||
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
|
||||
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
|
||||
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
|
||||
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
|
||||
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
|
||||
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
|
||||
|
@ -8,44 +8,89 @@ namespace at::functionalization {
|
||||
|
||||
// See Note [Functionalization Pass In Core]
|
||||
|
||||
enum class InverseReturnMode {
|
||||
/// Specifies that functional inverses should always return a view.
|
||||
AlwaysView,
|
||||
/// Specifies that functional inverses should always return a non-view / copy.
|
||||
NeverView,
|
||||
/// Specifies that functional inverses should return a view unless a (copying)
|
||||
/// scatter
|
||||
/// inverse exists, in which case that will be used instead.
|
||||
/// This avoids as_strided() calls that can be difficult for subclasses to
|
||||
/// handle.
|
||||
ViewOrScatterInverse,
|
||||
};
|
||||
|
||||
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
|
||||
static const char* name() { \
|
||||
return #TYPE; \
|
||||
}
|
||||
|
||||
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
|
||||
using SerializableTuple = std::tuple<__VA_ARGS__>;
|
||||
|
||||
// ViewMeta is a class used by the functionalization pass to navigate between
|
||||
// a base tensor and a view tensor.
|
||||
// For example, if I call `b = a.view1(...)`
|
||||
// the functionalization pass will generate and store a ViewMeta on b that looks
|
||||
// like:
|
||||
// the functionalization pass will generate and store a ViewMeta specialization
|
||||
// for `view1` operation on b that looks like:
|
||||
//
|
||||
// ViewMeta(
|
||||
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
|
||||
// return base.view1(...);
|
||||
// },
|
||||
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
|
||||
// int64_t mutated_view_idx) -> at::Tensor {
|
||||
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
||||
// ...);
|
||||
// struct TORCH_API view1_ViewMeta : public ViewMeta {
|
||||
// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
|
||||
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||
// bool /* reapply_views */,
|
||||
// const std::vector<int64_t>&);
|
||||
//
|
||||
// view1_ViewMeta(const SerializableTuple& tpl)
|
||||
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||
//
|
||||
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
|
||||
// : ViewMeta(/*has_symbolic_inputs=*/false),
|
||||
// reapply_views(reapply_views),
|
||||
// size(size) {}
|
||||
//
|
||||
// Tensor forward(const Tensor& base) override {
|
||||
// return base.view1(...);
|
||||
// }
|
||||
//
|
||||
// The forward_fn lambda describes how to replay view1 on a tensor.
|
||||
// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
|
||||
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
||||
// ...);
|
||||
// }
|
||||
//
|
||||
// The reverse_fn lambda describes how, given a tensor that is already a view,
|
||||
// SerializableTuple to_serializable_tuple() {
|
||||
// return std::make_tuple(reapply_views, size);
|
||||
// }
|
||||
//
|
||||
// bool reapply_views;
|
||||
// std::vector<int64_t> size;
|
||||
// };
|
||||
//
|
||||
// The forward function describes how to replay view1 on a tensor.
|
||||
//
|
||||
// The reverse function describes how, given a tensor that is already a view,
|
||||
// how to get the corresponding base tensor. See Note [Functionalization Pass:
|
||||
// View Inverses] for details.
|
||||
//
|
||||
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
|
||||
// representing the `ViewMeta` instance state. Methods that take in/return such
|
||||
// a type are used for supporting pickle serialization.
|
||||
struct ViewMeta {
|
||||
ViewMeta(
|
||||
std::function<Tensor(const Tensor&, int64_t)> forward,
|
||||
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
|
||||
bool has_symbolic_inputs,
|
||||
bool is_multi_output = false,
|
||||
bool is_as_strided = false,
|
||||
int64_t out_idx = 0)
|
||||
: forward_fn(std::move(forward)),
|
||||
reverse_fn(std::move(reverse)),
|
||||
out_index(out_idx),
|
||||
: out_index(out_idx),
|
||||
is_multi_output(is_multi_output),
|
||||
is_as_strided(is_as_strided),
|
||||
has_symbolic_inputs(has_symbolic_inputs) {}
|
||||
|
||||
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
|
||||
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
|
||||
virtual ~ViewMeta() {}
|
||||
|
||||
virtual Tensor forward(const Tensor& base) = 0;
|
||||
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
|
||||
|
||||
// See Note [out_idx in ViewMeta]
|
||||
int64_t out_index;
|
||||
|
||||
@ -57,10 +102,17 @@ struct ViewMeta {
|
||||
// Tells us if this view operation has any symbolic inputs
|
||||
bool has_symbolic_inputs;
|
||||
|
||||
// Returns a copy of the current ViewMeta, if out_idx matches the current
|
||||
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
|
||||
// Returns a new ViewMeta with the same forward/reverse
|
||||
// functions, but a new out index.
|
||||
ViewMeta to_out_idx(int64_t out_idx);
|
||||
//
|
||||
// This method should be implemented by those `ViewMeta` that have more than
|
||||
// one output.
|
||||
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"ViewMeta::to_out_index not implemented. ",
|
||||
"Likely because there's only one output.");
|
||||
}
|
||||
};
|
||||
|
||||
// FunctionalStorageImpl is a subclass of StorageImpl used by the
|
||||
@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const at::Tensor new_val;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const std::vector<ViewMeta> view_metas;
|
||||
const std::vector<std::shared_ptr<ViewMeta>> view_metas;
|
||||
};
|
||||
|
||||
explicit FunctionalStorageImpl(const Tensor& value);
|
||||
|
||||
void add_update(
|
||||
const Tensor& updated_val,
|
||||
const std::vector<ViewMeta>& view_metas);
|
||||
const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
|
||||
bool apply_updates();
|
||||
const Tensor& base() {
|
||||
return base_;
|
||||
|
@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const {
|
||||
// - view_value: The output tensor that we need to wrap.
|
||||
// - base: The "base" of the view that `view_value` was generated from.
|
||||
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
|
||||
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta)
|
||||
: c10::TensorImpl(
|
||||
c10::DispatchKeySet(DispatchKey::Functionalize),
|
||||
view_value.dtype(),
|
||||
view_value.device()
|
||||
),
|
||||
value_(view_value),
|
||||
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
|
||||
was_storage_changed_(base->was_storage_changed_),
|
||||
is_symbolic_(base->is_symbolic_)
|
||||
{
|
||||
FunctionalTensorWrapper::FunctionalTensorWrapper(
|
||||
const Tensor& view_value,
|
||||
const FunctionalTensorWrapper* base,
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta)
|
||||
: c10::TensorImpl(
|
||||
c10::DispatchKeySet(DispatchKey::Functionalize),
|
||||
view_value.dtype(),
|
||||
view_value.device()),
|
||||
value_(view_value),
|
||||
is_multi_output_view_(
|
||||
base->is_multi_output_view_ || meta->is_multi_output),
|
||||
was_storage_changed_(base->was_storage_changed_),
|
||||
is_symbolic_(base->is_symbolic_) {
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
set_constructor_metadata();
|
||||
@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
|
||||
view_metas_ = base->view_metas_; // copy
|
||||
}
|
||||
view_metas_.push_back(meta);
|
||||
maybe_mark_symbolic(meta);
|
||||
maybe_mark_symbolic(meta.get());
|
||||
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
|
||||
}
|
||||
|
||||
|
||||
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
|
||||
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
|
||||
}
|
||||
@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
|
||||
}
|
||||
|
||||
// See Note [Functionalization Pass - Inplace View Ops]
|
||||
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) {
|
||||
void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
|
||||
view_metas_.push_back(meta);
|
||||
// Manually track the fact that this tensor recieved a metadata mutation!
|
||||
has_metadata_mutation_ = true;
|
||||
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
|
||||
maybe_mark_symbolic(meta);
|
||||
maybe_mark_symbolic(meta.get());
|
||||
// Note [Functionalization Pass - Inplace View Ops]
|
||||
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
||||
// An example is transpose_, e.g. `a.transpose_()`
|
||||
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
value_ = meta.forward_fn(value_, meta.out_index);
|
||||
value_ = meta->forward(value_);
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
}
|
||||
|
||||
@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() {
|
||||
regenerate_from_base();
|
||||
}
|
||||
|
||||
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) {
|
||||
auto t = base;
|
||||
|
||||
// Reapply views to get the viewed tensor from the base in alias_
|
||||
for (auto& view_meta: view_metas_) {
|
||||
t = view_meta.forward_fn(t, view_meta.out_index);
|
||||
}
|
||||
|
||||
return t;
|
||||
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
|
||||
return view_metas_;
|
||||
}
|
||||
|
||||
void FunctionalTensorWrapper::regenerate_from_base() {
|
||||
@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
|
||||
auto t = storage_impl->base();
|
||||
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||
t = apply_view_metas(t);
|
||||
t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
|
||||
|
||||
replace_(t, /*from_lazy_regenerate=*/true);
|
||||
@ -759,20 +753,28 @@ void freeze_functional_tensor(const Tensor& tensor) {
|
||||
functional_base_impl->freeze_storage();
|
||||
}
|
||||
|
||||
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
|
||||
Tensor create_functional_tensor_with_view_meta(
|
||||
const at::Tensor& view_to_wrap,
|
||||
const at::Tensor& base,
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta,
|
||||
int64_t out_idx) {
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
|
||||
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
|
||||
auto meta_ = meta;
|
||||
if (out_idx != 0) {
|
||||
// Note [out_idx in ViewMeta]
|
||||
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
|
||||
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
|
||||
meta = meta.to_out_idx(out_idx);
|
||||
meta_ = meta->to_out_index(out_idx);
|
||||
}
|
||||
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
|
||||
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
|
||||
}
|
||||
|
||||
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) {
|
||||
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
||||
ITensorListRef view_to_wrap,
|
||||
const at::Tensor& base,
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta) {
|
||||
std::vector<Tensor> outputs(view_to_wrap.size());
|
||||
int64_t i = 0;
|
||||
for (const auto& tensor : view_to_wrap) {
|
||||
@ -782,12 +784,22 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) {
|
||||
void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
|
||||
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
|
||||
self_impl->mutate_view_meta(meta);
|
||||
}
|
||||
|
||||
Tensor apply_view_meta_sequence(
|
||||
const Tensor& base,
|
||||
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
|
||||
Tensor r = base;
|
||||
for (auto& vm : sequence) {
|
||||
r = vm->forward(r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
// Note [Propagating strides in the functionalization pass]
|
||||
// In order to properly compute stride information, the functionalization pass
|
||||
// calls each {view} reference implementations with meta tensors.
|
||||
|
@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
explicit FunctionalTensorWrapper(
|
||||
const Tensor& view_value,
|
||||
const FunctionalTensorWrapper* base,
|
||||
const functionalization::ViewMeta& meta);
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||
|
||||
// Get the underlying, actual tensor, that doesn't know anything about
|
||||
// functionalization.
|
||||
@ -97,17 +97,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
->are_all_mutations_under_no_grad_or_inference_mode();
|
||||
}
|
||||
|
||||
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
|
||||
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
|
||||
void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
|
||||
is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
|
||||
}
|
||||
|
||||
bool is_symbolic() const {
|
||||
return is_symbolic_;
|
||||
}
|
||||
|
||||
// Runs the forward_fn of every ViewMeta collected in the current instance
|
||||
// to some other base.
|
||||
Tensor apply_view_metas(const Tensor& base);
|
||||
// Retrieves the ViewMeta sequence of this tensor.
|
||||
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
|
||||
const;
|
||||
|
||||
// Sync's the underlying tensor with its alias, if it's out of date. This
|
||||
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
||||
@ -144,7 +144,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
// from the base tensor. This method is used by inplace-view ops like
|
||||
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
||||
// tensor by replaying the views off of the alias.
|
||||
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
|
||||
void mutate_view_meta(
|
||||
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
|
||||
|
||||
// Custom implementation of self.set_(src)
|
||||
void set__impl(const FunctionalTensorWrapper* other);
|
||||
@ -273,7 +274,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
bool is_symbolic_ = false;
|
||||
|
||||
size_t generation_ = 0;
|
||||
std::vector<at::functionalization::ViewMeta> view_metas_;
|
||||
std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
|
||||
|
||||
protected:
|
||||
static void copy_tensor_metadata(
|
||||
@ -365,16 +366,20 @@ TORCH_API void propagate_xla_data_direct(
|
||||
Tensor create_functional_tensor_with_view_meta(
|
||||
const Tensor& view_to_wrap,
|
||||
const Tensor& base,
|
||||
functionalization::ViewMeta meta,
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta,
|
||||
int64_t out_idx = 0);
|
||||
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
||||
ITensorListRef view_to_wrap,
|
||||
const Tensor& base,
|
||||
const functionalization::ViewMeta& meta);
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||
|
||||
void mutate_view_meta(
|
||||
const Tensor& self,
|
||||
const functionalization::ViewMeta& meta);
|
||||
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
||||
|
||||
TORCH_API Tensor apply_view_meta_sequence(
|
||||
const Tensor& base,
|
||||
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
|
||||
|
||||
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
||||
void set_sizes_strides_offset(
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include <ATen/FunctionalizeFallbackKernel.h>
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
@ -27,6 +29,31 @@
|
||||
#include <utility>
|
||||
#endif
|
||||
|
||||
namespace at::functionalization {
|
||||
|
||||
Tensor resize__ViewMeta::forward(const Tensor& base) {
|
||||
if (reapply_views) {
|
||||
return base.as_strided(size, c10::contiguous_strides(size));
|
||||
} else {
|
||||
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
|
||||
return base.as_strided_scatter(
|
||||
mutated_view, size, c10::contiguous_strides(size));
|
||||
}
|
||||
|
||||
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
|
||||
return at::_unsafe_view_symint(base, size);
|
||||
}
|
||||
|
||||
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
|
||||
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
|
||||
}
|
||||
|
||||
} // namespace at::functionalization
|
||||
|
||||
namespace {
|
||||
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
@ -168,19 +195,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
|
||||
// The output of resizing is equivalent to taking a slice of a larger tensor.
|
||||
// We have to emulate this "slicing" with an as_strided call.
|
||||
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
||||
if (reapply_views) {
|
||||
return base.as_strided(size, c10::contiguous_strides(size));
|
||||
} else {
|
||||
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
|
||||
}
|
||||
},
|
||||
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
||||
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
|
||||
},
|
||||
/*has_symbolic_inputs=*/false
|
||||
);
|
||||
auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
|
||||
reapply_views, size.vec());
|
||||
at::functionalization::impl::mutate_view_meta(self, view_meta);
|
||||
return self;
|
||||
}
|
||||
@ -299,17 +315,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
|
||||
tmp_output = at::_unsafe_view_symint(self_, size);
|
||||
}
|
||||
|
||||
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
|
||||
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
||||
return at::_unsafe_view_symint(base, size);
|
||||
},
|
||||
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
|
||||
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
|
||||
},
|
||||
/*has_symbolic_inputs=*/has_symbolic_inputs
|
||||
);
|
||||
bool has_symbolic_inputs = std::any_of(
|
||||
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
|
||||
auto view_meta =
|
||||
std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
|
||||
has_symbolic_inputs, size.vec());
|
||||
|
||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
|
58
aten/src/ATen/FunctionalizeFallbackKernel.h
Normal file
58
aten/src/ATen/FunctionalizeFallbackKernel.h
Normal file
@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
|
||||
namespace at::functionalization {
|
||||
|
||||
// `ViewMeta` implementation for `resize_` operation.
|
||||
struct TORCH_API resize__ViewMeta : public ViewMeta {
|
||||
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta);
|
||||
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||
bool /* reapply_views */,
|
||||
const std::vector<int64_t>&);
|
||||
|
||||
resize__ViewMeta(const SerializableTuple& tpl)
|
||||
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||
|
||||
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
|
||||
: ViewMeta(/*has_symbolic_inputs=*/false),
|
||||
reapply_views(reapply_views),
|
||||
size(size) {}
|
||||
|
||||
Tensor forward(const Tensor& base) override;
|
||||
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||
|
||||
SerializableTuple to_serializable_tuple() {
|
||||
return std::make_tuple(reapply_views, size);
|
||||
}
|
||||
|
||||
bool reapply_views;
|
||||
std::vector<int64_t> size;
|
||||
};
|
||||
|
||||
// `ViewMeta` implementation for `_unsafe_view` operation.
|
||||
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
|
||||
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta);
|
||||
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
|
||||
bool /* has_symbolic_inputs */,
|
||||
const std::vector<c10::SymInt>&);
|
||||
|
||||
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
|
||||
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
|
||||
|
||||
_unsafe_view_ViewMeta(
|
||||
bool has_symbolic_inputs,
|
||||
const std::vector<c10::SymInt>& size)
|
||||
: ViewMeta(has_symbolic_inputs), size(size) {}
|
||||
|
||||
Tensor forward(const Tensor& base) override;
|
||||
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||
|
||||
SerializableTuple to_serializable_tuple() {
|
||||
return std::make_tuple(has_symbolic_inputs, size);
|
||||
}
|
||||
|
||||
std::vector<c10::SymInt> size;
|
||||
};
|
||||
|
||||
} // namespace at::functionalization
|
@ -2,22 +2,12 @@
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace at {
|
||||
namespace functionalization {
|
||||
|
||||
enum class InverseReturnMode {
|
||||
/// Specifies that functional inverses should always return a view.
|
||||
AlwaysView,
|
||||
/// Specifies that functional inverses should always return a non-view / copy.
|
||||
NeverView,
|
||||
/// Specifies that functional inverses should return a view unless a (copying) scatter
|
||||
/// inverse exists, in which case that will be used instead.
|
||||
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
|
||||
ViewOrScatterInverse,
|
||||
};
|
||||
|
||||
struct FunctionalInverses {
|
||||
|
||||
${view_inverse_declarations}
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/FunctionalInverses.h>
|
||||
#include <ATen/ViewMetaClasses.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
|
19
aten/src/ATen/templates/ViewMetaClasses.cpp
Normal file
19
aten/src/ATen/templates/ViewMetaClasses.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/FunctionalInverses.h>
|
||||
#include <ATen/ViewMetaClasses.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Operators.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
${op_headers}
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace functionalization {
|
||||
|
||||
${view_meta_implementations}
|
||||
|
||||
} // namespace functionalization
|
||||
} // namespace at
|
12
aten/src/ATen/templates/ViewMetaClasses.h
Normal file
12
aten/src/ATen/templates/ViewMetaClasses.h
Normal file
@ -0,0 +1,12 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace functionalization {
|
||||
|
||||
${view_meta_declarations}
|
||||
|
||||
} // namespace functionalization
|
||||
} // namespace at
|
11
aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp
Normal file
11
aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
#include <ATen/ViewMetaClasses.h>
|
||||
#include <torch/csrc/functionalization/Module.h>
|
||||
|
||||
namespace torch::functionalization {
|
||||
|
||||
void initGenerated(PyObject* module) {
|
||||
auto functionalization = py::handle(module).cast<py::module>();
|
||||
$view_meta_bindings
|
||||
}
|
||||
|
||||
} // namespace torch::functionalization
|
@ -117,6 +117,7 @@ def define_targets(rules):
|
||||
":LazyNonNativeIr.h",
|
||||
":RegisterDispatchDefinitions.ini",
|
||||
":RegisterDispatchKey.cpp",
|
||||
":ViewMetaClassesPythonBinding.cpp",
|
||||
":native_functions.yaml",
|
||||
":shape_inference.h",
|
||||
":tags.yaml",
|
||||
@ -297,6 +298,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
|
||||
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
||||
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
||||
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
||||
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
|
||||
]
|
||||
|
||||
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP
|
||||
|
@ -929,6 +929,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/disable_torch_function.cpp",
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
"torch/csrc/cpu/Module.cpp",
|
||||
"torch/csrc/functionalization/Module.cpp",
|
||||
"torch/csrc/instruction_counter/Module.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
|
@ -310,6 +310,7 @@ set(GENERATED_CXX_PYTHON
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
|
||||
)
|
||||
|
||||
set(GENERATED_H_PYTHON
|
||||
@ -373,6 +374,7 @@ add_custom_command(
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
|
||||
${autograd_python}
|
||||
${autograd_yaml}
|
||||
${autograd_templates}
|
||||
|
@ -250,11 +250,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
@functorch_config.patch(
|
||||
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
|
||||
)
|
||||
def test_view_replay_bypass(self):
|
||||
"""
|
||||
Shoud bypass when view replay is turned on
|
||||
"""
|
||||
|
||||
def test_view_replay(self):
|
||||
def fn(a):
|
||||
tmp = a.detach()
|
||||
a.mul_(2)
|
||||
@ -262,10 +258,25 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
compiled_fn = torch.compile(fn)
|
||||
compiled_fn(torch.rand(2, 3))
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
def run_and_check(miss, hit, bypass):
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
inp = torch.rand(2, 3)
|
||||
compiled_inp = inp.clone().detach()
|
||||
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
out = fn(inp)
|
||||
compiled_out = compiled_fn(compiled_inp)
|
||||
|
||||
self.assertEqual(out, compiled_out)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
|
||||
|
||||
run_and_check(miss=1, hit=0, bypass=0)
|
||||
run_and_check(miss=1, hit=1, bypass=0)
|
||||
run_and_check(miss=1, hit=2, bypass=0)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
|
@ -6897,7 +6897,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
||||
{
|
||||
"enable_autograd_cache": True,
|
||||
"strict_autograd_cache": True,
|
||||
"view_replay_for_aliased_outputs": False,
|
||||
}
|
||||
)
|
||||
@torch._inductor.config.patch("fx_graph_cache", True)
|
||||
|
@ -189,6 +189,12 @@ def main() -> None:
|
||||
)
|
||||
options = parser.parse_args()
|
||||
|
||||
# Path: aten/src/ATen
|
||||
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
|
||||
operator_selector = get_selector(
|
||||
options.selected_op_list_path, options.operators_yaml_path
|
||||
)
|
||||
|
||||
generate_code(
|
||||
options.gen_dir,
|
||||
options.native_functions_path,
|
||||
@ -198,13 +204,32 @@ def main() -> None:
|
||||
options.disable_autograd,
|
||||
options.force_schema_registration,
|
||||
# options.selected_op_list
|
||||
operator_selector=get_selector(
|
||||
options.selected_op_list_path, options.operators_yaml_path
|
||||
),
|
||||
operator_selector=operator_selector,
|
||||
)
|
||||
|
||||
# Generate the python bindings for functionalization's `ViewMeta` classes.
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_view_meta_classes,
|
||||
)
|
||||
|
||||
functionalization_templates_dir = os.path.join(aten_path, "templates")
|
||||
functionalization_install_dir = os.path.join(
|
||||
options.gen_dir, "torch/csrc/functionalization/generated"
|
||||
)
|
||||
|
||||
os.makedirs(functionalization_install_dir, exist_ok=True)
|
||||
assert os.path.isdir(functionalization_install_dir)
|
||||
assert os.path.isdir(functionalization_templates_dir)
|
||||
|
||||
gen_functionalization_view_meta_classes(
|
||||
options.native_functions_path or NATIVE_FUNCTIONS_PATH,
|
||||
options.tags_path or TAGS_PATH,
|
||||
selector=operator_selector,
|
||||
install_dir=functionalization_install_dir,
|
||||
template_dir=functionalization_templates_dir,
|
||||
)
|
||||
|
||||
if options.gen_lazy_ts_backend:
|
||||
aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
|
||||
ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
|
||||
ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
|
||||
ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
|
||||
|
@ -67,6 +67,7 @@ from . import (
|
||||
_export,
|
||||
_cpu,
|
||||
_dynamo,
|
||||
_functionalization,
|
||||
_functorch,
|
||||
_lazy,
|
||||
_lazy_ts_backend,
|
||||
|
16
torch/_C/_functionalization.pyi
Normal file
16
torch/_C/_functionalization.pyi
Normal file
@ -0,0 +1,16 @@
|
||||
from torch import Tensor
|
||||
from torch.types import _bool
|
||||
|
||||
# Defined in torch/csrc/functionalization/Module.cpp
|
||||
|
||||
class ViewMeta:
|
||||
has_symbolic_inputs: _bool
|
||||
|
||||
# Returns the list of ViewMeta instances of the given functional tensor.
|
||||
#
|
||||
# Although we do have python bindings for their types, we won't
|
||||
# expose them here, since they should not be used by users.
|
||||
def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ...
|
||||
|
||||
# Applies the ViewMeta sequence on top of the given base.
|
||||
def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ...
|
@ -227,19 +227,6 @@ def check_cacheable(gm: torch.fx.GraphModule):
|
||||
check_node_safe(node)
|
||||
|
||||
|
||||
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
|
||||
"""
|
||||
When view replay is turned on, we bypass autograd cache if
|
||||
the output is aliased.
|
||||
"""
|
||||
if config.view_replay_for_aliased_outputs:
|
||||
for info in metadata.output_info:
|
||||
if info.functional_tensor is not None:
|
||||
raise BypassAOTAutogradCache(
|
||||
"Cannot cache a graph with functional tensor"
|
||||
)
|
||||
|
||||
|
||||
class AOTAutogradCacheDetails(FxGraphHashDetails):
|
||||
"""
|
||||
Object to capture all the details for a dynamo graph module relevant to computing
|
||||
@ -875,7 +862,6 @@ class AOTAutogradCache:
|
||||
def save(key: str, entry: AOTAutogradCacheEntry, remote: bool):
|
||||
"""Save a single entry into the cache."""
|
||||
try:
|
||||
check_metadata_cacheable(entry.runtime_metadata)
|
||||
content = pickle.dumps(entry)
|
||||
CacheArtifactManager.record_artifact(
|
||||
CacheArtifactType.AOT_AUTOGRAD, key, content
|
||||
|
@ -36,10 +36,10 @@ from .functional_utils import (
|
||||
has_metadata_mutation,
|
||||
MetadataKey,
|
||||
to_fun,
|
||||
ViewMetaSequence,
|
||||
was_inductor_storage_resized,
|
||||
)
|
||||
from .schemas import (
|
||||
FunctionalTensorMetadataEq,
|
||||
InputAliasInfo,
|
||||
MutationType,
|
||||
OutputAliasInfo,
|
||||
@ -604,7 +604,7 @@ from a multi-output view call"
|
||||
#
|
||||
# The FunctionalTensor will be saved if one of the 2 conditions below
|
||||
# is true:
|
||||
functional_tensor = None
|
||||
view_meta_sequence = None
|
||||
if (
|
||||
# 1. If the output_type is either of:
|
||||
# (i) alias_of_intermediate;
|
||||
@ -636,7 +636,7 @@ from a multi-output view call"
|
||||
and not input_info[base_idx].mutates_metadata
|
||||
):
|
||||
if isinstance(o, FunctionalTensor):
|
||||
functional_tensor = FunctionalTensorMetadataEq(o.elem)
|
||||
view_meta_sequence = ViewMetaSequence(o)
|
||||
|
||||
out_info = OutputAliasInfo(
|
||||
output_type=output_type,
|
||||
@ -644,7 +644,7 @@ from a multi-output view call"
|
||||
base_idx=base_idx,
|
||||
dynamic_dims=dynamic_dims,
|
||||
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
|
||||
functional_tensor=functional_tensor,
|
||||
view_meta_sequence=view_meta_sequence,
|
||||
)
|
||||
output_info.append(out_info)
|
||||
|
||||
|
@ -13,15 +13,12 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._C import _functionalization
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_true,
|
||||
sym_eq,
|
||||
SymIntEqByExpr,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
@ -227,9 +224,9 @@ def gen_alias_from_base(
|
||||
aliased_base_tensor,
|
||||
target_meta_tensor,
|
||||
target_requires_grad,
|
||||
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None,
|
||||
target_view_meta_sequence: Optional[ViewMetaSequence] = None,
|
||||
*,
|
||||
replay_views,
|
||||
replay_views: bool,
|
||||
):
|
||||
# Patch the correct requires_grad field of the output tensor, depending on whether:
|
||||
# (i) the reconstructed output (out) was came from a tensor that requires grad or not;
|
||||
@ -248,13 +245,11 @@ def gen_alias_from_base(
|
||||
# to replay them (view functions) on the aliased_base_tensor.
|
||||
if (
|
||||
replay_views
|
||||
and target_functional_tensor is not None
|
||||
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
|
||||
and target_view_meta_sequence is not None
|
||||
and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence)
|
||||
):
|
||||
functional_tensor = target_functional_tensor.tensor
|
||||
|
||||
out = torch._functionalize_apply_view_metas(
|
||||
functional_tensor, aliased_base_tensor
|
||||
out = _functionalization.apply_view_meta_sequence(
|
||||
aliased_base_tensor, target_view_meta_sequence.sequence
|
||||
)
|
||||
# If re-applying the ViewMeta sequence succeeded, there should be no more
|
||||
# problems going forward. We just check we got to the target shape and
|
||||
@ -315,28 +310,8 @@ def gen_alias_from_base(
|
||||
return aliased_out
|
||||
|
||||
|
||||
def has_same_metadata(t1, t2):
|
||||
return (
|
||||
definitely_true(sym_eq(t1.size(), t2.size()))
|
||||
and definitely_true(t1.layout == t2.layout)
|
||||
and (
|
||||
is_sparse_any(t1)
|
||||
or (
|
||||
definitely_true(sym_eq(t1.stride(), t2.stride()))
|
||||
and definitely_true(t1.storage_offset() == t2.storage_offset())
|
||||
)
|
||||
)
|
||||
and t1.is_conj() == t2.is_conj()
|
||||
and t1.is_neg() == t2.is_neg()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetadataKey:
|
||||
"""
|
||||
This should be equal whenever has_same_metadata would return True
|
||||
"""
|
||||
|
||||
size: Tuple[SymIntEqByExpr, ...]
|
||||
layout: torch.layout
|
||||
is_sparse: bool
|
||||
@ -360,25 +335,45 @@ class MetadataKey:
|
||||
)
|
||||
|
||||
|
||||
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata
|
||||
# after applying all the ViewMeta operations.
|
||||
class FunctionalTensorMetadataEq:
|
||||
def __init__(self, tensor: torch.Tensor) -> None:
|
||||
assert torch._is_functional_tensor(tensor)
|
||||
self.tensor = tensor
|
||||
# ViewMeta sequence wrapper for equality comparisons.
|
||||
#
|
||||
# Even though we can compare each ViewMeta instance, we compare the resulting
|
||||
# tensor metadata, instead. That's because the creation of synthetic bases + the
|
||||
# re-generation of input views might end-up creating a different sequence of
|
||||
# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same
|
||||
# metadata.
|
||||
#
|
||||
# Therefore, we store what the end result should look like as serializable
|
||||
# metadata.
|
||||
#
|
||||
# When logging, this class should look like:
|
||||
#
|
||||
# ViewMetaSequence(view, select_int, slice_Tensor)
|
||||
#
|
||||
# i.e. a parenthesized list of view operations within that ViewMeta sequence.
|
||||
class ViewMetaSequence:
|
||||
def __init__(self, tensor: FunctionalTensor) -> None:
|
||||
assert torch._is_functional_tensor(tensor.elem)
|
||||
self.sequence = _functionalization.get_view_meta_sequence(tensor.elem)
|
||||
self.metadata = MetadataKey.make(tensor)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
suffix = len("_ViewMeta")
|
||||
types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence)
|
||||
return f"ViewMetaSequence({types})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# If other is None, then it probably means that we weren't able to recreate
|
||||
# the FunctionalTensorMetadataEq. One of this cases is when we update the
|
||||
# view metadata by calling: create_synthetic_base_metadata.
|
||||
# the ViewMeta sequence. One example is when we update the view metadata by
|
||||
# calling: create_synthetic_base_metadata.
|
||||
if other is None:
|
||||
return True
|
||||
|
||||
# Comparison agains any other type is not implemented.
|
||||
if not isinstance(other, FunctionalTensorMetadataEq):
|
||||
# Comparison against any other type is not implemented.
|
||||
if not isinstance(other, ViewMetaSequence):
|
||||
return NotImplemented
|
||||
|
||||
return has_same_metadata(self.tensor, other.tensor)
|
||||
return self.metadata == other.metadata
|
||||
|
||||
|
||||
# new_arg and arg here are either:
|
||||
|
@ -75,7 +75,7 @@ def remove_dupe_metadata(
|
||||
dynamic_dims=o.dynamic_dims,
|
||||
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
|
||||
requires_grad=o.requires_grad,
|
||||
functional_tensor=o.functional_tensor,
|
||||
view_meta_sequence=o.view_meta_sequence,
|
||||
)
|
||||
for o in m.output_info
|
||||
],
|
||||
@ -226,7 +226,7 @@ def create_synthetic_base_metadata(
|
||||
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
|
||||
base_idx=new_base_idx, # type: ignore[arg-type]
|
||||
requires_grad=o.requires_grad,
|
||||
functional_tensor=o.functional_tensor,
|
||||
view_meta_sequence=o.view_meta_sequence,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -172,7 +172,7 @@ class AliasOfInputHandler:
|
||||
self.base_idx = info.base_idx
|
||||
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
||||
self.requires_grad = info.requires_grad
|
||||
self.functional_tensor = info.functional_tensor
|
||||
self.view_meta_sequence = info.view_meta_sequence
|
||||
self.replay_views = config.view_replay_for_aliased_outputs
|
||||
|
||||
def __call__(self, orig_inputs, fw_outs, out):
|
||||
@ -181,7 +181,7 @@ class AliasOfInputHandler:
|
||||
aliased_base_tensor,
|
||||
self.unwrap_out(out),
|
||||
self.requires_grad,
|
||||
self.functional_tensor,
|
||||
self.view_meta_sequence,
|
||||
replay_views=self.replay_views,
|
||||
)
|
||||
|
||||
@ -209,7 +209,7 @@ class AliasOfIntermediateHandler:
|
||||
|
||||
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
||||
self.requires_grad = info.requires_grad
|
||||
self.functional_tensor = info.functional_tensor
|
||||
self.view_meta_sequence = info.view_meta_sequence
|
||||
self.replay_views = config.view_replay_for_aliased_outputs
|
||||
|
||||
def __call__(self, orig_inputs, fw_outs, out):
|
||||
@ -218,7 +218,7 @@ class AliasOfIntermediateHandler:
|
||||
aliased_base_tensor,
|
||||
self.unwrap_out(out),
|
||||
self.requires_grad,
|
||||
self.functional_tensor,
|
||||
self.view_meta_sequence,
|
||||
replay_views=self.replay_views,
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,6 @@ input/output types, metadata, config, function signatures etc.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@ -20,10 +19,7 @@ from torch._subclasses.fake_tensor import is_fake
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config
|
||||
from .functional_utils import (
|
||||
_check_if_mutation_can_be_in_graph,
|
||||
FunctionalTensorMetadataEq,
|
||||
)
|
||||
from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence
|
||||
from .utils import strict_zip
|
||||
|
||||
|
||||
@ -92,15 +88,14 @@ class OutputAliasInfo:
|
||||
dynamic_dims: Optional[Set[int]]
|
||||
# requires_grad
|
||||
requires_grad: bool
|
||||
# FunctionalTensorWrapper that represents this output.
|
||||
# Sequence of ViewMeta objects.
|
||||
#
|
||||
# Provides us the means to replay views from it.
|
||||
# Provides us the means to re-run view functions on other tensors.
|
||||
#
|
||||
# We need to wrap the actual FunctionalTensorWrapper with this class so that
|
||||
# we only compare the tensor's metadata. That's because with the transformations
|
||||
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base
|
||||
# tensor might change.
|
||||
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
|
||||
# We need to wrap the actual list of ViewMeta with this class so that
|
||||
# we compare the ViewMeta elements appropriately, i.e. their type and
|
||||
# the elements returned by the `as_tuple()` call.
|
||||
view_meta_sequence: Optional[ViewMetaSequence] = None
|
||||
|
||||
|
||||
class MutationType(Enum):
|
||||
@ -582,17 +577,6 @@ class ViewAndMutationMeta:
|
||||
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
|
||||
# Clear traced tangents at runtime
|
||||
self.traced_tangents = []
|
||||
new_output_info = []
|
||||
for out in self.output_info:
|
||||
if config.view_replay_for_aliased_outputs:
|
||||
new_out = out
|
||||
else:
|
||||
# If we're not using view_replay, remove the functional tensor.
|
||||
# Functional tensors are unfortunately not serializable,
|
||||
# so doing this is required for AOTAutograd caching.
|
||||
new_out = dataclasses.replace(out, functional_tensor=None)
|
||||
new_output_info.append(new_out)
|
||||
self.output_info = new_output_info
|
||||
for inp_meta in self.subclass_inp_meta:
|
||||
if isinstance(inp_meta, SubclassCreationMeta):
|
||||
inp_meta.make_runtime_safe()
|
||||
|
@ -71,6 +71,7 @@
|
||||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/export/pybind.h>
|
||||
#include <torch/csrc/functionalization/Module.h>
|
||||
#include <torch/csrc/functorch/init.h>
|
||||
#include <torch/csrc/fx/node.h>
|
||||
#include <torch/csrc/inductor/aoti_package/pybind.h>
|
||||
@ -1869,6 +1870,7 @@ PyObject* initModule() {
|
||||
torch::instruction_counter::initModule(module);
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
torch::functionalization::initModule(module);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// This will only initialise base classes and attach them to library namespace
|
||||
|
@ -633,15 +633,6 @@ void initTorchFunctions(PyObject* module) {
|
||||
at::functionalization::impl::isFunctionalTensor(t));
|
||||
at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
|
||||
});
|
||||
py_module.def(
|
||||
"_functionalize_apply_view_metas",
|
||||
[](const at::Tensor& tensor, const at::Tensor& base) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
at::functionalization::impl::isFunctionalTensor(tensor));
|
||||
auto impl =
|
||||
at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
||||
return impl->apply_view_metas(base);
|
||||
});
|
||||
py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||
|
71
torch/csrc/functionalization/Module.cpp
Normal file
71
torch/csrc/functionalization/Module.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
#include <torch/csrc/functionalization/Module.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/FunctionalizeFallbackKernel.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch::functionalization {
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
// Create a `torch._C._functionalization` Python module.
|
||||
auto functionalization = m.def_submodule(
|
||||
"_functionalization", "functionalization related pybind.");
|
||||
|
||||
// Retrieve the ViewMeta sequence of a given functional tensor.
|
||||
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
at::functionalization::impl::isFunctionalTensor(tensor));
|
||||
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
||||
return impl->view_metas();
|
||||
});
|
||||
|
||||
// Applies the given ViewMeta sequence to the given base.
|
||||
functionalization.def(
|
||||
"apply_view_meta_sequence",
|
||||
[](const at::Tensor& base,
|
||||
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
|
||||
sequence) {
|
||||
return at::functionalization::impl::apply_view_meta_sequence(
|
||||
base, sequence);
|
||||
});
|
||||
|
||||
// Binding for InverseReturnMode.
|
||||
py::enum_<at::functionalization::InverseReturnMode>(
|
||||
functionalization, "InverseReturnMode")
|
||||
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
|
||||
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
|
||||
.value(
|
||||
"ViewOrScatterInverse",
|
||||
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
|
||||
|
||||
// Create bindings for the ViewMeta base class.
|
||||
//
|
||||
// Needed so that we can take a list of ViewMeta objects as parameter.
|
||||
// Specifically, in the Python-side, we will have a list of derived ViewMeta
|
||||
// classes. We need to tell pybind11 that all of those are, in fact, instances
|
||||
// of different ViewMeta sub-types.
|
||||
py::class_<
|
||||
at::functionalization::ViewMeta,
|
||||
std::shared_ptr<at::functionalization::ViewMeta>>(
|
||||
functionalization, "ViewMeta")
|
||||
.def_property_readonly(
|
||||
"has_symbolic_inputs",
|
||||
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
|
||||
return meta->has_symbolic_inputs;
|
||||
});
|
||||
|
||||
// Bindings for `ViewMeta` specializations manually implemented.
|
||||
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
|
||||
functionalization);
|
||||
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
|
||||
functionalization);
|
||||
|
||||
// Bindings for `ViewMeta` specializations automatically generated.
|
||||
initGenerated(functionalization.ptr());
|
||||
}
|
||||
|
||||
} // namespace torch::functionalization
|
36
torch/csrc/functionalization/Module.h
Normal file
36
torch/csrc/functionalization/Module.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch::functionalization {
|
||||
|
||||
// Creates the default bindings for `ViewMeta` specializations.
|
||||
//
|
||||
// Defines a constructor using the types in `SerializableTuple`, as well
|
||||
// as pickle methods.
|
||||
template <class T>
|
||||
void create_binding_with_pickle(py::module m) {
|
||||
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
|
||||
m, T::name())
|
||||
.def(py::init<typename T::SerializableTuple>())
|
||||
.def(
|
||||
"as_tuple",
|
||||
[](const std::shared_ptr<T>& meta) {
|
||||
return meta->to_serializable_tuple();
|
||||
})
|
||||
.def(py::pickle(
|
||||
[](const std::shared_ptr<T>& meta) {
|
||||
return meta->to_serializable_tuple();
|
||||
},
|
||||
[](const typename T::SerializableTuple& tpl) {
|
||||
return std::make_shared<T>(tpl);
|
||||
}));
|
||||
}
|
||||
|
||||
void initModule(PyObject* module);
|
||||
void initGenerated(PyObject* module);
|
||||
|
||||
} // namespace torch::functionalization
|
@ -23,20 +23,13 @@ from torchgen.model import (
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to API's used
|
||||
# when creating view lambdas that are used by the functionalization pass.
|
||||
# There are two types of lambdas: forward lambdas and reverse lambdas.
|
||||
# These API's mostly follow the dispatcher API, with a few quirks:
|
||||
# - The lambda capture has to convert reference types to value types
|
||||
# - While the forward lambda just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse lambda
|
||||
# when creating `ViewMeta` specializations that are used by the functionalization pass.
|
||||
# These API's mostly follow the dispatcher API, with one difference:
|
||||
# - While the forward function just directly calls into the at::_ops API
|
||||
# (following the dispatcher convention), the logic here for the reverse function
|
||||
# is responsible for generating both the call-site, and the declarations
|
||||
# (which are implemented manually in the at::functionalization::impl namespace).
|
||||
|
||||
# The lambdas generated for each view op in the functionalization pass are of the form
|
||||
# [capture_arguments](outer_arguments) -> returns_type {
|
||||
# return name(inner_arguments);
|
||||
# }
|
||||
|
||||
# Define some specific lambda input arguments.
|
||||
base_binding = Binding(
|
||||
name="base",
|
||||
@ -46,6 +39,18 @@ base_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
has_symbolic_inputs_binding = Binding(
|
||||
name="has_symbolic_inputs",
|
||||
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
|
||||
argument=Argument(
|
||||
name="has_symbolic_inputs",
|
||||
type=BaseType(BaseTy.bool),
|
||||
default=None,
|
||||
annotation=None,
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_binding = Binding(
|
||||
name="mutated_view",
|
||||
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
|
||||
@ -54,11 +59,11 @@ mutated_view_binding = Binding(
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
mutated_view_idx_binding = Binding(
|
||||
name="mutated_view_idx",
|
||||
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
|
||||
out_index_binding = Binding(
|
||||
name="out_index",
|
||||
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
|
||||
argument=Argument(
|
||||
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
|
||||
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding(
|
||||
)
|
||||
|
||||
|
||||
# The lambda capture itself doesn't have a name.
|
||||
# The name returned here corresponds to the name of the inner function called by the lambda.
|
||||
# Name of the `ViewMeta` specialization class created.
|
||||
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
|
||||
namespace = "at::functionalization::" if with_namespace else ""
|
||||
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
|
||||
|
||||
|
||||
# Name of the operation called inside the `forward`/`reverse` implementations.
|
||||
def name(
|
||||
g: NativeFunctionsViewGroup,
|
||||
*,
|
||||
@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
||||
return f"{api_name}_inverse"
|
||||
|
||||
|
||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
||||
# capture arguments include all arguments except `self`.
|
||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
non_self_value_bindings = [
|
||||
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
|
||||
]
|
||||
|
||||
all_bindings = [
|
||||
inverse_return_mode_binding if is_reverse else reapply_views_binding
|
||||
]
|
||||
all_bindings.extend(non_self_value_bindings)
|
||||
return all_bindings
|
||||
|
||||
|
||||
def returns_type(func: FunctionSchema) -> CType:
|
||||
# Assertion: all view ops return tensor-like outputs
|
||||
assert len(func.returns) >= 1
|
||||
@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType:
|
||||
return BaseCType(tensorT)
|
||||
|
||||
|
||||
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
||||
if is_reverse:
|
||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
||||
else:
|
||||
return [base_binding, mutated_view_idx_binding]
|
||||
# Checks whether `func` might return more than one value.
|
||||
def is_multi_output(func: FunctionSchema) -> bool:
|
||||
return len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
|
||||
)
|
||||
|
||||
|
||||
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
||||
if len(func.returns) > 1 or (
|
||||
len(func.returns) == 1 and func.returns[0].type.is_list_like()
|
||||
):
|
||||
return mutated_view_idx_binding
|
||||
return None
|
||||
# `ViewMeta` specialization constructor parameters.
|
||||
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
# All specializations are paremeterized by `has_symbolic_inputs` flag.
|
||||
arguments = [has_symbolic_inputs_binding]
|
||||
|
||||
# If `func` might return more than 1 value, we also parameterize this specialization
|
||||
# with the output index.
|
||||
if is_multi_output(func):
|
||||
arguments.append(out_index_binding)
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# `ViewMeta` specialized class' constructor arguments.
|
||||
#
|
||||
# Values needed specifically by this specialization, that the base class does not need.
|
||||
# Same as the class' attributes, but non-owning.
|
||||
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
|
||||
return attributes(func, owning=False)
|
||||
|
||||
|
||||
# `ViewMeta` specialized class' non-static member data.
|
||||
#
|
||||
# Essential data for calling the instance's `forward` and `reverse functions. You can
|
||||
# think of them as values that should be captured from the functionalization kernel.
|
||||
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
return [
|
||||
reapply_views_binding,
|
||||
inverse_return_mode_binding,
|
||||
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
|
||||
]
|
||||
|
||||
|
||||
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
# the reverse lambda does the same, but with an additional "mutated_view" arg
|
||||
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
|
||||
# their corresponding view_inverse function takes in an additional index argument.
|
||||
index_binding = inner_call_index(func)
|
||||
if index_binding is not None:
|
||||
if is_multi_output(func):
|
||||
return [
|
||||
base_binding,
|
||||
mutated_view_binding,
|
||||
inverse_return_mode_binding,
|
||||
index_binding,
|
||||
out_index_binding,
|
||||
] + non_self_bindings
|
||||
else:
|
||||
return [
|
||||
|
@ -300,83 +300,11 @@ class ViewInverseSignature:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.inner_arguments(
|
||||
self.g.view.func, is_reverse=True
|
||||
)
|
||||
for a in functionalization.op_arguments(self.g.view.func, is_reverse=True)
|
||||
]
|
||||
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionalizationLambda:
|
||||
g: NativeFunctionsViewGroup
|
||||
|
||||
# are we generating the forward lambda or the reverse lambda?
|
||||
is_reverse: bool
|
||||
|
||||
def captures(self) -> list[Expr]:
|
||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
||||
# and plumb it into the lambda.
|
||||
outer_ctx = dispatcher.arguments(self.g.view.func) + [
|
||||
functionalization.reapply_views_binding,
|
||||
functionalization.inverse_return_mode_binding,
|
||||
]
|
||||
capture_bindings = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
# allow_expensive_conversions is set because we want to convert
|
||||
# some reference types (IntArrayRef) to value types (vector<int64_t>).
|
||||
capture_exprs = translate.translate(
|
||||
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
|
||||
)
|
||||
return capture_exprs
|
||||
|
||||
def decl(self) -> str:
|
||||
return_type = functionalization.returns_type(self.g.view.func)
|
||||
capture_str = ", ".join(
|
||||
f"{val.type.name} = {val.expr}" for val in self.captures()
|
||||
)
|
||||
decls = [
|
||||
a.decl()
|
||||
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
]
|
||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
||||
|
||||
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
||||
inner_call_name = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=self.is_reverse,
|
||||
include_namespace=True,
|
||||
reapply_views=reapply_views,
|
||||
)
|
||||
|
||||
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
|
||||
capture_ctx = functionalization.capture_arguments(
|
||||
self.g.view.func, is_reverse=self.is_reverse
|
||||
)
|
||||
full_ctx = arg_ctx + capture_ctx
|
||||
|
||||
assert self.g.view_copy is not None
|
||||
call_bindings = functionalization.inner_arguments(
|
||||
self.g.view_copy.func, is_reverse=self.is_reverse
|
||||
)
|
||||
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
|
||||
call_exprs = [
|
||||
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
|
||||
]
|
||||
if not self.is_reverse and maybe_index is not None:
|
||||
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
|
||||
else:
|
||||
return f'{inner_call_name}({", ".join(call_exprs)});'
|
||||
|
||||
@staticmethod
|
||||
def from_func(
|
||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
||||
) -> FunctionalizationLambda:
|
||||
return FunctionalizationLambda(g, is_reverse)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StructuredImplSignature:
|
||||
g: NativeFunctionsGroup
|
||||
|
105
torchgen/gen.py
105
torchgen/gen.py
@ -45,6 +45,8 @@ from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_functionalization_view_meta_classes_decl,
|
||||
gen_functionalization_view_meta_classes_impl,
|
||||
GenCompositeViewCopyKernel,
|
||||
)
|
||||
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
|
||||
@ -2577,48 +2579,48 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
},
|
||||
)
|
||||
|
||||
def gen_op_headers(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.view.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
|
||||
]
|
||||
if g.view_copy is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
|
||||
]
|
||||
if g.inplace is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
|
||||
]
|
||||
if g.mutable is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
return [
|
||||
f"#include <ATen/ops/{g.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
def functionalization_env_callable(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> dict[str, list[str]]:
|
||||
def gen_op_headers(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.view.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
|
||||
]
|
||||
if g.view_copy is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
|
||||
]
|
||||
if g.inplace is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
|
||||
]
|
||||
if g.mutable is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
return [
|
||||
f"#include <ATen/ops/{g.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
return {
|
||||
"ops_headers": gen_op_headers(g),
|
||||
"func_definitions": gen_functionalization_definition(
|
||||
@ -2684,6 +2686,31 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
},
|
||||
)
|
||||
|
||||
cpu_fm.write(
|
||||
"ViewMetaClasses.h",
|
||||
lambda: {
|
||||
"view_meta_declarations": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
|
||||
view_groups,
|
||||
)
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
cpu_fm.write(
|
||||
"ViewMetaClasses.cpp",
|
||||
lambda: {
|
||||
"view_meta_implementations": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
|
||||
view_groups,
|
||||
)
|
||||
),
|
||||
"op_headers": list(concatMap(gen_op_headers, view_groups)),
|
||||
},
|
||||
)
|
||||
|
||||
# Note [view_copy NativeFunctions]
|
||||
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
|
||||
# needs to have a corresponding non-aliasing {view}_copy variant.
|
||||
|
@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from torchgen.api import cpp, dispatcher
|
||||
from torchgen.api import cpp, dispatcher, functionalization
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
CType,
|
||||
DispatcherSignature,
|
||||
FunctionalizationLambda,
|
||||
iTensorListRefT,
|
||||
NativeSignature,
|
||||
OptionalCType,
|
||||
@ -48,7 +47,7 @@ from torchgen.native_function_generation import (
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
)
|
||||
from torchgen.utils import dataclass_repr
|
||||
from torchgen.utils import concatMap, dataclass_repr, FileManager
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -365,6 +364,8 @@ def emit_view_functionalization_body(
|
||||
with native_function_manager(f):
|
||||
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
|
||||
|
||||
spec = ViewMetaSpecialization(g, f=f)
|
||||
|
||||
# the "view_copy" op name that the functionalization kernels need to call
|
||||
api_name = g.view_copy.func.name.unambiguous_name()
|
||||
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
|
||||
@ -385,9 +386,6 @@ def emit_view_functionalization_body(
|
||||
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
|
||||
]
|
||||
|
||||
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
|
||||
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
|
||||
|
||||
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
|
||||
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
||||
meta_call_args = [
|
||||
@ -415,19 +413,7 @@ def emit_view_functionalization_body(
|
||||
: at::functionalization::InverseReturnMode::NeverView
|
||||
);
|
||||
{symbolic_inputs_check}
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
{forward_lambda.decl()} {{
|
||||
if (reapply_views) {{
|
||||
return {forward_lambda.inner_call(reapply_views=True)}
|
||||
}} else {{
|
||||
return {forward_lambda.inner_call(reapply_views=False)}
|
||||
}}
|
||||
}},
|
||||
{reverse_lambda.decl()} {{
|
||||
return {reverse_lambda.inner_call()}
|
||||
}},
|
||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
|
||||
);
|
||||
auto view_meta = {spec.new()};
|
||||
auto compute_reference_meta =
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
||||
@ -455,7 +441,6 @@ def emit_view_functionalization_body(
|
||||
"""
|
||||
|
||||
else:
|
||||
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
{unwrap_tensor_args_str}
|
||||
@ -489,21 +474,7 @@ def emit_view_functionalization_body(
|
||||
}}
|
||||
}}
|
||||
{symbolic_inputs_check}
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
{forward_lambda.decl()} {{
|
||||
if (reapply_views) {{
|
||||
return {forward_lambda.inner_call(reapply_views=True)}
|
||||
}} else {{
|
||||
return {forward_lambda.inner_call(reapply_views=False)}
|
||||
}}
|
||||
}},
|
||||
{reverse_lambda.decl()} {{
|
||||
return {reverse_lambda.inner_call()}
|
||||
}},
|
||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
|
||||
/*is_multi_output=*/{str(is_multi_output_view).lower()},
|
||||
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
|
||||
);
|
||||
auto view_meta = {spec.new()};
|
||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
if (compute_reference_meta) {{
|
||||
@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration(
|
||||
return emit_decl_helper(g)
|
||||
|
||||
|
||||
# Helper class for generating `ViewMeta` specializations.
|
||||
@dataclass
|
||||
class ViewMetaSpecialization:
|
||||
g: NativeFunctionsViewGroup
|
||||
f: NativeFunction
|
||||
|
||||
@property
|
||||
def is_multi_output(self) -> bool:
|
||||
return functionalization.is_multi_output(self.f.func)
|
||||
|
||||
@property
|
||||
def is_as_strided(self) -> bool:
|
||||
return str(self.f.func.name) == "as_strided"
|
||||
|
||||
@property
|
||||
def out_index(self) -> str:
|
||||
if self.is_multi_output:
|
||||
return functionalization.out_index_binding.name
|
||||
return "0"
|
||||
|
||||
@property
|
||||
def classname(self) -> str:
|
||||
return functionalization.classname(self.f.func)
|
||||
|
||||
def decl(self) -> list[str]:
|
||||
base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func)
|
||||
extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func)
|
||||
attributes = functionalization.attributes(self.f.func)
|
||||
|
||||
# List of types for declaring the `SerializableTuple` type.
|
||||
serializable_tuple_args = ",\n".join(
|
||||
f" {binding.type} /* {binding.name} */"
|
||||
for binding in (base_ctor_arguments + attributes)
|
||||
)
|
||||
|
||||
# Arguments used for forwarding the tuple elements to the constructor.
|
||||
destructure_tuple_args = ", ".join(
|
||||
f"std::get<{i}>(tpl)"
|
||||
for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments))
|
||||
)
|
||||
|
||||
# List of constructor parameters
|
||||
ctor_parameters = ", ".join(
|
||||
binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments)
|
||||
)
|
||||
|
||||
# Call the base class `ViewMeta` constructor.
|
||||
#
|
||||
# Both of `is_multi_output` and `is_as_strided` are known values, given the
|
||||
# operation schema.
|
||||
is_multi_output_str = str(self.is_multi_output).lower()
|
||||
is_as_strided_str = str(self.is_as_strided).lower()
|
||||
|
||||
base_ctor_bindings = ", ".join(
|
||||
[
|
||||
# `has_symbolic_inputs` is always taken as parameter.
|
||||
functionalization.has_symbolic_inputs_binding.name,
|
||||
f"/*is_multi_output=*/{is_multi_output_str}",
|
||||
f"/*is_as_strided=*/{is_as_strided_str}",
|
||||
# `out_index` is know if the operation returns only one value. Otherwise,
|
||||
# we also take it as parameter.
|
||||
f"/*out_index=*/{self.out_index}",
|
||||
]
|
||||
)
|
||||
|
||||
# Assignments of `extra_ctor_arguments` to their corresponding fields.
|
||||
# These are extra fields to-be-declared in this specialization.
|
||||
#
|
||||
# We need to set `allow_expensive_conversions`, since we are storing owned versions
|
||||
# of the non-owning arguments.
|
||||
ctor_assignments = ",\n".join(
|
||||
f" {e.type.name}({e.expr})"
|
||||
for e in translate(
|
||||
extra_ctor_arguments,
|
||||
attributes,
|
||||
method=False,
|
||||
allow_expensive_conversions=True,
|
||||
)
|
||||
)
|
||||
|
||||
# List of arguments for constructing the `SerializableTuple` from an instance.
|
||||
tuple_arguments = ", ".join(
|
||||
binding.name for binding in (base_ctor_arguments + attributes)
|
||||
)
|
||||
|
||||
# List of field declarations.
|
||||
attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes)
|
||||
|
||||
# Override `to_out_index` if this operation returns more than 1 value.
|
||||
to_out_index_decl = ""
|
||||
if self.is_multi_output:
|
||||
to_out_index_decl = (
|
||||
" std::shared_ptr<ViewMeta> to_out_index(int64_t out_idx) override;"
|
||||
)
|
||||
|
||||
return [
|
||||
f"""
|
||||
struct TORCH_API {self.classname} : public ViewMeta {{
|
||||
FUNCTIONALIZATION_VIEWMETA_NAME({self.classname});
|
||||
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args});
|
||||
|
||||
{self.classname}(const SerializableTuple& tpl)
|
||||
: {self.classname}({destructure_tuple_args}) {{}}
|
||||
|
||||
{self.classname}({ctor_parameters})
|
||||
: at::functionalization::ViewMeta({base_ctor_bindings}),
|
||||
{ctor_assignments} {{}}
|
||||
|
||||
Tensor forward(const Tensor& base) override;
|
||||
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
|
||||
{to_out_index_decl}
|
||||
|
||||
SerializableTuple to_serializable_tuple() {{
|
||||
return std::make_tuple({tuple_arguments});
|
||||
}}
|
||||
|
||||
{attr_declarations}
|
||||
}};
|
||||
"""
|
||||
]
|
||||
|
||||
# Generate a call to the actual operation.
|
||||
def opcall(self, is_reverse: bool, reapply_views: bool) -> str:
|
||||
opname = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=is_reverse,
|
||||
include_namespace=True,
|
||||
reapply_views=reapply_views,
|
||||
)
|
||||
|
||||
# Expected arguments for the operation.
|
||||
assert self.g.view_copy is not None
|
||||
op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse)
|
||||
|
||||
# The context is composed by the constructor arguments (which are also
|
||||
# the field variables stored in the instance), and the `base` tensor.
|
||||
context = [functionalization.base_binding]
|
||||
context += functionalization.base_ctor_arguments(self.f.func)
|
||||
context += functionalization.attributes(self.f.func)
|
||||
|
||||
# If we are generating the call for the reverse function, we also have
|
||||
# access to `mutated_view` argument.
|
||||
if is_reverse:
|
||||
context.append(functionalization.mutated_view_binding)
|
||||
|
||||
arguments = ", ".join(
|
||||
[e.expr for e in translate(context, op_arguments, method=False)]
|
||||
)
|
||||
|
||||
# Index the result if this operation returns multiple values.
|
||||
maybe_index = ""
|
||||
if not is_reverse and self.is_multi_output:
|
||||
maybe_index = f"[{self.out_index}]"
|
||||
|
||||
return f"{opname}({arguments}){maybe_index}"
|
||||
|
||||
def impl(self) -> list[str]:
|
||||
functions = [
|
||||
f"""
|
||||
at::Tensor {self.classname}::forward(const at::Tensor& base) {{
|
||||
if (reapply_views) {{
|
||||
return {self.opcall(is_reverse=False, reapply_views=True)};
|
||||
}} else {{
|
||||
return {self.opcall(is_reverse=False, reapply_views=False)};
|
||||
}}
|
||||
}}""",
|
||||
f"""
|
||||
at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{
|
||||
return {self.opcall(is_reverse=True, reapply_views=True)};
|
||||
}}""",
|
||||
]
|
||||
|
||||
# If this operation returns multiple values, also generate a `to_out_index`
|
||||
# implementation.
|
||||
if self.is_multi_output:
|
||||
functions.append(f"""
|
||||
std::shared_ptr<at::functionalization::ViewMeta> {self.classname}::to_out_index(int64_t out_index) {{
|
||||
return {self.new("out_index")};
|
||||
}}
|
||||
""")
|
||||
|
||||
return functions
|
||||
|
||||
# Create the Python binding for this specialized class.
|
||||
def binding(self) -> list[str]:
|
||||
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||
return [f" create_binding_with_pickle<{name}>(functionalization);"]
|
||||
|
||||
# Generate an instanciation of this specialized class.
|
||||
def new(self, out_index: str = "0") -> str:
|
||||
name = functionalization.classname(self.f.func, with_namespace=True)
|
||||
ctor_arguments = functionalization.base_ctor_arguments(
|
||||
self.f.func
|
||||
) + functionalization.extra_ctor_arguments(self.f.func)
|
||||
# Replace the `out_index` parameter with the given `out_index`.
|
||||
arguments = ", ".join(
|
||||
binding.name if binding.name != "out_index" else out_index
|
||||
for binding in ctor_arguments
|
||||
)
|
||||
return f"std::make_shared<{name}>({arguments})"
|
||||
|
||||
# Run the function `run` for both: `view` and `view_inplace` functions.
|
||||
@staticmethod
|
||||
def map(
|
||||
g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]]
|
||||
) -> list[str]:
|
||||
def maybe_run(f: Optional[NativeFunction]) -> list[str]:
|
||||
if f is None:
|
||||
return []
|
||||
with native_function_manager(f):
|
||||
return run(ViewMetaSpecialization(g, f))
|
||||
|
||||
return list(concatMap(maybe_run, (g.view, g.view_inplace)))
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_base(
|
||||
selector: SelectiveBuilder,
|
||||
g: NativeFunctionsViewGroup,
|
||||
run: Callable[[ViewMetaSpecialization], list[str]],
|
||||
) -> list[str]:
|
||||
if not selector.include_all_operators:
|
||||
return []
|
||||
|
||||
if g.composite:
|
||||
return []
|
||||
|
||||
return ViewMetaSpecialization.map(g, run)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_decl(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.decl
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_impl(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.impl
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_view_meta_classes_binding(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> list[str]:
|
||||
return gen_functionalization_view_meta_classes_base(
|
||||
selector, g, ViewMetaSpecialization.binding
|
||||
)
|
||||
|
||||
|
||||
# Generates the Python bindings for the `ViewMeta` specialized classes.
|
||||
def gen_functionalization_view_meta_classes(
|
||||
native_functions_path: str,
|
||||
tags_path: str,
|
||||
selector: SelectiveBuilder,
|
||||
install_dir: str,
|
||||
template_dir: str,
|
||||
) -> None:
|
||||
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
|
||||
|
||||
# Parse the native_functions.yaml.
|
||||
# Then, group them into `NativeFunctionsViewGroup`.
|
||||
#
|
||||
# This is the same steps we do in gen.py (ATen codegen).
|
||||
native_functions = parse_native_yaml(
|
||||
native_functions_path, tags_path
|
||||
).native_functions
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(
|
||||
native_functions
|
||||
)
|
||||
view_groups = [
|
||||
g
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
]
|
||||
|
||||
fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
|
||||
fm.write(
|
||||
"ViewMetaClassesPythonBinding.cpp",
|
||||
lambda: {
|
||||
"view_meta_bindings": list(
|
||||
concatMap(
|
||||
lambda g: gen_functionalization_view_meta_classes_binding(
|
||||
selector, g
|
||||
),
|
||||
view_groups,
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def gen_functionalization_registration(
|
||||
selector: SelectiveBuilder,
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
|
Reference in New Issue
Block a user