Replay view with view_func instead of as_strided in meta_utils for NT (#112205)

Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently,  _view_func IS always available for nested tensors.

A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer
2023-10-27 14:58:47 -04:00
committed by PyTorch MergeBot
parent 503955f5ec
commit 0cda4c8abe
4 changed files with 118 additions and 14 deletions

View File

@ -525,7 +525,10 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
Py_RETURN_NONE;
}
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
static PyObject* view_func_impl(
PyObject* self_,
PyObject* arg,
bool check_has_same_meta) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK(
@ -540,7 +543,8 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
if (diff_view_meta && diff_view_meta->has_bw_view()) {
const auto& view_info = diff_view_meta->get_backward_view();
// Ensure that the newly provided base is similar to the original base
if (torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
if (!check_has_same_meta ||
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
// Do the actual view replay
if (view_info.has_view_fn()) {
out = view_info.view_fn()(new_base);
@ -554,6 +558,14 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
}
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
}
// Instantiates a subclass of self with the same data.
static PyObject* THPVariable_as_subclass(
PyObject* _self,
@ -1637,6 +1649,7 @@ static PyMethodDef extra_methods[] = {
nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
{"_view_func", THPVariable_view_func, METH_O, nullptr},
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
{nullptr}};
struct THPVariableMeta {