Replay view with view_func instead of as_strided in meta_utils for NT (#112205)

Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently,  _view_func IS always available for nested tensors.

A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer
2023-10-27 14:58:47 -04:00
committed by PyTorch MergeBot
parent 503955f5ec
commit 0cda4c8abe
4 changed files with 118 additions and 14 deletions

View File

@ -309,7 +309,7 @@ class MetaConverter:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import DimDynamic
if shape_env:
if shape_env and not t.is_nested:
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
else:
base_dynamic_dims = None
@ -369,25 +369,45 @@ class MetaConverter:
#
# So we may have to do *two* views out of the base to
# recreate this situation.
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
def _view_from_base(base, t):
if t.is_nested:
# Nested tensors do not support as_strided, and
# hence,always have _view_func available.
#
# The unsafe version of _view_func omits
# checking whether the base passed in has the same
# metadata as the original base the view_func
# was originally executed with. (1) It is OK here,
# because we're calling it on the meta-ified base,
# so the metadata is guaranteed to be the same.
# (2) It is necessary because we don't actually
# want to guard on the base's metadata here.
return t._view_func_unsafe(base)
else:
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
return base.as_strided(sizes, strides, storage_offset)
if safe_is_leaf(t):
# Leaf views that track view metadata are created by
# creating a view inside a no_grad block
with torch.no_grad(), maybe_suppress():
r = base.as_strided(sizes, strides, storage_offset)
r = _view_from_base(base, t)
# As it's a leaf, we can directly assign requires_grad
r.requires_grad = t.requires_grad
else:
if t._base.requires_grad == t.requires_grad:
# Easy case, just run the view op
with torch.enable_grad(), maybe_suppress():
r = base.as_strided(sizes, strides, storage_offset)
r = _view_from_base(base, t)
# NB: We don't actaully faithfully replicate
# autograd connectivity, but that doesn't matter
# today. See following for more info:
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
else:
# Obscure case. Create a leaf view and give it the
# correct requires_grad, then do the final view.
@ -397,7 +417,7 @@ class MetaConverter:
mid = base.view(base.shape)
mid.requires_grad = t.requires_grad
with torch.enable_grad(), maybe_suppress():
r = mid.as_strided(sizes, strides, storage_offset)
r = _view_from_base(mid, t)
# The CreationMeta influences whether or not inplace
# mutation is an error or not. So we need to make
# sure we properly propagate this as well.