mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on. ```cpp /// Base class for view functions, providing reapplication of a view on a new base. /// Each view op should get a codegenerated subclass of this class containing /// any state needed to reconstruct the view. The class also provides convenience /// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification, /// where we want to use symbolic values or fake tensors instead. struct TORCH_API ViewFunc { virtual ~ViewFunc() {} /// Returns any SymInts in the saved state. virtual std::vector<c10::SymInt> get_symints() const { return {}; } /// Returns the number of SymInts in the saved state. virtual size_t num_symints() const { return 0; } /// Returns any tensors in the saved state. virtual std::vector<at::Tensor> get_tensors() const { return {}; } /// Returns the number of tensors in the saved state. virtual size_t num_tensors() const { return 0; } /// Reapplies the view on the given base using the saved state. virtual at::Tensor operator()(const at::Tensor&) const = 0; /// Returns a clone of this ViewFunc, optionally with the specified saved state. virtual std::unique_ptr<ViewFunc> clone_and_set( std::optional<std::vector<c10::SymInt>> = c10::nullopt, std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0; protected: /// Sets the values of any SymInts in the saved state. The input vector size must /// match the number of SymInts in the saved state (i.e. the size of the list /// returned by get_symints()). virtual void set_symints(std::vector<c10::SymInt>) {} /// Sets the values of any Tensors in the saved state. The input vector size must /// match the number of Tensors in the saved state (i.e. the size of the list /// returned by get_tensors()). virtual void set_tensors(std::vector<at::Tensor>) {} }; ``` New codegen files: * `torch/csrc/autograd/generated/ViewFunc.h` * `torch/csrc/autograd/generated/ViewFuncs.cpp` The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd. Example codegen for `slice.Tensor`: ```cpp // torch/csrc/autograd/generated/ViewFuncs.h #define SLICE_TENSOR_VIEW_FUNC_AVAILABLE struct SliceTensorViewFunc : public torch::autograd::ViewFunc { SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step) {}; virtual ~SliceTensorViewFunc() override {}; virtual std::vector<c10::SymInt> get_symints() const override; virtual size_t num_symints() const override; virtual std::vector<at::Tensor> get_tensors() const override; virtual size_t num_tensors() const override; virtual at::Tensor operator()(const at::Tensor&) const override; virtual std::unique_ptr<ViewFunc> clone_and_set( std::optional<std::vector<c10::SymInt>> = c10::nullopt, std::optional<std::vector<at::Tensor>> = c10::nullopt) const override; protected: virtual void set_symints(std::vector<c10::SymInt>) override; virtual void set_tensors(std::vector<at::Tensor>) override; private: int64_t dim; c10::optional<c10::SymInt> start; c10::optional<c10::SymInt> end; c10::SymInt step; }; ... // torch/csrc/autograd/generated/ViewFuncs.cpp std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const { ::std::vector<c10::SymInt> symints; symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1); if(start.has_value()) symints.insert(symints.end(), *(start)); if(end.has_value()) symints.insert(symints.end(), *(end)); symints.push_back(step); return symints; } size_t SliceTensorViewFunc::num_symints() const { return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1); } void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) { TORCH_INTERNAL_ASSERT(symints.size() == num_symints()); auto i = 0; if(start.has_value()) start = symints[i]; i += (start.has_value() ? 1 : 0); if(end.has_value()) end = symints[i]; i += (end.has_value() ? 1 : 0); step = symints[i]; } std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const { ::std::vector<at::Tensor> tensors; return tensors; } size_t SliceTensorViewFunc::num_tensors() const { return static_cast<size_t>(0); } void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) { TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors()); } at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const { return at::_ops::slice_Tensor::call(input_base, dim, start, end, step); } std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set( std::optional<std::vector<c10::SymInt>> symints, std::optional<std::vector<at::Tensor>> tensors) const { auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step); if (symints.has_value()) { output->set_symints(std::move(*(symints))); } if (tensors.has_value()) { output->set_tensors(std::move(*(tensors))); } return output; } ``` The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification. For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly. ```sh python test/test_autograd.py -k test_view_func_replay python test/test_ops.py -k test_view_replay ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6b04251b87
commit
9ec8dd2467
@ -131,6 +131,8 @@ def get_generate_code_bin_outs():
|
||||
"autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
|
||||
"autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
|
||||
"autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
|
||||
"autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
|
||||
"autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
|
||||
}
|
||||
|
||||
if is_arvr_mode():
|
||||
|
Reference in New Issue
Block a user