mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make functionalization ViewMeta
serializable with pickle. (#143712)
Fix: #141974 This PR makes `ViewMeta` sequence, present in functional tensors, serializable with pickle. In order to accomplish that, it makes `ViewMeta` an abstract class with overridable `forward` and `reverse` functions. In this context, each operation that once instanciated `ViewMeta`, should now create a new specialized class that inherits from `ViewMeta. Therefore, this PR also uses codegen for creating these specializations. In summary, these are the changes this PR introduces: - `ViewMeta` is turned into an abstract class (see _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual functions that need to be implemented. `to_out_index` should be implemented by operations that might return more than 1 output. - New `ViewMeta` specializations for `resize_` and `_unsafe_view` are created (see _FunctionalizeFallbackKernel.h_). - New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the declaration and definition of the `ViewMeta` specializations, which are automatically generated in the ATen codegen (see _gen.py_). - New `_functionalization` Python sub-module is created (see _Module.cpp_). It serves as namespace for the `ViewMeta` specializations and `InverseReturnMode` enum. - New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds the automatically generated Python bindings for the `ViewMeta` specialization, which are generated in the torch codegen (see _generate_code.py_). Note that this PR makes use of codegen at 2 different moments: - ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes. - Torch codegen (_generate_code.py_): generated the Python bindings for them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7c3aa1da1c
commit
b8abdaa286
71
torch/csrc/functionalization/Module.cpp
Normal file
71
torch/csrc/functionalization/Module.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
#include <torch/csrc/functionalization/Module.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <ATen/FunctionalStorageImpl.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/FunctionalizeFallbackKernel.h>
|
||||
#include <memory>
|
||||
|
||||
namespace torch::functionalization {
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
// Create a `torch._C._functionalization` Python module.
|
||||
auto functionalization = m.def_submodule(
|
||||
"_functionalization", "functionalization related pybind.");
|
||||
|
||||
// Retrieve the ViewMeta sequence of a given functional tensor.
|
||||
functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
at::functionalization::impl::isFunctionalTensor(tensor));
|
||||
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
||||
return impl->view_metas();
|
||||
});
|
||||
|
||||
// Applies the given ViewMeta sequence to the given base.
|
||||
functionalization.def(
|
||||
"apply_view_meta_sequence",
|
||||
[](const at::Tensor& base,
|
||||
const std::vector<std::shared_ptr<at::functionalization::ViewMeta>>&
|
||||
sequence) {
|
||||
return at::functionalization::impl::apply_view_meta_sequence(
|
||||
base, sequence);
|
||||
});
|
||||
|
||||
// Binding for InverseReturnMode.
|
||||
py::enum_<at::functionalization::InverseReturnMode>(
|
||||
functionalization, "InverseReturnMode")
|
||||
.value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView)
|
||||
.value("NeverView", at::functionalization::InverseReturnMode::NeverView)
|
||||
.value(
|
||||
"ViewOrScatterInverse",
|
||||
at::functionalization::InverseReturnMode::ViewOrScatterInverse);
|
||||
|
||||
// Create bindings for the ViewMeta base class.
|
||||
//
|
||||
// Needed so that we can take a list of ViewMeta objects as parameter.
|
||||
// Specifically, in the Python-side, we will have a list of derived ViewMeta
|
||||
// classes. We need to tell pybind11 that all of those are, in fact, instances
|
||||
// of different ViewMeta sub-types.
|
||||
py::class_<
|
||||
at::functionalization::ViewMeta,
|
||||
std::shared_ptr<at::functionalization::ViewMeta>>(
|
||||
functionalization, "ViewMeta")
|
||||
.def_property_readonly(
|
||||
"has_symbolic_inputs",
|
||||
[](const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
|
||||
return meta->has_symbolic_inputs;
|
||||
});
|
||||
|
||||
// Bindings for `ViewMeta` specializations manually implemented.
|
||||
create_binding_with_pickle<at::functionalization::resize__ViewMeta>(
|
||||
functionalization);
|
||||
create_binding_with_pickle<at::functionalization::_unsafe_view_ViewMeta>(
|
||||
functionalization);
|
||||
|
||||
// Bindings for `ViewMeta` specializations automatically generated.
|
||||
initGenerated(functionalization.ptr());
|
||||
}
|
||||
|
||||
} // namespace torch::functionalization
|
Reference in New Issue
Block a user