Files
pytorch/torch/csrc/functionalization/Module.h
Yukio Siraichi b8abdaa286 Make functionalization ViewMeta serializable with pickle. (#143712)
Fix: #141974

This PR makes `ViewMeta` sequence, present in functional tensors,
serializable with pickle. In order to accomplish that, it makes
`ViewMeta` an abstract class with overridable `forward` and `reverse`
functions. In this context, each operation that once instanciated
`ViewMeta`, should now create a new specialized class that inherits from
`ViewMeta. Therefore, this PR also uses codegen for creating these
specializations.

In summary, these are the changes this PR introduces:

- `ViewMeta` is turned into an abstract class (see
  _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual
  functions that need to be implemented. `to_out_index` should be
  implemented by operations that might return more than 1 output.

- New `ViewMeta` specializations for `resize_` and `_unsafe_view` are
  created (see _FunctionalizeFallbackKernel.h_).

- New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the
  declaration and definition of the `ViewMeta` specializations, which
  are automatically generated in the ATen codegen (see _gen.py_).

- New `_functionalization` Python sub-module is created (see
  _Module.cpp_). It serves as namespace for the `ViewMeta`
  specializations and `InverseReturnMode` enum.

- New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds
  the automatically generated Python bindings for the `ViewMeta`
  specialization, which are generated in the torch codegen (see
  _generate_code.py_).

Note that this PR makes use of codegen at 2 different moments:

- ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes.
- Torch codegen (_generate_code.py_): generated the Python bindings for
  them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712
Approved by: https://github.com/bdhirsh
2025-01-16 19:41:41 +00:00

37 lines
1.0 KiB
C++

#pragma once
#include <ATen/FunctionalStorageImpl.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::functionalization {
// Creates the default bindings for `ViewMeta` specializations.
//
// Defines a constructor using the types in `SerializableTuple`, as well
// as pickle methods.
template <class T>
void create_binding_with_pickle(py::module m) {
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
m, T::name())
.def(py::init<typename T::SerializableTuple>())
.def(
"as_tuple",
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
})
.def(py::pickle(
[](const std::shared_ptr<T>& meta) {
return meta->to_serializable_tuple();
},
[](const typename T::SerializableTuple& tpl) {
return std::make_shared<T>(tpl);
}));
}
void initModule(PyObject* module);
void initGenerated(PyObject* module);
} // namespace torch::functionalization