mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move attention kernels back from fake_impls to meta_registrations (#134288)
See #121528 for additional context. In #120682, we moved the attention kernels from meta_registrations to fake_impls with the intent of fixing the device handling for seed/offset: these are typically on CPU. We needed to put the registrations in fake_impls to do this because meta_registrations doesn't have a way to specify device, whereas fake_impls does. But when we tried to actually fix the device types (#120839), we had to revert the PR because it broke cudagraph handling (during which seed/offset _are_ on CUDA). Now, we want to put the registrations back in meta_registrations so that we can call these kernels with meta tensors. The use case is later in this stack - we want to be able to use the flop counter with these kernels. Also - I specifically skip the `compare_tensor_meta()` check in test_fake / test_fake_autocast tests for the `_efficient_attention_forward` and `_flash_attention_forward` kernels, which fails because of the device mismatch from the seed/offset tensors. Then we can un-skip these opinfos. I verified that the efficient_attention_forward bug (#120842) is now caught by these opinfos if I revert the fix from this PR. Differential Revision: [D61687369](https://our.internmc.facebook.com/intern/diff/D61687369) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134288 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
39ca96398b
commit
289486d007
@ -2476,9 +2476,16 @@ class TestFakeTensor(TestCase):
|
||||
# if you see a shape exception here, you may need to add
|
||||
# a `dynamic_output_shape` tag to an operator
|
||||
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
if op.op not in [
|
||||
torch.ops.aten._efficient_attention_forward,
|
||||
torch.ops.aten._flash_attention_forward,
|
||||
]:
|
||||
# prims/decomps must correctly model strides,
|
||||
# see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
|
||||
|
||||
# note: the excluded ops have intentionally incorrect device;
|
||||
# see "Note [Seed and Offset]" (_meta_registrations.py)
|
||||
prims.utils.compare_tensor_meta(fake_out, real_out, True)
|
||||
|
||||
if name not in aliasing_failures:
|
||||
fake_aliasing = outputs_alias_inputs(
|
||||
|
Reference in New Issue
Block a user