Files
pytorch/test/test_transformers_privateuse1.py
FFFrog 1b389025ba Refactor and Improve the OpenReg Module (#158090)
----
# Refactor and Improve the OpenReg Module

## Background

Since PrivateUse1 has become the main path for integrating new devices with PyTorch, there have been some feature requests related to PrivateUse1 regarding interfaces, documentation, reference examples, etc., such as the following:

- https://github.com/pytorch/pytorch/issues/155864
- https://github.com/pytorch/pytorch/issues/144955
- https://github.com/pytorch/pytorch/issues/144845

Taking these requests into consideration and combining them with the position of OpenReg, which is currently used as the test backend for PrivateUse1, I'm planning to make the following optimizations:

- Optimize the implementation of OpenReg to make it align with the standard specifications for real backend (C++) access, serving as a reference for new device integration code.
- Add comprehensive documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html) to guide new accelerator integration, functioning as a reference manual.

## Design Principles:

- Minimization Principle: Keep the code small and clear; only implement the minimum set of code required for verification and as an integration reference.
- Authenticity Principle: Integrate OpenReg in the same way that real accelerators access PyTorch.

## More Infos:

Pleaes refer to [this](6b8020f1ab/test/cpp_extensions/open_registration_extension/torch_openreg/README.md) for more information about `OpenReg`.

## Current Progress:
- Refer to the implementation of [torch_xla](https://github.com/pytorch/xla) to refactor all of OpenReg's code, making it easier to understand.
- Ensure all tests in [test/test_openreg.py](https://github.com/FFFrog/pytorch/blob/openreg/test/test_openreg.py) pass after refactoring.

## Next Steps:
- Add more features to cover all integration points.
- Gradually add user guides and documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158090
Approved by: https://github.com/seemethere, https://github.com/albanD
2025-07-15 08:10:05 +00:00

101 lines
3.8 KiB
Python

# Owner(s): ["module: sdpa"]
import unittest
from collections import namedtuple
from functools import partial
import torch_openreg # noqa: F401
import torch
from torch.nn.attention import SDPBackend
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TEST_XPU
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
class TestSDPAPrivateUse1Only(NNTestCase):
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert (
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
== SDPBackend.OVERRIDEABLE.value
)
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
)
rand_upward = torch.rand(
shape, device="cpu", dtype=torch.float16, requires_grad=False
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
)
if __name__ == "__main__":
run_tests()