mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Skip leaf check on assert_metadata_eq
if grad tensor level is -2
(#122728)
When fakifying a grad tracking tensor, if the level is -2 (sentinel value) we can just unwrap the grad tensor and return a fake version of it. In this PR, we update the `assert_metadata_eq` to not compare if the grad tensor and the unwrapped ones are leafs or not, as this may not be always true. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122728 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
03439d4c1c
commit
9ff2a9dcdd
@ -77,6 +77,7 @@ def assert_metadata_eq(
|
||||
m2: torch.Tensor,
|
||||
*,
|
||||
skip_symbolic=False,
|
||||
skip_leaf=False,
|
||||
):
|
||||
if isinstance(m1, torch.Tensor):
|
||||
m1 = MetaTensorDescriber().describe_tensor(m1)
|
||||
@ -86,6 +87,7 @@ def assert_metadata_eq(
|
||||
if not skip_symbolic:
|
||||
assert_eq(m1.shape, m2.shape)
|
||||
assert_eq(m1.requires_grad, m2.requires_grad)
|
||||
if not skip_leaf:
|
||||
assert_eq(m1.is_leaf, m2.is_leaf)
|
||||
# MetaTensorDesc doesn't store grad_fn; inferred from leaf
|
||||
# assert_eq(m1.grad_fn is None, m2.grad_fn is None)
|
||||
@ -855,6 +857,8 @@ class MetaConverter:
|
||||
return fake_t
|
||||
|
||||
if self.get_tensor_memo(t) is None:
|
||||
GRAD_TENSOR_SENTINEL_VALUE = -2
|
||||
|
||||
with torch.inference_mode(t.is_inference):
|
||||
if t.is_sparse:
|
||||
is_leaf = t.is_leaf
|
||||
@ -1010,6 +1014,9 @@ class MetaConverter:
|
||||
with disable_functorch():
|
||||
ft = _to_fake_tensor(t.unwrapped)
|
||||
lvl = t.level
|
||||
if lvl == GRAD_TENSOR_SENTINEL_VALUE:
|
||||
r = ft
|
||||
else:
|
||||
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
||||
t.functorch_stack
|
||||
):
|
||||
@ -1291,7 +1298,10 @@ class MetaConverter:
|
||||
torch._C._set_conj(r, t.is_conj)
|
||||
torch._C._set_neg(r, t.is_neg)
|
||||
# This can be skipped if necessary for performance reasons
|
||||
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
|
||||
skip_leaf = (
|
||||
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
|
||||
)
|
||||
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
|
||||
self.set_tensor_memo(t, r)
|
||||
|
||||
return self.get_tensor_memo(t)
|
||||
|
Reference in New Issue
Block a user