[inductor] fix index.Tensor fallback (#144736)

The original issue is we see accuracy problem in a meta internal model [meta internal link](https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065/).  The debugging is hard but the root cause is relatively simple. The root cause is that the model has mix-device inputs for index.Tensor which causes Inductor to fallback. And the meta kernel for index.Tensor returns a tensor with inconsistent strides to the eager kernel.

The following code snippet
```
import torch
from torch._subclasses import FakeTensorMode

device = "cuda"

x = torch.randn((24, 16, 32, 32), device=device).to(memory_format=torch.channels_last)
x = x.view(2, 12, 16, 32, 32)

i1 = torch.arange(2).unsqueeze(-1)
i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3]

print(f"Eager stride: {x[i1, i2].stride()}")

mode = FakeTensorMode()
with mode:
    f_x = mode.from_tensor(x)
    f_i1 = mode.from_tensor(i1)
    f_i2 = mode.from_tensor(i2)
    f_out = f_x[f_i1, f_i2]
    print(f"Meta stride: {f_out.stride()}")
```

would output:
```
Eager stride: (49152, 16384, 1, 512, 16)
Meta stride: (49152, 16384, 1024, 32, 1)
```

In this PR, I fix the problem to run eager kernel to get the index.Tensor fallback's output layout. A better solution would be to change meta/eager kernel implementation so that their output layout matches. But I'm not sure how to properly do that.
In the index.Tensor meta kernel, we always produce dense output:  6d56277682/torch/_meta_registrations.py (L3184) . While the eager kernel seems to leverage TensorIteratorBase to decide some dimension permutation: 6d56277682/aten/src/ATen/TensorIterator.cpp (L232-L308) .  We can duplicate this logic to the meta kernel implementation if we really want meta matches eager. I can follow up on this if people have strong opinion to do this.

And here is an issue https://github.com/pytorch/pytorch/issues/144717 for asserting size/strides for fallback kernels. With that, the issue debugged here would be much easier to root cause.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144736
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang
2025-01-15 15:20:22 -08:00
committed by PyTorch MergeBot
parent 57d5659c3b
commit 0c0583254e
3 changed files with 91 additions and 1 deletions

View File

@ -37,6 +37,7 @@ from torch._dynamo.testing import (
CompileCounterWithBackend,
expectedFailureCodegenDynamic,
rand_strided,
reset_rng_state,
same,
skipIfPy312,
)
@ -12269,6 +12270,46 @@ class CommonTemplate:
with self.assertRaisesRegex(RuntimeError, "Output size is too small"):
_ = torch.compile(model)(inputs)
@requires_gpu()
@config.patch(fallback_random=True)
@unittest.skipIf(
config.cpp_wrapper,
"cpp wrapper does not support sort properly: https://gist.github.com/shunting314/e58f637f9972f1ad1a033d73cee6e42a",
)
def test_mix_device_index(self):
"""
A tiny repro for this meta internal issue: https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065
whose root cause is Inductor having wrong assumption of index.Tensor's output
stride.
"""
image_latent = (
torch.randn((24, 16, 32, 32), device=GPU_TYPE)
.to(memory_format=torch.channels_last)
.view(2, 12, 16, 32, 32)
)
def f(image_latent):
indices = torch.argsort(torch.rand(2, 12), dim=-1)
tar_latent = image_latent[torch.arange(2).unsqueeze(-1), indices[:, :3]]
# The original model uses einops. In this unit test, we use view op directly
# to avoid importing einops
# tar_latent_rearranged = einops.rearrange(
# tar_latent, "b n c h w -> (b n) c h w"
# )
tar_latent_rearranged = tar_latent.view(-1, *tar_latent.size()[2:])
return tar_latent_rearranged
reset_rng_state()
ref = f(image_latent)
opt_f = torch.compile(f)
reset_rng_state()
act = opt_f(image_latent)
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
@dataclasses.dataclass
class TestFailure:

View File

@ -1805,6 +1805,24 @@ class TestMeta(TestCase):
self.assertEqual(nz.stride(), torch.Size([1, 24]))
def test_stride_for_index_Tensor(self):
from torch._subclasses import FakeTensorMode
x = torch.randn((24, 16, 32, 32)).to(memory_format=torch.channels_last)
x = x.view(2, 12, 16, 32, 32)
i1 = torch.arange(2).unsqueeze(-1)
i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3]
out = x[i1, i2]
mode = FakeTensorMode()
with mode:
f_x = mode.from_tensor(x)
f_i1 = mode.from_tensor(i1)
f_i2 = mode.from_tensor(i2)
f_out = f_x[f_i1, f_i2]
self.assertEqual(out.stride(), f_out.stride())
instantiate_device_type_tests(TestMeta, globals())

View File

@ -3260,7 +3260,38 @@ def meta_index_Tensor(self, indices):
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
def _restride_src(self):
"""
This follows restride_src in TensorAdvancedIndexing.cpp
"""
shape = before_shape + replacement_shape + after_shape
strides = list(self.stride())
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
replacement_shape
)
return self.as_strided(shape, strides)
out = self.new_empty(before_shape + replacement_shape + after_shape)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(self.numel() == 0):
# No need to worry about the output strides if self is empty.
return out
# Try to follow eager to decide the output stride based on self.
# Note that perm here is the reverse of the 'perm_' decided by
# TensorIteratorBase::reorder_dimensions
restrided_self = _restride_src(self)
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
# Follow TensorIteratorBase::allocate_or_resize_outputs
if list(perm) != list(range(len(perm))):
perm_shape = utils.apply_perm(out.shape, perm)
new_stride = utils.make_contiguous_strides_for(perm_shape)
new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm))
out = out.as_strided(out.size(), new_stride)
return out
@register_meta([aten.convolution_backward.default])