mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Replay view with view_func instead of as_strided in meta_utils for NT (#112205)
Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently, _view_func IS always available for nested tensors. A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
503955f5ec
commit
0cda4c8abe
@ -309,7 +309,7 @@ class MetaConverter:
|
||||
from torch._dynamo.source import AttrSource
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
||||
|
||||
if shape_env:
|
||||
if shape_env and not t.is_nested:
|
||||
base_dynamic_dims = [DimDynamic.STATIC] * t._base.dim()
|
||||
else:
|
||||
base_dynamic_dims = None
|
||||
@ -369,25 +369,45 @@ class MetaConverter:
|
||||
#
|
||||
# So we may have to do *two* views out of the base to
|
||||
# recreate this situation.
|
||||
|
||||
(
|
||||
sizes,
|
||||
strides,
|
||||
storage_offset,
|
||||
) = sym_sizes_strides_storage_offset(t, source)
|
||||
def _view_from_base(base, t):
|
||||
if t.is_nested:
|
||||
# Nested tensors do not support as_strided, and
|
||||
# hence,always have _view_func available.
|
||||
#
|
||||
# The unsafe version of _view_func omits
|
||||
# checking whether the base passed in has the same
|
||||
# metadata as the original base the view_func
|
||||
# was originally executed with. (1) It is OK here,
|
||||
# because we're calling it on the meta-ified base,
|
||||
# so the metadata is guaranteed to be the same.
|
||||
# (2) It is necessary because we don't actually
|
||||
# want to guard on the base's metadata here.
|
||||
return t._view_func_unsafe(base)
|
||||
else:
|
||||
(
|
||||
sizes,
|
||||
strides,
|
||||
storage_offset,
|
||||
) = sym_sizes_strides_storage_offset(t, source)
|
||||
return base.as_strided(sizes, strides, storage_offset)
|
||||
|
||||
if safe_is_leaf(t):
|
||||
# Leaf views that track view metadata are created by
|
||||
# creating a view inside a no_grad block
|
||||
with torch.no_grad(), maybe_suppress():
|
||||
r = base.as_strided(sizes, strides, storage_offset)
|
||||
r = _view_from_base(base, t)
|
||||
# As it's a leaf, we can directly assign requires_grad
|
||||
r.requires_grad = t.requires_grad
|
||||
else:
|
||||
if t._base.requires_grad == t.requires_grad:
|
||||
# Easy case, just run the view op
|
||||
with torch.enable_grad(), maybe_suppress():
|
||||
r = base.as_strided(sizes, strides, storage_offset)
|
||||
r = _view_from_base(base, t)
|
||||
|
||||
# NB: We don't actaully faithfully replicate
|
||||
# autograd connectivity, but that doesn't matter
|
||||
# today. See following for more info:
|
||||
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
|
||||
else:
|
||||
# Obscure case. Create a leaf view and give it the
|
||||
# correct requires_grad, then do the final view.
|
||||
@ -397,7 +417,7 @@ class MetaConverter:
|
||||
mid = base.view(base.shape)
|
||||
mid.requires_grad = t.requires_grad
|
||||
with torch.enable_grad(), maybe_suppress():
|
||||
r = mid.as_strided(sizes, strides, storage_offset)
|
||||
r = _view_from_base(mid, t)
|
||||
# The CreationMeta influences whether or not inplace
|
||||
# mutation is an error or not. So we need to make
|
||||
# sure we properly propagate this as well.
|
||||
|
||||
Reference in New Issue
Block a user