mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] fix index.Tensor fallback (#144736)
The original issue is we see accuracy problem in a meta internal model [meta internal link](https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065/). The debugging is hard but the root cause is relatively simple. The root cause is that the model has mix-device inputs for index.Tensor which causes Inductor to fallback. And the meta kernel for index.Tensor returns a tensor with inconsistent strides to the eager kernel. The following code snippet ``` import torch from torch._subclasses import FakeTensorMode device = "cuda" x = torch.randn((24, 16, 32, 32), device=device).to(memory_format=torch.channels_last) x = x.view(2, 12, 16, 32, 32) i1 = torch.arange(2).unsqueeze(-1) i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3] print(f"Eager stride: {x[i1, i2].stride()}") mode = FakeTensorMode() with mode: f_x = mode.from_tensor(x) f_i1 = mode.from_tensor(i1) f_i2 = mode.from_tensor(i2) f_out = f_x[f_i1, f_i2] print(f"Meta stride: {f_out.stride()}") ``` would output: ``` Eager stride: (49152, 16384, 1, 512, 16) Meta stride: (49152, 16384, 1024, 32, 1) ``` In this PR, I fix the problem to run eager kernel to get the index.Tensor fallback's output layout. A better solution would be to change meta/eager kernel implementation so that their output layout matches. But I'm not sure how to properly do that. In the index.Tensor meta kernel, we always produce dense output:6d56277682/torch/_meta_registrations.py (L3184). While the eager kernel seems to leverage TensorIteratorBase to decide some dimension permutation:6d56277682/aten/src/ATen/TensorIterator.cpp (L232-L308). We can duplicate this logic to the meta kernel implementation if we really want meta matches eager. I can follow up on this if people have strong opinion to do this. And here is an issue https://github.com/pytorch/pytorch/issues/144717 for asserting size/strides for fallback kernels. With that, the issue debugged here would be much easier to root cause. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144736 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
57d5659c3b
commit
0c0583254e
@ -37,6 +37,7 @@ from torch._dynamo.testing import (
|
||||
CompileCounterWithBackend,
|
||||
expectedFailureCodegenDynamic,
|
||||
rand_strided,
|
||||
reset_rng_state,
|
||||
same,
|
||||
skipIfPy312,
|
||||
)
|
||||
@ -12269,6 +12270,46 @@ class CommonTemplate:
|
||||
with self.assertRaisesRegex(RuntimeError, "Output size is too small"):
|
||||
_ = torch.compile(model)(inputs)
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(fallback_random=True)
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper,
|
||||
"cpp wrapper does not support sort properly: https://gist.github.com/shunting314/e58f637f9972f1ad1a033d73cee6e42a",
|
||||
)
|
||||
def test_mix_device_index(self):
|
||||
"""
|
||||
A tiny repro for this meta internal issue: https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065
|
||||
whose root cause is Inductor having wrong assumption of index.Tensor's output
|
||||
stride.
|
||||
"""
|
||||
image_latent = (
|
||||
torch.randn((24, 16, 32, 32), device=GPU_TYPE)
|
||||
.to(memory_format=torch.channels_last)
|
||||
.view(2, 12, 16, 32, 32)
|
||||
)
|
||||
|
||||
def f(image_latent):
|
||||
indices = torch.argsort(torch.rand(2, 12), dim=-1)
|
||||
|
||||
tar_latent = image_latent[torch.arange(2).unsqueeze(-1), indices[:, :3]]
|
||||
|
||||
# The original model uses einops. In this unit test, we use view op directly
|
||||
# to avoid importing einops
|
||||
# tar_latent_rearranged = einops.rearrange(
|
||||
# tar_latent, "b n c h w -> (b n) c h w"
|
||||
# )
|
||||
tar_latent_rearranged = tar_latent.view(-1, *tar_latent.size()[2:])
|
||||
|
||||
return tar_latent_rearranged
|
||||
|
||||
reset_rng_state()
|
||||
ref = f(image_latent)
|
||||
opt_f = torch.compile(f)
|
||||
reset_rng_state()
|
||||
act = opt_f(image_latent)
|
||||
|
||||
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestFailure:
|
||||
|
||||
@ -1805,6 +1805,24 @@ class TestMeta(TestCase):
|
||||
self.assertEqual(nz.stride(), torch.Size([1, 24]))
|
||||
|
||||
|
||||
def test_stride_for_index_Tensor(self):
|
||||
from torch._subclasses import FakeTensorMode
|
||||
x = torch.randn((24, 16, 32, 32)).to(memory_format=torch.channels_last)
|
||||
x = x.view(2, 12, 16, 32, 32)
|
||||
|
||||
i1 = torch.arange(2).unsqueeze(-1)
|
||||
i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3]
|
||||
|
||||
out = x[i1, i2]
|
||||
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
f_x = mode.from_tensor(x)
|
||||
f_i1 = mode.from_tensor(i1)
|
||||
f_i2 = mode.from_tensor(i2)
|
||||
f_out = f_x[f_i1, f_i2]
|
||||
|
||||
self.assertEqual(out.stride(), f_out.stride())
|
||||
|
||||
instantiate_device_type_tests(TestMeta, globals())
|
||||
|
||||
|
||||
@ -3260,7 +3260,38 @@ def meta_index_Tensor(self, indices):
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
def _restride_src(self):
|
||||
"""
|
||||
This follows restride_src in TensorAdvancedIndexing.cpp
|
||||
"""
|
||||
shape = before_shape + replacement_shape + after_shape
|
||||
strides = list(self.stride())
|
||||
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
|
||||
replacement_shape
|
||||
)
|
||||
return self.as_strided(shape, strides)
|
||||
|
||||
out = self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if guard_size_oblivious(self.numel() == 0):
|
||||
# No need to worry about the output strides if self is empty.
|
||||
return out
|
||||
|
||||
# Try to follow eager to decide the output stride based on self.
|
||||
# Note that perm here is the reverse of the 'perm_' decided by
|
||||
# TensorIteratorBase::reorder_dimensions
|
||||
restrided_self = _restride_src(self)
|
||||
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
|
||||
|
||||
# Follow TensorIteratorBase::allocate_or_resize_outputs
|
||||
if list(perm) != list(range(len(perm))):
|
||||
perm_shape = utils.apply_perm(out.shape, perm)
|
||||
new_stride = utils.make_contiguous_strides_for(perm_shape)
|
||||
new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm))
|
||||
out = out.as_strided(out.size(), new_stride)
|
||||
return out
|
||||
|
||||
|
||||
@register_meta([aten.convolution_backward.default])
|
||||
|
||||
Reference in New Issue
Block a user