[real_tensor_prop] Infer Fake kernels during real tensor prop (#139213)

This PR changes real_tensor_prop to also infer fake kernels when the
operator doesn't have it.

We infer the fake output to be of the same properties as the real
output, with unbacked symints in the sizes and some stride order.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139213
Approved by: https://github.com/pianpwk
ghstack dependencies: #139212
This commit is contained in:
Richard Zou
2024-10-30 09:38:20 -04:00
committed by PyTorch MergeBot
parent 03ec25053a
commit ad0883a288
4 changed files with 197 additions and 5 deletions

View File

@ -422,3 +422,33 @@ class MutationChecker:
def hash_tensor(t: torch.Tensor) -> torch.Tensor:
"""Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
return t.detach().float().mean()
def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
"""If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
Don't use this if the operator decomposes before FakeTensorMode.
"""
if can_generate_trivial_fake_impl(op):
return True
name = op._name
if torch._C._dispatch_has_kernel_for_dispatch_key(
name, "CompositeImplicitAutograd"
):
return True
opdef = torch._library.custom_ops._maybe_get_opdef(name)
if opdef is None:
# the non-torch.library.custom_op path
if torch._C._dispatch_has_kernel_for_dispatch_key(
name, "CompositeExplicitAutograd"
):
return True
entry = torch._library.simple_registry.singleton.find(name)
if entry.fake_impl.kernel is not None:
return True
if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
return True
else:
# the torch.library.custom_op path
if opdef._abstract_fn is not None:
return True
return False