Replay view with view_func instead of as_strided in meta_utils for NT (#112205)

Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently,  _view_func IS always available for nested tensors.

A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer
2023-10-27 14:58:47 -04:00
committed by PyTorch MergeBot
parent 503955f5ec
commit 0cda4c8abe
4 changed files with 118 additions and 14 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
import functools import functools
import itertools
import unittest import unittest
import torch import torch
@ -739,13 +740,15 @@ class GraphModule(torch.nn.Module):
class TestNestedTensor(torch._dynamo.test_case.TestCase): class TestNestedTensor(torch._dynamo.test_case.TestCase):
def _get_jagged_tensor(self, nested_size, offsets): def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size # Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D) # as specified ((S0, S1, S2), D)
D = nested_size[1] D = nested_size[1]
out = [] out = []
for s in nested_size[0]: for s in nested_size[0]:
out.append(torch.randn(s, D, requires_grad=True, dtype=torch.float64)) out.append(
torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
)
return jagged_from_list(out, offsets) return jagged_from_list(out, offsets)
def _check_recompiles(self, fn, inputs1, inputs2, recompiles): def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
@ -858,6 +861,73 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True) self._check_recompiles(fn, (nt,), (nt3,), True)
def _get_views(self):
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
self.assertEqual(x.is_leaf, False)
yield x.unsqueeze(-1)
# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
x, _ = self._get_jagged_tensor(
((2, 3, 4), 3), None, requires_grad=requires_grad_1
)
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
yield x_view
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
yield x_view_view
def test_inputs_to_compiled_fn_are_views(self):
for nt_view in self._get_views():
def fn(x):
return x.sin()
out_ref = fn(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend="aot_eager", dynamic=True
)
out = compile_fn(nt_view)
# Check metadata and values are correct
self.assertTrue(out.size() == out_ref.size())
self.assertTrue(out.stride() == out_ref.stride())
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
# Check that no guards are incurred
def backend(gm, args):
context = torch._guards.TracingContext.get()
val_to_guards = context.fake_mode.shape_env.var_to_guards.values()
self.assertEqual(len(val_to_guards), 0)
return gm
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend=backend, dynamic=True
)
out = compile_fn(nt_view)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -309,7 +309,7 @@ class MetaConverter:
from torch._dynamo.source import AttrSource from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import DimDynamic from torch.fx.experimental.symbolic_shapes import DimDynamic
if shape_env: if shape_env and not t.is_nested:
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim() base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
else: else:
base_dynamic_dims = None base_dynamic_dims = None
@ -369,25 +369,45 @@ class MetaConverter:
# #
# So we may have to do *two* views out of the base to # So we may have to do *two* views out of the base to
# recreate this situation. # recreate this situation.
def _view_from_base(base, t):
( if t.is_nested:
sizes, # Nested tensors do not support as_strided, and
strides, # hence,always have _view_func available.
storage_offset, #
) = sym_sizes_strides_storage_offset(t, source) # The unsafe version of _view_func omits
# checking whether the base passed in has the same
# metadata as the original base the view_func
# was originally executed with. (1) It is OK here,
# because we're calling it on the meta-ified base,
# so the metadata is guaranteed to be the same.
# (2) It is necessary because we don't actually
# want to guard on the base's metadata here.
return t._view_func_unsafe(base)
else:
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
return base.as_strided(sizes, strides, storage_offset)
if safe_is_leaf(t): if safe_is_leaf(t):
# Leaf views that track view metadata are created by # Leaf views that track view metadata are created by
# creating a view inside a no_grad block # creating a view inside a no_grad block
with torch.no_grad(), maybe_suppress(): with torch.no_grad(), maybe_suppress():
r = base.as_strided(sizes, strides, storage_offset) r = _view_from_base(base, t)
# As it's a leaf, we can directly assign requires_grad # As it's a leaf, we can directly assign requires_grad
r.requires_grad = t.requires_grad r.requires_grad = t.requires_grad
else: else:
if t._base.requires_grad == t.requires_grad: if t._base.requires_grad == t.requires_grad:
# Easy case, just run the view op # Easy case, just run the view op
with torch.enable_grad(), maybe_suppress(): with torch.enable_grad(), maybe_suppress():
r = base.as_strided(sizes, strides, storage_offset) r = _view_from_base(base, t)
# NB: We don't actaully faithfully replicate
# autograd connectivity, but that doesn't matter
# today. See following for more info:
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
else: else:
# Obscure case. Create a leaf view and give it the # Obscure case. Create a leaf view and give it the
# correct requires_grad, then do the final view. # correct requires_grad, then do the final view.
@ -397,7 +417,7 @@ class MetaConverter:
mid = base.view(base.shape) mid = base.view(base.shape)
mid.requires_grad = t.requires_grad mid.requires_grad = t.requires_grad
with torch.enable_grad(), maybe_suppress(): with torch.enable_grad(), maybe_suppress():
r = mid.as_strided(sizes, strides, storage_offset) r = _view_from_base(mid, t)
# The CreationMeta influences whether or not inplace # The CreationMeta influences whether or not inplace
# mutation is an error or not. So we need to make # mutation is an error or not. So we need to make
# sure we properly propagate this as well. # sure we properly propagate this as well.

View File

@ -525,7 +525,10 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { static PyObject* view_func_impl(
PyObject* self_,
PyObject* arg,
bool check_has_same_meta) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(self_); const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK( TORCH_CHECK(
@ -540,7 +543,8 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
if (diff_view_meta && diff_view_meta->has_bw_view()) { if (diff_view_meta && diff_view_meta->has_bw_view()) {
const auto& view_info = diff_view_meta->get_backward_view(); const auto& view_info = diff_view_meta->get_backward_view();
// Ensure that the newly provided base is similar to the original base // Ensure that the newly provided base is similar to the original base
if (torch::autograd::utils::has_same_meta(new_base, view_info.base_)) { if (!check_has_same_meta ||
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
// Do the actual view replay // Do the actual view replay
if (view_info.has_view_fn()) { if (view_info.has_view_fn()) {
out = view_info.view_fn()(new_base); out = view_info.view_fn()(new_base);
@ -554,6 +558,14 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
}
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
}
// Instantiates a subclass of self with the same data. // Instantiates a subclass of self with the same data.
static PyObject* THPVariable_as_subclass( static PyObject* THPVariable_as_subclass(
PyObject* _self, PyObject* _self,
@ -1637,6 +1649,7 @@ static PyMethodDef extra_methods[] = {
nullptr}, nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
{"_view_func", THPVariable_view_func, METH_O, nullptr}, {"_view_func", THPVariable_view_func, METH_O, nullptr},
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
{nullptr}}; {nullptr}};
struct THPVariableMeta { struct THPVariableMeta {

View File

@ -345,6 +345,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._reduce_ex_internal, Tensor._reduce_ex_internal,
Tensor._fix_weakref, Tensor._fix_weakref,
Tensor._view_func, Tensor._view_func,
Tensor._view_func_unsafe,
Tensor._make_wrapper_subclass, Tensor._make_wrapper_subclass,
Tensor._python_dispatch.__get__, Tensor._python_dispatch.__get__,
Tensor._has_symbolic_sizes_strides.__get__, Tensor._has_symbolic_sizes_strides.__get__,