Move attention kernels back from fake_impls to meta_registrations (#134288)

See #121528 for additional context.

In #120682, we moved the attention kernels from meta_registrations to fake_impls with the intent of fixing the device handling for seed/offset: these are typically on CPU. We needed to put the registrations in fake_impls to do this because meta_registrations doesn't have a way to specify device, whereas fake_impls does. But when we tried to actually fix the device types (#120839), we had to revert the PR because it broke cudagraph handling (during which seed/offset _are_ on CUDA).

Now, we want to put the registrations back in meta_registrations so that we can call these kernels with meta tensors. The use case is later in this stack - we want to be able to use the flop counter with these kernels.

Also - I specifically skip the `compare_tensor_meta()` check in test_fake / test_fake_autocast tests for the `_efficient_attention_forward` and `_flash_attention_forward` kernels, which fails because of the device mismatch from the seed/offset tensors. Then we can un-skip these opinfos. I verified that the efficient_attention_forward bug (#120842) is now caught by these opinfos if I revert the fix from this PR.

Differential Revision: [D61687369](https://our.internmc.facebook.com/intern/diff/D61687369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134288
Approved by: https://github.com/drisspg
This commit is contained in:
David Berard
2024-08-26 11:15:58 -07:00
committed by PyTorch MergeBot
parent 39ca96398b
commit 289486d007
5 changed files with 301 additions and 356 deletions

View File

@ -2476,9 +2476,16 @@ class TestFakeTensor(TestCase):
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, True)
if op.op not in [
torch.ops.aten._efficient_attention_forward,
torch.ops.aten._flash_attention_forward,
]:
# prims/decomps must correctly model strides,
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
# note: the excluded ops have intentionally incorrect device;
# see "Note [Seed and Offset]" (_meta_registrations.py)
prims.utils.compare_tensor_meta(fake_out, real_out, True)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs(