[dynamo] Skip leaf check on assert_metadata_eq if grad tensor level is -2 (#122728)

When fakifying a grad tracking tensor, if the level is -2 (sentinel
value) we can just unwrap the grad tensor and return a fake version of
it. In this PR, we update the `assert_metadata_eq` to not compare if
the grad tensor and the unwrapped ones are leafs or not, as this may
not be always true.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122728
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2024-03-26 18:39:54 -03:00
committed by PyTorch MergeBot
parent 03439d4c1c
commit 9ff2a9dcdd
9 changed files with 16 additions and 6 deletions

View File

@ -77,6 +77,7 @@ def assert_metadata_eq(
m2: torch.Tensor,
*,
skip_symbolic=False,
skip_leaf=False,
):
if isinstance(m1, torch.Tensor):
m1 = MetaTensorDescriber().describe_tensor(m1)
@ -86,6 +87,7 @@ def assert_metadata_eq(
if not skip_symbolic:
assert_eq(m1.shape, m2.shape)
assert_eq(m1.requires_grad, m2.requires_grad)
if not skip_leaf:
assert_eq(m1.is_leaf, m2.is_leaf)
# MetaTensorDesc doesn't store grad_fn; inferred from leaf
# assert_eq(m1.grad_fn is None, m2.grad_fn is None)
@ -855,6 +857,8 @@ class MetaConverter:
return fake_t
if self.get_tensor_memo(t) is None:
GRAD_TENSOR_SENTINEL_VALUE = -2
with torch.inference_mode(t.is_inference):
if t.is_sparse:
is_leaf = t.is_leaf
@ -1010,6 +1014,9 @@ class MetaConverter:
with disable_functorch():
ft = _to_fake_tensor(t.unwrapped)
lvl = t.level
if lvl == GRAD_TENSOR_SENTINEL_VALUE:
r = ft
else:
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
t.functorch_stack
):
@ -1291,7 +1298,10 @@ class MetaConverter:
torch._C._set_conj(r, t.is_conj)
torch._C._set_neg(r, t.is_neg)
# This can be skipped if necessary for performance reasons
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
skip_leaf = (
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
)
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
self.set_tensor_memo(t, r)
return self.get_tensor_memo(t)