mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Replay view with view_func instead of as_strided in meta_utils for NT (#112205)
Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently, _view_func IS always available for nested tensors. A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
503955f5ec
commit
0cda4c8abe
@ -1,5 +1,6 @@
|
|||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -739,13 +740,15 @@ class GraphModule(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||||
def _get_jagged_tensor(self, nested_size, offsets):
|
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
|
||||||
# Makes a jagged tensor with N constituent tensors with size
|
# Makes a jagged tensor with N constituent tensors with size
|
||||||
# as specified ((S0, S1, S2), D)
|
# as specified ((S0, S1, S2), D)
|
||||||
D = nested_size[1]
|
D = nested_size[1]
|
||||||
out = []
|
out = []
|
||||||
for s in nested_size[0]:
|
for s in nested_size[0]:
|
||||||
out.append(torch.randn(s, D, requires_grad=True, dtype=torch.float64))
|
out.append(
|
||||||
|
torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
|
||||||
|
)
|
||||||
return jagged_from_list(out, offsets)
|
return jagged_from_list(out, offsets)
|
||||||
|
|
||||||
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
||||||
@ -858,6 +861,73 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
self._check_recompiles(fn, (nt,), (nt2,), False)
|
self._check_recompiles(fn, (nt,), (nt2,), False)
|
||||||
self._check_recompiles(fn, (nt,), (nt3,), True)
|
self._check_recompiles(fn, (nt,), (nt3,), True)
|
||||||
|
|
||||||
|
def _get_views(self):
|
||||||
|
# There are three cases to consider here based on the logic in
|
||||||
|
# meta_utils.py
|
||||||
|
#
|
||||||
|
# (1) basic case:
|
||||||
|
# view is not a leaf and has the same requires grad as its basic case
|
||||||
|
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
|
||||||
|
self.assertEqual(x.is_leaf, False)
|
||||||
|
yield x.unsqueeze(-1)
|
||||||
|
|
||||||
|
# (2) leaf view case:
|
||||||
|
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
|
||||||
|
# base w/ requires_grad True or requires_grad False
|
||||||
|
for requires_grad_1, requires_grad_2 in itertools.product(
|
||||||
|
[True, False], repeat=2
|
||||||
|
):
|
||||||
|
x, _ = self._get_jagged_tensor(
|
||||||
|
((2, 3, 4), 3), None, requires_grad=requires_grad_1
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
x_view = x.unsqueeze(-1)
|
||||||
|
# The issue is this doesn't quite work
|
||||||
|
x_view.requires_grad_(requires_grad_2)
|
||||||
|
yield x_view
|
||||||
|
|
||||||
|
# (3) obscure case:
|
||||||
|
# view is not a leaf (implies requires_grad True)
|
||||||
|
# base w/ requires_grad False)
|
||||||
|
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
|
||||||
|
# intermediate leaf view
|
||||||
|
with torch.no_grad():
|
||||||
|
x_view = x.unsqueeze(-1)
|
||||||
|
x_view.requires_grad_(True)
|
||||||
|
x_view_view = x_view.unsqueeze(-1)
|
||||||
|
yield x_view_view
|
||||||
|
|
||||||
|
def test_inputs_to_compiled_fn_are_views(self):
|
||||||
|
for nt_view in self._get_views():
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
out_ref = fn(nt_view)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
compile_fn = torch.compile(
|
||||||
|
fn, fullgraph=True, backend="aot_eager", dynamic=True
|
||||||
|
)
|
||||||
|
out = compile_fn(nt_view)
|
||||||
|
|
||||||
|
# Check metadata and values are correct
|
||||||
|
self.assertTrue(out.size() == out_ref.size())
|
||||||
|
self.assertTrue(out.stride() == out_ref.stride())
|
||||||
|
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
|
||||||
|
|
||||||
|
# Check that no guards are incurred
|
||||||
|
def backend(gm, args):
|
||||||
|
context = torch._guards.TracingContext.get()
|
||||||
|
val_to_guards = context.fake_mode.shape_env.var_to_guards.values()
|
||||||
|
self.assertEqual(len(val_to_guards), 0)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
torch._dynamo.reset()
|
||||||
|
compile_fn = torch.compile(
|
||||||
|
fn, fullgraph=True, backend=backend, dynamic=True
|
||||||
|
)
|
||||||
|
out = compile_fn(nt_view)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
@ -309,7 +309,7 @@ class MetaConverter:
|
|||||||
from torch._dynamo.source import AttrSource
|
from torch._dynamo.source import AttrSource
|
||||||
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
||||||
|
|
||||||
if shape_env:
|
if shape_env and not t.is_nested:
|
||||||
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
||||||
else:
|
else:
|
||||||
base_dynamic_dims = None
|
base_dynamic_dims = None
|
||||||
@ -369,25 +369,45 @@ class MetaConverter:
|
|||||||
#
|
#
|
||||||
# So we may have to do *two* views out of the base to
|
# So we may have to do *two* views out of the base to
|
||||||
# recreate this situation.
|
# recreate this situation.
|
||||||
|
def _view_from_base(base, t):
|
||||||
(
|
if t.is_nested:
|
||||||
sizes,
|
# Nested tensors do not support as_strided, and
|
||||||
strides,
|
# hence,always have _view_func available.
|
||||||
storage_offset,
|
#
|
||||||
) = sym_sizes_strides_storage_offset(t, source)
|
# The unsafe version of _view_func omits
|
||||||
|
# checking whether the base passed in has the same
|
||||||
|
# metadata as the original base the view_func
|
||||||
|
# was originally executed with. (1) It is OK here,
|
||||||
|
# because we're calling it on the meta-ified base,
|
||||||
|
# so the metadata is guaranteed to be the same.
|
||||||
|
# (2) It is necessary because we don't actually
|
||||||
|
# want to guard on the base's metadata here.
|
||||||
|
return t._view_func_unsafe(base)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
sizes,
|
||||||
|
strides,
|
||||||
|
storage_offset,
|
||||||
|
) = sym_sizes_strides_storage_offset(t, source)
|
||||||
|
return base.as_strided(sizes, strides, storage_offset)
|
||||||
|
|
||||||
if safe_is_leaf(t):
|
if safe_is_leaf(t):
|
||||||
# Leaf views that track view metadata are created by
|
# Leaf views that track view metadata are created by
|
||||||
# creating a view inside a no_grad block
|
# creating a view inside a no_grad block
|
||||||
with torch.no_grad(), maybe_suppress():
|
with torch.no_grad(), maybe_suppress():
|
||||||
r = base.as_strided(sizes, strides, storage_offset)
|
r = _view_from_base(base, t)
|
||||||
# As it's a leaf, we can directly assign requires_grad
|
# As it's a leaf, we can directly assign requires_grad
|
||||||
r.requires_grad = t.requires_grad
|
r.requires_grad = t.requires_grad
|
||||||
else:
|
else:
|
||||||
if t._base.requires_grad == t.requires_grad:
|
if t._base.requires_grad == t.requires_grad:
|
||||||
# Easy case, just run the view op
|
# Easy case, just run the view op
|
||||||
with torch.enable_grad(), maybe_suppress():
|
with torch.enable_grad(), maybe_suppress():
|
||||||
r = base.as_strided(sizes, strides, storage_offset)
|
r = _view_from_base(base, t)
|
||||||
|
|
||||||
|
# NB: We don't actaully faithfully replicate
|
||||||
|
# autograd connectivity, but that doesn't matter
|
||||||
|
# today. See following for more info:
|
||||||
|
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
|
||||||
else:
|
else:
|
||||||
# Obscure case. Create a leaf view and give it the
|
# Obscure case. Create a leaf view and give it the
|
||||||
# correct requires_grad, then do the final view.
|
# correct requires_grad, then do the final view.
|
||||||
@ -397,7 +417,7 @@ class MetaConverter:
|
|||||||
mid = base.view(base.shape)
|
mid = base.view(base.shape)
|
||||||
mid.requires_grad = t.requires_grad
|
mid.requires_grad = t.requires_grad
|
||||||
with torch.enable_grad(), maybe_suppress():
|
with torch.enable_grad(), maybe_suppress():
|
||||||
r = mid.as_strided(sizes, strides, storage_offset)
|
r = _view_from_base(mid, t)
|
||||||
# The CreationMeta influences whether or not inplace
|
# The CreationMeta influences whether or not inplace
|
||||||
# mutation is an error or not. So we need to make
|
# mutation is an error or not. So we need to make
|
||||||
# sure we properly propagate this as well.
|
# sure we properly propagate this as well.
|
||||||
|
@ -525,7 +525,10 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
|
|||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
static PyObject* view_func_impl(
|
||||||
|
PyObject* self_,
|
||||||
|
PyObject* arg,
|
||||||
|
bool check_has_same_meta) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
const auto& self = THPVariable_Unpack(self_);
|
const auto& self = THPVariable_Unpack(self_);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
@ -540,7 +543,8 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
|||||||
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
||||||
const auto& view_info = diff_view_meta->get_backward_view();
|
const auto& view_info = diff_view_meta->get_backward_view();
|
||||||
// Ensure that the newly provided base is similar to the original base
|
// Ensure that the newly provided base is similar to the original base
|
||||||
if (torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
|
if (!check_has_same_meta ||
|
||||||
|
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
|
||||||
// Do the actual view replay
|
// Do the actual view replay
|
||||||
if (view_info.has_view_fn()) {
|
if (view_info.has_view_fn()) {
|
||||||
out = view_info.view_fn()(new_base);
|
out = view_info.view_fn()(new_base);
|
||||||
@ -554,6 +558,14 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||||
|
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
|
||||||
|
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
// Instantiates a subclass of self with the same data.
|
// Instantiates a subclass of self with the same data.
|
||||||
static PyObject* THPVariable_as_subclass(
|
static PyObject* THPVariable_as_subclass(
|
||||||
PyObject* _self,
|
PyObject* _self,
|
||||||
@ -1637,6 +1649,7 @@ static PyMethodDef extra_methods[] = {
|
|||||||
nullptr},
|
nullptr},
|
||||||
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
||||||
{"_view_func", THPVariable_view_func, METH_O, nullptr},
|
{"_view_func", THPVariable_view_func, METH_O, nullptr},
|
||||||
|
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
|
||||||
{nullptr}};
|
{nullptr}};
|
||||||
|
|
||||||
struct THPVariableMeta {
|
struct THPVariableMeta {
|
||||||
|
@ -345,6 +345,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||||||
Tensor._reduce_ex_internal,
|
Tensor._reduce_ex_internal,
|
||||||
Tensor._fix_weakref,
|
Tensor._fix_weakref,
|
||||||
Tensor._view_func,
|
Tensor._view_func,
|
||||||
|
Tensor._view_func_unsafe,
|
||||||
Tensor._make_wrapper_subclass,
|
Tensor._make_wrapper_subclass,
|
||||||
Tensor._python_dispatch.__get__,
|
Tensor._python_dispatch.__get__,
|
||||||
Tensor._has_symbolic_sizes_strides.__get__,
|
Tensor._has_symbolic_sizes_strides.__get__,
|
||||||
|
Reference in New Issue
Block a user