mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
----
# Refactor and Improve the OpenReg Module
## Background
Since PrivateUse1 has become the main path for integrating new devices with PyTorch, there have been some feature requests related to PrivateUse1 regarding interfaces, documentation, reference examples, etc., such as the following:
- https://github.com/pytorch/pytorch/issues/155864
- https://github.com/pytorch/pytorch/issues/144955
- https://github.com/pytorch/pytorch/issues/144845
Taking these requests into consideration and combining them with the position of OpenReg, which is currently used as the test backend for PrivateUse1, I'm planning to make the following optimizations:
- Optimize the implementation of OpenReg to make it align with the standard specifications for real backend (C++) access, serving as a reference for new device integration code.
- Add comprehensive documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html) to guide new accelerator integration, functioning as a reference manual.
## Design Principles:
- Minimization Principle: Keep the code small and clear; only implement the minimum set of code required for verification and as an integration reference.
- Authenticity Principle: Integrate OpenReg in the same way that real accelerators access PyTorch.
## More Infos:
Pleaes refer to [this](6b8020f1ab/test/cpp_extensions/open_registration_extension/torch_openreg/README.md
) for more information about `OpenReg`.
## Current Progress:
- Refer to the implementation of [torch_xla](https://github.com/pytorch/xla) to refactor all of OpenReg's code, making it easier to understand.
- Ensure all tests in [test/test_openreg.py](https://github.com/FFFrog/pytorch/blob/openreg/test/test_openreg.py) pass after refactoring.
## Next Steps:
- Add more features to cover all integration points.
- Gradually add user guides and documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158090
Approved by: https://github.com/seemethere, https://github.com/albanD
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
# Owner(s): ["module: sdpa"]
|
|
|
|
import unittest
|
|
from collections import namedtuple
|
|
from functools import partial
|
|
|
|
import torch_openreg # noqa: F401
|
|
|
|
import torch
|
|
from torch.nn.attention import SDPBackend
|
|
from torch.testing._internal.common_nn import NNTestCase
|
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TEST_XPU
|
|
|
|
|
|
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
|
|
|
|
|
|
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
|
class TestSDPAPrivateUse1Only(NNTestCase):
|
|
@skipIfTorchDynamo()
|
|
def test_fused_sdp_choice_privateuseone(self):
|
|
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
|
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
|
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
|
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
|
q_privateuse1 = q_cpu.to("openreg")
|
|
k_privateuse1 = k_cpu.to("openreg")
|
|
v_privateuse1 = v_cpu.to("openreg")
|
|
assert (
|
|
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
|
|
== SDPBackend.OVERRIDEABLE.value
|
|
)
|
|
|
|
def test_scaled_dot_product_fused_attention_overrideable(self):
|
|
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
|
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
|
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
|
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
|
q_privateuse1 = q_cpu.to("openreg")
|
|
k_privateuse1 = k_cpu.to("openreg")
|
|
v_privateuse1 = v_cpu.to("openreg")
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
|
|
)
|
|
|
|
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
|
|
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
|
make_tensor = partial(
|
|
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
|
|
)
|
|
shape = (batch_size, num_heads, seq_len, head_dim)
|
|
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
|
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
|
|
q_privateuse1 = q_cpu.to("openreg")
|
|
k_privateuse1 = k_cpu.to("openreg")
|
|
v_privateuse1 = v_cpu.to("openreg")
|
|
attn_mask_privateuse1 = attn_mask.to("openreg")
|
|
(
|
|
output,
|
|
logsumexp,
|
|
cum_seq_q,
|
|
cum_seq_k,
|
|
max_q,
|
|
max_k,
|
|
philox_seed,
|
|
philox_offset,
|
|
debug_attn_mask,
|
|
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
|
|
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
|
|
)
|
|
|
|
rand_upward = torch.rand(
|
|
shape, device="cpu", dtype=torch.float16, requires_grad=False
|
|
)
|
|
rand_upward_privateuse1 = rand_upward.to("openreg")
|
|
grad_input_mask = [True, True, True, True]
|
|
grad_q, grad_k, grad_v, grad_attn_mask = (
|
|
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
|
rand_upward_privateuse1,
|
|
q_privateuse1,
|
|
k_privateuse1,
|
|
v_privateuse1,
|
|
attn_mask_privateuse1,
|
|
grad_input_mask,
|
|
output,
|
|
logsumexp,
|
|
cum_seq_q,
|
|
cum_seq_k,
|
|
max_q,
|
|
max_k,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
philox_seed=philox_seed,
|
|
philox_offset=philox_offset,
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|