Reify view_func() closures as ViewFuncs (#118404)

Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
This commit is contained in:
Joel Schlosser
2024-02-08 17:06:32 -05:00
committed by PyTorch MergeBot
parent 261f0138a2
commit d5a6762263
23 changed files with 718 additions and 117 deletions

View File

@ -525,16 +525,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
Py_RETURN_NONE;
}
// Maps the given python callable over a vector of items, returning a vector
// of the same type of items.
template <typename T>
static std::vector<T> map_py_func(
const py::function& func,
const std::vector<T>& items) {
std::vector<T> new_items;
new_items.reserve(items.size());
for (auto& item : items) {
new_items.push_back(py::cast<T>(func(item)));
}
return new_items;
}
static PyObject* view_func_impl(
PyObject* self_,
PyObject* arg,
PyObject* _self,
PyObject* args,
PyObject* kwargs,
bool check_has_same_meta) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK(
THPVariable_Check(arg),
"_view_func expect a single argument that is a Tensor");
const auto& new_base = THPVariable_Unpack(arg);
const auto& self = THPVariable_Unpack(_self);
static PythonArgParser parser({
"_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
});
ParsedArgs<3> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args);
auto new_base = r.tensor(0);
PyObject* symint_visitor_fn = r.pyobject(1);
PyObject* tensor_visitor_fn = r.pyobject(2);
// Ensure that self is indeed a backward differentiable view
// If not, we return an undefined Tensor (None) and let the user handle it.
@ -547,7 +567,29 @@ static PyObject* view_func_impl(
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
// Do the actual view replay
if (view_info.has_view_fn()) {
out = view_info.view_fn()(new_base);
auto& view_func = view_info.view_fn();
// Determine new SymInt / tensor state as needed.
c10::optional<std::vector<c10::SymInt>> new_symints = c10::nullopt;
if (symint_visitor_fn != Py_None) {
new_symints = map_py_func(
py::cast<py::function>(symint_visitor_fn),
view_func.get_symints());
}
c10::optional<std::vector<at::Tensor>> new_tensors = c10::nullopt;
if (tensor_visitor_fn != Py_None) {
new_tensors = map_py_func(
py::cast<py::function>(tensor_visitor_fn),
view_func.get_tensors());
}
// call view func
if (new_symints.has_value() || new_tensors.has_value()) {
out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
} else {
out = view_func(new_base);
}
} else {
out = new_base.as_strided(
self.sizes(), self.strides(), self.storage_offset());
@ -558,12 +600,18 @@ static PyObject* view_func_impl(
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
static PyObject* THPVariable_view_func(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
}
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
static PyObject* THPVariable_view_func_unsafe(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
}
static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
@ -1668,8 +1716,14 @@ static PyMethodDef extra_methods[] = {
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
{"_view_func", THPVariable_view_func, METH_O, nullptr},
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
{"_view_func",
castPyCFunctionWithKeywords(THPVariable_view_func),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_view_func_unsafe",
castPyCFunctionWithKeywords(THPVariable_view_func_unsafe),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_rev_view_func_unsafe",
THPVariable_rev_view_func_unsafe,
METH_O,