mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] fix index.Tensor fallback (#144736)
The original issue is we see accuracy problem in a meta internal model [meta internal link](https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065/). The debugging is hard but the root cause is relatively simple. The root cause is that the model has mix-device inputs for index.Tensor which causes Inductor to fallback. And the meta kernel for index.Tensor returns a tensor with inconsistent strides to the eager kernel. The following code snippet ``` import torch from torch._subclasses import FakeTensorMode device = "cuda" x = torch.randn((24, 16, 32, 32), device=device).to(memory_format=torch.channels_last) x = x.view(2, 12, 16, 32, 32) i1 = torch.arange(2).unsqueeze(-1) i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3] print(f"Eager stride: {x[i1, i2].stride()}") mode = FakeTensorMode() with mode: f_x = mode.from_tensor(x) f_i1 = mode.from_tensor(i1) f_i2 = mode.from_tensor(i2) f_out = f_x[f_i1, f_i2] print(f"Meta stride: {f_out.stride()}") ``` would output: ``` Eager stride: (49152, 16384, 1, 512, 16) Meta stride: (49152, 16384, 1024, 32, 1) ``` In this PR, I fix the problem to run eager kernel to get the index.Tensor fallback's output layout. A better solution would be to change meta/eager kernel implementation so that their output layout matches. But I'm not sure how to properly do that. In the index.Tensor meta kernel, we always produce dense output:6d56277682/torch/_meta_registrations.py (L3184)
. While the eager kernel seems to leverage TensorIteratorBase to decide some dimension permutation:6d56277682/aten/src/ATen/TensorIterator.cpp (L232-L308)
. We can duplicate this logic to the meta kernel implementation if we really want meta matches eager. I can follow up on this if people have strong opinion to do this. And here is an issue https://github.com/pytorch/pytorch/issues/144717 for asserting size/strides for fallback kernels. With that, the issue debugged here would be much easier to root cause. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144736 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
57d5659c3b
commit
0c0583254e
@ -1805,6 +1805,24 @@ class TestMeta(TestCase):
|
||||
self.assertEqual(nz.stride(), torch.Size([1, 24]))
|
||||
|
||||
|
||||
def test_stride_for_index_Tensor(self):
|
||||
from torch._subclasses import FakeTensorMode
|
||||
x = torch.randn((24, 16, 32, 32)).to(memory_format=torch.channels_last)
|
||||
x = x.view(2, 12, 16, 32, 32)
|
||||
|
||||
i1 = torch.arange(2).unsqueeze(-1)
|
||||
i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3]
|
||||
|
||||
out = x[i1, i2]
|
||||
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
f_x = mode.from_tensor(x)
|
||||
f_i1 = mode.from_tensor(i1)
|
||||
f_i2 = mode.from_tensor(i2)
|
||||
f_out = f_x[f_i1, f_i2]
|
||||
|
||||
self.assertEqual(out.stride(), f_out.stride())
|
||||
|
||||
instantiate_device_type_tests(TestMeta, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user