Files
pytorch/test/test_transformers_privateuse1.py
FFFrog 29c8ae825f [OpenReg] Move SDPA to OpenReg from open_registration_extension.cpp (#153309)
As the title stated.

**Next Chages**:
- Migrate remaining functionality to OpenReg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153309
Approved by: https://github.com/albanD
2025-05-13 03:49:19 +00:00

101 lines
3.8 KiB
Python

# Owner(s): ["module: sdpa"]
import unittest
from collections import namedtuple
from functools import partial
import pytorch_openreg # noqa: F401
import torch
from torch.nn.attention import SDPBackend
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TEST_XPU
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
class TestSDPAPrivateUse1Only(NNTestCase):
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert (
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
== SDPBackend.OVERRIDEABLE.value
)
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
)
rand_upward = torch.rand(
shape, device="cpu", dtype=torch.float16, requires_grad=False
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
)
if __name__ == "__main__":
run_tests()