Revert "Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)"

This reverts commit acc4cbea395c25410c26d6fd3c88c072ce24c918.

Reverted https://github.com/pytorch/pytorch/pull/124799 on behalf of https://github.com/jeanschmidt due to checking if this diff introduced regressions on linux-focal-py3.11-clang10 and linux-focal-py3.8-clang10 ([comment](https://github.com/pytorch/pytorch/pull/124799#issuecomment-2076756876))
This commit is contained in:
PyTorch MergeBot
2024-04-25 09:29:57 +00:00
parent 48a016157d
commit 678662a557
11 changed files with 58 additions and 129 deletions

View File

@ -1387,28 +1387,6 @@ class TestTorchFunctionMode(TestCase):
self.assertTrue(called)
def test_getitem_call(self):
# This failed because the parser thinks the function is called to()
# but it's actually called _parse_to()
called = False
class A(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
nonlocal called
if kwargs is None:
kwargs = {}
called = True
return func(*args, **kwargs)
a = torch.zeros(5)
b = torch.tensor(0)
with A():
a[b]
self.assertTrue(called)
def test_distributions_bernoulli(self):
# This failed because improper use of has_torch_function when
# is_tensor_like should have been used instead, inside the