mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)"
This reverts commit acc4cbea395c25410c26d6fd3c88c072ce24c918. Reverted https://github.com/pytorch/pytorch/pull/124799 on behalf of https://github.com/jeanschmidt due to checking if this diff introduced regressions on linux-focal-py3.11-clang10 and linux-focal-py3.8-clang10 ([comment](https://github.com/pytorch/pytorch/pull/124799#issuecomment-2076756876))
This commit is contained in:
@ -1387,28 +1387,6 @@ class TestTorchFunctionMode(TestCase):
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
def test_getitem_call(self):
|
||||
# This failed because the parser thinks the function is called to()
|
||||
# but it's actually called _parse_to()
|
||||
|
||||
called = False
|
||||
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
nonlocal called
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
called = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
a = torch.zeros(5)
|
||||
b = torch.tensor(0)
|
||||
with A():
|
||||
a[b]
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
|
||||
def test_distributions_bernoulli(self):
|
||||
# This failed because improper use of has_torch_function when
|
||||
# is_tensor_like should have been used instead, inside the
|
||||
|
Reference in New Issue
Block a user