mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix: #141974 This PR makes `ViewMeta` sequence, present in functional tensors, serializable with pickle. In order to accomplish that, it makes `ViewMeta` an abstract class with overridable `forward` and `reverse` functions. In this context, each operation that once instanciated `ViewMeta`, should now create a new specialized class that inherits from `ViewMeta. Therefore, this PR also uses codegen for creating these specializations. In summary, these are the changes this PR introduces: - `ViewMeta` is turned into an abstract class (see _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual functions that need to be implemented. `to_out_index` should be implemented by operations that might return more than 1 output. - New `ViewMeta` specializations for `resize_` and `_unsafe_view` are created (see _FunctionalizeFallbackKernel.h_). - New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the declaration and definition of the `ViewMeta` specializations, which are automatically generated in the ATen codegen (see _gen.py_). - New `_functionalization` Python sub-module is created (see _Module.cpp_). It serves as namespace for the `ViewMeta` specializations and `InverseReturnMode` enum. - New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds the automatically generated Python bindings for the `ViewMeta` specialization, which are generated in the torch codegen (see _generate_code.py_). Note that this PR makes use of codegen at 2 different moments: - ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes. - Torch codegen (_generate_code.py_): generated the Python bindings for them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712 Approved by: https://github.com/bdhirsh
37 lines
1.0 KiB
C++
37 lines
1.0 KiB
C++
#pragma once
|
|
|
|
#include <ATen/FunctionalStorageImpl.h>
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch::functionalization {
|
|
|
|
// Creates the default bindings for `ViewMeta` specializations.
|
|
//
|
|
// Defines a constructor using the types in `SerializableTuple`, as well
|
|
// as pickle methods.
|
|
template <class T>
|
|
void create_binding_with_pickle(py::module m) {
|
|
py::class_<T, std::shared_ptr<T>, at::functionalization::ViewMeta>(
|
|
m, T::name())
|
|
.def(py::init<typename T::SerializableTuple>())
|
|
.def(
|
|
"as_tuple",
|
|
[](const std::shared_ptr<T>& meta) {
|
|
return meta->to_serializable_tuple();
|
|
})
|
|
.def(py::pickle(
|
|
[](const std::shared_ptr<T>& meta) {
|
|
return meta->to_serializable_tuple();
|
|
},
|
|
[](const typename T::SerializableTuple& tpl) {
|
|
return std::make_shared<T>(tpl);
|
|
}));
|
|
}
|
|
|
|
void initModule(PyObject* module);
|
|
void initGenerated(PyObject* module);
|
|
|
|
} // namespace torch::functionalization
|