mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix fake tensor caching when output has unbacked (#153034)
We handle fake tensor caching in two ways: 1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode. 2. If the inputs have symbols then we cache on the ShapeEnv. This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call. However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output. In this case we shouldn't cache at all because what would that really mean? So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op. Added a test which checks for this case. While in there I also did a couple other related changes: 1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again. 2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153034 Approved by: https://github.com/masnesral, https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbb03e6971
commit
4f425a0397
@ -2265,13 +2265,10 @@ class FakeTensorDispatchCache(TestCase):
|
||||
gc.collect()
|
||||
self.assertTrue(count_invoke_subgraph_keys() == 0)
|
||||
|
||||
|
||||
|
||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||
def test_invoke_subgraph_cacheable_inplace(self):
|
||||
invoke_subgraph = torch._higher_order_ops.invoke_subgraph
|
||||
|
||||
|
||||
def fn(x, y):
|
||||
# aten ops are used so that eager backend graph is suitable for fake
|
||||
# tensor testing
|
||||
@ -2317,5 +2314,32 @@ class FakeTensorDispatchCache(TestCase):
|
||||
extract_tensor_metadata(b),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||
def test_unbacked_output(self):
|
||||
# The point of this test is to have an op which has no symbols as input
|
||||
# but a symbol as an output and make sure that we skip caching it.
|
||||
class LengthsGather(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
bias = torch.gather(offsets, 0, indices)
|
||||
lengths_selected = torch.gather(lengths, 0, indices)
|
||||
index = torch.repeat_interleave(bias, lengths_selected, dim=0)
|
||||
return index
|
||||
|
||||
input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
lengths = torch.tensor([0, 2, 3, 1, 4])
|
||||
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
|
||||
offsets = torch.cumsum(lengths, 0)
|
||||
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
|
||||
|
||||
FakeTensorMode.cache_clear()
|
||||
ep.run_decompositions({})
|
||||
self.assertBypasses("unrepresented symbol in output", 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user