mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[real_tensor_prop] Infer Fake kernels during real tensor prop (#139213)
This PR changes real_tensor_prop to also infer fake kernels when the operator doesn't have it. We infer the fake output to be of the same properties as the real output, with unbacked symints in the sizes and some stride order. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/139213 Approved by: https://github.com/pianpwk ghstack dependencies: #139212
This commit is contained in:
committed by
PyTorch MergeBot
parent
03ec25053a
commit
ad0883a288
@ -422,3 +422,33 @@ class MutationChecker:
|
||||
def hash_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
"""Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
|
||||
return t.detach().float().mean()
|
||||
|
||||
|
||||
def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
||||
"""If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
|
||||
Don't use this if the operator decomposes before FakeTensorMode.
|
||||
"""
|
||||
if can_generate_trivial_fake_impl(op):
|
||||
return True
|
||||
name = op._name
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
name, "CompositeImplicitAutograd"
|
||||
):
|
||||
return True
|
||||
opdef = torch._library.custom_ops._maybe_get_opdef(name)
|
||||
if opdef is None:
|
||||
# the non-torch.library.custom_op path
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
name, "CompositeExplicitAutograd"
|
||||
):
|
||||
return True
|
||||
entry = torch._library.simple_registry.singleton.find(name)
|
||||
if entry.fake_impl.kernel is not None:
|
||||
return True
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
|
||||
return True
|
||||
else:
|
||||
# the torch.library.custom_op path
|
||||
if opdef._abstract_fn is not None:
|
||||
return True
|
||||
return False
|
||||
|
Reference in New Issue
Block a user