mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[OpenReg] Migrate OpenReg Tests from tests/test_openreg.py into torch_openreg/tests (#161917)
**Background:** Almost all the tests in `test/test_openreg.py` are designed for `torch_openreg`, so placing these testcases in the test directory is not a good idea. Instead, they should be moved to the `tests` directory under `torch_openreg`, coordinating these tests with their corresponding functional logic. **How to do:** So how do we verify the quality of the third-party device integration mechanism? We will maintain a `test_openreg` entrypoint in `test/run_test.py`. This entrypoint will install `torch_openreg` and run all the testcases located in `torch_openreg`. As long as all testcases pass, we can guarantee that the out-of-tree backend integration mechanism is available. **Next:** We will also improve `torch_openreg's` test coverage in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161917 Approved by: https://github.com/albanD
This commit is contained in:
@ -0,0 +1,42 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import os
|
||||
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfMPS,
|
||||
skipIfTorchDynamo,
|
||||
skipIfWindows,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestAutograd(TestCase):
|
||||
# Support MPS and Windows platform later and fix torchdynamo issue
|
||||
@skipIfMPS
|
||||
@skipIfWindows()
|
||||
@skipIfTorchDynamo()
|
||||
def test_autograd_init(self):
|
||||
# Make sure autograd is initialized
|
||||
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
|
||||
|
||||
pid = os.getpid()
|
||||
task_path = f"/proc/{pid}/task"
|
||||
all_threads = psutil.Process(pid).threads()
|
||||
|
||||
all_thread_names = set()
|
||||
|
||||
for t in all_threads:
|
||||
with open(f"{task_path}/{t.id}/comm") as file:
|
||||
thread_name = file.read().strip()
|
||||
all_thread_names.add(thread_name)
|
||||
|
||||
for i in range(torch.accelerator.device_count()):
|
||||
self.assertIn(f"pt_autograd_{i}", all_thread_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,40 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class TestEvent(TestCase):
|
||||
@skipIfTorchDynamo()
|
||||
def test_record_event(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
event1 = stream.record_event()
|
||||
self.assertNotEqual(0, event1.event_id)
|
||||
event2 = stream.record_event()
|
||||
self.assertNotEqual(0, event2.event_id)
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_elapsed_time(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
e1 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
e1.record(stream)
|
||||
e2 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
e2.record(stream)
|
||||
|
||||
e2.synchronize()
|
||||
self.assertTrue(e2.query())
|
||||
|
||||
ms = e1.elapsed_time(e2)
|
||||
self.assertTrue(ms > 0)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_wait_stream(self):
|
||||
s1 = torch.Stream(device="openreg")
|
||||
s2 = torch.Stream(device="openreg")
|
||||
e1 = s1.record_event()
|
||||
e1.wait(s2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,31 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class TestPinMemory(TestCase):
|
||||
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
|
||||
def test_pin_memory(self):
|
||||
tensor = torch.randn(10)
|
||||
self.assertFalse(tensor.is_pinned())
|
||||
pinned_tensor = tensor.pin_memory()
|
||||
self.assertTrue(pinned_tensor.is_pinned())
|
||||
slice_tensor = pinned_tensor[2:5]
|
||||
self.assertTrue(slice_tensor.is_pinned())
|
||||
|
||||
tensor = torch.randn(10)
|
||||
storage = tensor.storage()
|
||||
self.assertFalse(storage.is_pinned("openreg"))
|
||||
pinned_storage = storage.pin_memory("openreg")
|
||||
self.assertTrue(pinned_storage.is_pinned("openreg"))
|
||||
|
||||
tensor = torch.randn(10)
|
||||
untyped_storage = tensor.untyped_storage()
|
||||
self.assertFalse(untyped_storage.is_pinned("openreg"))
|
||||
pinned_untyped_storage = untyped_storage.pin_memory("openreg")
|
||||
self.assertTrue(pinned_untyped_storage.is_pinned("openreg"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,162 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class TestBackendModule(TestCase):
|
||||
def test_backend_module_name(self):
|
||||
self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
|
||||
# backend can be renamed to the same name multiple times
|
||||
torch.utils.rename_privateuse1_backend("openreg")
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
torch.utils.rename_privateuse1_backend("dev")
|
||||
|
||||
def test_backend_module_registration(self):
|
||||
def generate_faked_module():
|
||||
return types.ModuleType("fake_module")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
||||
torch._register_device_module("dev", generate_faked_module())
|
||||
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
||||
torch._register_device_module("openreg", generate_faked_module())
|
||||
|
||||
def test_backend_module_function(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Try to call torch.openreg"):
|
||||
torch.utils.backend_registration._get_custom_mod_func("func_name_")
|
||||
self.assertTrue(
|
||||
torch.utils.backend_registration._get_custom_mod_func("device_count")() == 2
|
||||
)
|
||||
|
||||
|
||||
class TestBackendProperty(TestCase):
|
||||
def test_backend_generate_methods(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
|
||||
self.assertTrue(hasattr(torch.Tensor, "is_openreg"))
|
||||
self.assertTrue(hasattr(torch.Tensor, "openreg"))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, "is_openreg"))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, "openreg"))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, "is_openreg"))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, "openreg"))
|
||||
self.assertTrue(hasattr(torch.nn.Module, "openreg"))
|
||||
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_openreg"))
|
||||
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "openreg"))
|
||||
|
||||
def test_backend_tensor_methods(self):
|
||||
x = torch.empty(4, 4)
|
||||
self.assertFalse(x.is_openreg)
|
||||
|
||||
y = x.openreg(torch.device("openreg"))
|
||||
self.assertTrue(y.is_openreg)
|
||||
z = x.openreg(torch.device("openreg:0"))
|
||||
self.assertTrue(z.is_openreg)
|
||||
n = x.openreg(0)
|
||||
self.assertTrue(n.is_openreg)
|
||||
|
||||
@unittest.skip("Need to support Parameter in openreg")
|
||||
def test_backend_module_methods(self):
|
||||
class FakeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
module = FakeModule()
|
||||
self.assertEqual(module.x.device.type, "cpu")
|
||||
module.openreg() # type: ignore[misc]
|
||||
self.assertEqual(module.x.device.type, "openreg")
|
||||
|
||||
@unittest.skip("Need to support untyped_storage in openreg")
|
||||
def test_backend_storage_methods(self):
|
||||
x = torch.empty(4, 4)
|
||||
|
||||
x_cpu = x.storage()
|
||||
self.assertFalse(x_cpu.is_openreg)
|
||||
x_openreg = x_cpu.openreg()
|
||||
self.assertTrue(x_openreg.is_openreg)
|
||||
|
||||
y = torch.empty(4, 4)
|
||||
|
||||
y_cpu = y.untyped_storage()
|
||||
self.assertFalse(y_cpu.is_openreg)
|
||||
y_openreg = y_cpu.openreg()
|
||||
self.assertTrue(y_openreg.is_openreg)
|
||||
|
||||
def test_backend_packed_sequence_methods(self):
|
||||
x = torch.rand(5, 3)
|
||||
y = torch.tensor([1, 1, 1, 1, 1])
|
||||
|
||||
z_cpu = torch.nn.utils.rnn.PackedSequence(x, y)
|
||||
self.assertFalse(z_cpu.is_openreg)
|
||||
|
||||
z_openreg = z_cpu.openreg()
|
||||
self.assertTrue(z_openreg.is_openreg)
|
||||
|
||||
|
||||
class TestTensorType(TestCase):
|
||||
def test_backend_tensor_type(self):
|
||||
dtypes_map = {
|
||||
torch.bool: "torch.openreg.BoolTensor",
|
||||
torch.double: "torch.openreg.DoubleTensor",
|
||||
torch.float32: "torch.openreg.FloatTensor",
|
||||
torch.half: "torch.openreg.HalfTensor",
|
||||
torch.int32: "torch.openreg.IntTensor",
|
||||
torch.int64: "torch.openreg.LongTensor",
|
||||
torch.int8: "torch.openreg.CharTensor",
|
||||
torch.short: "torch.openreg.ShortTensor",
|
||||
torch.uint8: "torch.openreg.ByteTensor",
|
||||
}
|
||||
|
||||
for dtype, str in dtypes_map.items():
|
||||
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
||||
self.assertTrue(x.type() == str)
|
||||
|
||||
# Note that all dtype-d Tensor objects here are only for legacy reasons
|
||||
# and should NOT be used.
|
||||
@skipIfTorchDynamo()
|
||||
def test_backend_type_methods(self):
|
||||
# Tensor
|
||||
tensor_cpu = torch.randn([8]).float()
|
||||
self.assertEqual(tensor_cpu.type(), "torch.FloatTensor")
|
||||
|
||||
tensor_openreg = tensor_cpu.openreg()
|
||||
self.assertEqual(tensor_openreg.type(), "torch.openreg.FloatTensor")
|
||||
|
||||
# Storage
|
||||
storage_cpu = tensor_cpu.storage()
|
||||
self.assertEqual(storage_cpu.type(), "torch.FloatStorage")
|
||||
|
||||
tensor_openreg = tensor_cpu.openreg()
|
||||
storage_openreg = tensor_openreg.storage()
|
||||
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
|
||||
|
||||
class CustomFloatStorage:
|
||||
@property
|
||||
def __module__(self):
|
||||
return "torch." + torch._C._get_privateuse1_backend_name()
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return "FloatStorage"
|
||||
|
||||
try:
|
||||
torch.openreg.FloatStorage = CustomFloatStorage()
|
||||
self.assertEqual(storage_openreg.type(), "torch.openreg.FloatStorage")
|
||||
|
||||
# test custom int storage after defining FloatStorage
|
||||
tensor_openreg = tensor_cpu.int().openreg()
|
||||
storage_openreg = tensor_openreg.storage()
|
||||
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
|
||||
finally:
|
||||
torch.openreg.FloatStorage = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,291 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
skipIfXpu,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
SDPAShape = collections.namedtuple(
|
||||
"Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
|
||||
)
|
||||
|
||||
|
||||
class TestFactory(TestCase):
|
||||
def test_empty(self):
|
||||
x = torch.empty(3, device="openreg")
|
||||
self.assertEqual(x.device.type, "openreg")
|
||||
self.assertEqual(x.shape, torch.Size([3]))
|
||||
|
||||
x = torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
|
||||
self.assertEqual(x.device.type, "openreg")
|
||||
self.assertEqual(x.shape, torch.Size([2, 3, 4, 5]))
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
x = torch.empty(3, 3, device="openreg")
|
||||
y = torch.empty(3, 3, device="openreg:0")
|
||||
z = x + y
|
||||
self.assertEqual(z.device.type, "openreg")
|
||||
self.assertEqual(z.shape, torch.Size([3, 3]))
|
||||
|
||||
def test_zeros(self):
|
||||
y = torch.zeros(3, device="openreg")
|
||||
self.assertEqual(y.device.type, "openreg")
|
||||
self.assertEqual(y.shape, torch.Size([3]))
|
||||
|
||||
def test_tensor(self):
|
||||
z = torch.tensor((), device="openreg")
|
||||
self.assertEqual(z.device.type, "openreg")
|
||||
self.assertEqual(z.shape, torch.Size([0]))
|
||||
|
||||
|
||||
class TestCopy(TestCase):
|
||||
def test_copy_same_device(self):
|
||||
a = torch.ones(10, device="openreg").clone()
|
||||
self.assertEqual(a, torch.ones(10, device="openreg"))
|
||||
|
||||
def test_cross_device_copy(self):
|
||||
a = torch.rand(10)
|
||||
b = a.to(device="openreg").add(2).to(device="cpu")
|
||||
self.assertEqual(b, a + 2)
|
||||
|
||||
def test_cross_diff_devices_copy(self):
|
||||
a = torch.ones(10, device="openreg:0").to(device="openreg:1").to(device="cpu")
|
||||
self.assertEqual(a, torch.ones(10))
|
||||
|
||||
|
||||
class TestOps(TestCase):
|
||||
def test_masked_select(self):
|
||||
tensor_cpu = torch.randn(10)
|
||||
tensor_openreg = tensor_cpu.to(device="openreg")
|
||||
mask = tensor_openreg.gt(0)
|
||||
out = torch.masked_select(tensor_openreg, mask)
|
||||
|
||||
self.assertEqual(out, tensor_cpu.masked_select(tensor_cpu.gt(0)))
|
||||
|
||||
def test_expand(self):
|
||||
x = torch.tensor([[1], [2], [3]], device="openreg")
|
||||
y = x.expand(3, 2)
|
||||
self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]]))
|
||||
self.assertEqual(x.data_ptr(), y.data_ptr())
|
||||
|
||||
def test_resize(self):
|
||||
tensor_cpu = torch.randn([4, 4])
|
||||
|
||||
tensor_openreg = tensor_cpu.openreg()
|
||||
self.assertTrue(tensor_openreg.size() == torch.Size([4, 4]))
|
||||
|
||||
storage_openreg = tensor_openreg.storage()
|
||||
self.assertTrue(storage_openreg.size() == 16)
|
||||
|
||||
tensor_openreg.resize_(2, 2, 2, 2)
|
||||
self.assertTrue(tensor_openreg.size() == torch.Size([2, 2, 2, 2]))
|
||||
|
||||
storage_openreg = tensor_openreg.storage()
|
||||
self.assertTrue(storage_openreg.size() == 16)
|
||||
|
||||
def test_printing(self):
|
||||
a = torch.ones(20, device="openreg")
|
||||
print(a)
|
||||
|
||||
|
||||
class TestSTUB(TestCase):
|
||||
def test_backend_dispatchstub(self):
|
||||
x_cpu = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
|
||||
x_openreg = x_cpu.to("openreg")
|
||||
|
||||
y_cpu = torch.abs(x_cpu)
|
||||
y_openreg = torch.abs(x_openreg)
|
||||
self.assertEqual(y_cpu, y_openreg.cpu())
|
||||
|
||||
o_cpu = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
|
||||
o_openreg = o_cpu.to("openreg")
|
||||
# output operand with resize flag is False in TensorIterator.
|
||||
torch.abs(x_cpu, out=o_cpu[:, :, 0:6:2])
|
||||
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:2])
|
||||
self.assertEqual(o_cpu, o_openreg.cpu())
|
||||
|
||||
# output operand with resize flag is True in TensorIterator and
|
||||
# convert output to contiguous tensor in TensorIterator.
|
||||
torch.abs(x_cpu, out=o_cpu[:, :, 0:6:3])
|
||||
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3])
|
||||
self.assertEqual(o_cpu, o_openreg.cpu())
|
||||
|
||||
|
||||
class TestQuantization(TestCase):
|
||||
@skipIfXpu(msg="missing kernel for openreg")
|
||||
def test_quantize(self):
|
||||
x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg")
|
||||
quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8)
|
||||
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||
|
||||
|
||||
class TestAutogradFunction(TestCase):
|
||||
def test_compile_autograd_function_returns_self(self):
|
||||
in_ref = torch.randn(4, device="openreg", requires_grad=True)
|
||||
out_ref = torch.ops.openreg.custom_autograd_fn_returns_self(in_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
in_test = in_ref.detach().clone().requires_grad_(True)
|
||||
# TODO(FFFrog): Need to support inductor for OpenReg first.
|
||||
out_test = torch.compile(backend="aot_eager")(
|
||||
torch.ops.openreg.custom_autograd_fn_returns_self
|
||||
)(in_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(in_ref.grad, in_test.grad)
|
||||
|
||||
@skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
|
||||
def test_compile_autograd_function_aliasing(self):
|
||||
in_ref = torch.randn(4, device="openreg", requires_grad=True)
|
||||
out_ref = torch.ops.openreg.custom_autograd_fn_aliasing(in_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
in_test = in_ref.detach().clone().requires_grad_(True)
|
||||
# TODO(FFFrog): Need to support inductor for OpenReg first.
|
||||
out_test = torch.compile(backend="aot_eager")(
|
||||
torch.ops.openreg.custom_autograd_fn_aliasing
|
||||
)(in_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(in_ref.grad, in_test.grad)
|
||||
|
||||
|
||||
class TestFallback(TestCase):
|
||||
def test_scalar_type_fallback(self):
|
||||
x_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
|
||||
x = torch.triu_indices(3, 3, device="openreg")
|
||||
self.assertEqual(x_cpu, x)
|
||||
|
||||
def test_tensor_type_fallback(self):
|
||||
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("openreg")
|
||||
y = torch.Tensor([1, 0, 2]).to("openreg")
|
||||
self.assertTrue(x.device.type, "openreg")
|
||||
self.assertFalse(x.is_cpu)
|
||||
|
||||
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
|
||||
# call sub op, which will fallback to cpu
|
||||
z = torch.sub(x, y)
|
||||
self.assertEqual(z_cpu, z)
|
||||
|
||||
# call index op, which will fallback to cpu
|
||||
z_cpu = torch.Tensor([3, 1])
|
||||
y = torch.Tensor([1, 0]).long().to("openreg")
|
||||
z = x[y, y]
|
||||
self.assertEqual(z_cpu, z)
|
||||
|
||||
def test_tensorlist_type_fallback(self):
|
||||
# create tensors located in custom device
|
||||
v_openreg = torch.Tensor([1, 2, 3]).to("openreg")
|
||||
# create result tensor located in cpu
|
||||
z_cpu = torch.Tensor([2, 4, 6])
|
||||
# create tensorlist for foreach_add op
|
||||
x = (v_openreg, v_openreg)
|
||||
y = (v_openreg, v_openreg)
|
||||
|
||||
# Check that our device is correct.
|
||||
self.assertTrue(v_openreg.device.type == "openreg")
|
||||
self.assertFalse(v_openreg.is_cpu)
|
||||
|
||||
# call _foreach_add op, which will fallback to cpu
|
||||
z = torch._foreach_add(x, y)
|
||||
self.assertEqual(z_cpu, z[0])
|
||||
self.assertEqual(z_cpu, z[1])
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
||||
class TestSDPA(NNTestCase):
|
||||
@skipIfTorchDynamo()
|
||||
def test_fused_sdp_choice_privateuseone(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = functools.partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SDPAShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
assert (
|
||||
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
|
||||
== SDPBackend.OVERRIDEABLE.value
|
||||
)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = functools.partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SDPAShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
|
||||
)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = functools.partial(
|
||||
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
shape = (batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
|
||||
q_privateuse1 = q_cpu.to("openreg")
|
||||
k_privateuse1 = k_cpu.to("openreg")
|
||||
v_privateuse1 = v_cpu.to("openreg")
|
||||
attn_mask_privateuse1 = attn_mask.to("openreg")
|
||||
(
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask,
|
||||
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
|
||||
)
|
||||
|
||||
rand_upward = torch.rand(
|
||||
shape, device="cpu", dtype=torch.float16, requires_grad=False
|
||||
)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
grad_q, grad_k, grad_v, grad_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,23 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestRNG(TestCase):
|
||||
def test_generator(self):
|
||||
generator = torch.Generator(device="openreg:1")
|
||||
self.assertEqual(generator.device.type, "openreg")
|
||||
self.assertEqual(generator.device.index, 1)
|
||||
|
||||
def test_rng_state(self):
|
||||
state = torch.openreg.get_rng_state(0)
|
||||
torch.openreg.set_rng_state(state, 0)
|
||||
|
||||
def test_manual_seed(self):
|
||||
torch.openreg.manual_seed_all(2024)
|
||||
self.assertEqual(torch.openreg.initial_seed(), 2024)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,174 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import _codecs
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
|
||||
import torch
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TemporaryFileName,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestStorage(TestCase):
|
||||
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
|
||||
def test_rewrapped_storage(self):
|
||||
pinned_a = torch.randn(10).pin_memory()
|
||||
rewrapped_a = torch.tensor((), dtype=torch.float32).set_(
|
||||
pinned_a.untyped_storage()[2:],
|
||||
size=(5,),
|
||||
stride=(1,),
|
||||
storage_offset=0,
|
||||
)
|
||||
self.assertTrue(rewrapped_a.is_pinned())
|
||||
self.assertNotEqual(pinned_a.data_ptr(), rewrapped_a.data_ptr())
|
||||
|
||||
|
||||
class TestSerialization(TestCase):
|
||||
def test_serialization(self):
|
||||
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||
|
||||
storage = torch.UntypedStorage(4, device=torch.device("openreg:0"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||
|
||||
storage_cpu = torch.empty(4, 4).storage()
|
||||
storage_openreg = torch.serialization.default_restore_location(
|
||||
storage_cpu, "openreg:0"
|
||||
)
|
||||
self.assertTrue(storage_openreg.is_openreg)
|
||||
|
||||
tensor = torch.empty(3, 3, device="openreg")
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor), {})
|
||||
metadata = {"version_number": True, "format_number": True}
|
||||
torch._utils.set_tensor_metadata(tensor, metadata)
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor), metadata)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "data.pt")
|
||||
torch.save(tensor, path)
|
||||
|
||||
tensor_openreg = torch.load(path)
|
||||
self.assertTrue(tensor_openreg.is_openreg)
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata)
|
||||
|
||||
tensor_cpu = torch.load(path, map_location="cpu")
|
||||
self.assertFalse(tensor_cpu.is_openreg)
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor_cpu), {})
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@unittest.skipIf(
|
||||
numpy.__version__ < "1.25",
|
||||
"versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy",
|
||||
)
|
||||
def test_open_device_numpy_serialization(self):
|
||||
"""
|
||||
This tests the legacy _rebuild_device_tensor_from_numpy serialization path
|
||||
"""
|
||||
data_legacy_numpy = (
|
||||
b"PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02}q\x00X\x01"
|
||||
b"\x00\x00\x00xq\x01ctorch._utils\n_rebuild_device_tensor_from_numpy\nq\x02(cnumpy.core.m"
|
||||
b"ultiarray\n_reconstruct\nq\x03cnumpy\nndarray\nq\x04K\x00\x85q\x05c_codecs\nencode\nq\x06"
|
||||
b"X\x01\x00\x00\x00bq\x07X\x06\x00\x00\x00latin1q\x08\x86q\tRq\n\x87q\x0bRq\x0c(K\x01K\x02K"
|
||||
b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01"
|
||||
b"\x00\x00\x00<q\x12NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00tq\x13b\x89h\x06X\x1c\x00\x00"
|
||||
b"\x00\x00\x00\xc2\x80?\x00\x00\x00@\x00\x00@@\x00\x00\xc2\x80@\x00\x00\xc2\xa0@\x00\x00\xc3"
|
||||
b"\x80@q\x14h\x08\x86q\x15Rq\x16tq\x17bctorch\nfloat32\nq\x18X\t\x00\x00\x00openreg:0q\x19\x89"
|
||||
b"tq\x1aRq\x1bs.PK\x07\x08\xdfE\xd6\xcaS\x01\x00\x00S\x01\x00\x00PK\x03\x04\x00\x00\x08"
|
||||
b"\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\x00.\x00"
|
||||
b"archive/byteorderFB*\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZlittlePK\x07\x08"
|
||||
b"\x85=\xe3\x19\x06\x00\x00\x00\x06\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0f\x00=\x00archive/versionFB9\x00"
|
||||
b"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ3\nPK\x07\x08\xd1\x9egU\x02\x00\x00"
|
||||
b"\x00\x02\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x1e\x002\x00archive/.data/serialization_idFB.\x00ZZZZZZZZZZZZZ"
|
||||
b"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ0636457737946401051300000025273995036293PK\x07\x08\xee(\xcd"
|
||||
b"\x8d(\x00\x00\x00(\x00\x00\x00PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00"
|
||||
b"\xdfE\xd6\xcaS\x01\x00\x00S\x01\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00archive/data.pklPK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00"
|
||||
b"\x00\x00\x85=\xe3\x19\x06\x00\x00\x00\x06\x00\x00\x00\x11\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\xa3\x01\x00\x00archive/byteorderPK\x01\x02\x00\x00\x00\x00\x08\x08\x00"
|
||||
b"\x00\x00\x00\x00\x00\xd1\x9egU\x02\x00\x00\x00\x02\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x16\x02\x00\x00archive/versionPK\x01\x02\x00\x00\x00\x00\x08"
|
||||
b"\x08\x00\x00\x00\x00\x00\x00\xee(\xcd\x8d(\x00\x00\x00(\x00\x00\x00\x1e\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x92\x02\x00\x00archive/.data/serialization_idPK\x06"
|
||||
b"\x06,\x00\x00\x00\x00\x00\x00\x00\x1e\x03-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x06\x01\x00\x00\x00\x00\x00\x008\x03\x00"
|
||||
b"\x00\x00\x00\x00\x00PK\x06\x07\x00\x00\x00\x00>\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00"
|
||||
b"PK\x05\x06\x00\x00\x00\x00\x04\x00\x04\x00\x06\x01\x00\x008\x03\x00\x00\x00\x00"
|
||||
)
|
||||
buf_data_legacy_numpy = io.BytesIO(data_legacy_numpy)
|
||||
|
||||
with safe_globals(
|
||||
[
|
||||
(
|
||||
(
|
||||
numpy.core.multiarray._reconstruct,
|
||||
"numpy.core.multiarray._reconstruct",
|
||||
)
|
||||
if numpy.__version__ >= "2.1"
|
||||
else numpy.core.multiarray._reconstruct
|
||||
),
|
||||
numpy.ndarray,
|
||||
numpy.dtype,
|
||||
_codecs.encode,
|
||||
numpy.dtypes.Float32DType,
|
||||
]
|
||||
):
|
||||
sd_loaded = torch.load(buf_data_legacy_numpy, weights_only=True)
|
||||
buf_data_legacy_numpy.seek(0)
|
||||
# Test map_location
|
||||
sd_loaded_cpu = torch.load(
|
||||
buf_data_legacy_numpy, weights_only=True, map_location="cpu"
|
||||
)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device="openreg"
|
||||
)
|
||||
self.assertEqual(sd_loaded["x"].cpu(), expected.cpu())
|
||||
self.assertFalse(sd_loaded["x"].is_cpu)
|
||||
self.assertTrue(sd_loaded_cpu["x"].is_cpu)
|
||||
|
||||
def test_open_device_cpu_serialization(self):
|
||||
default_protocol = torch.serialization.DEFAULT_PROTOCOL
|
||||
|
||||
with unittest.mock.patch.object(torch._C, "_has_storage", return_value=False):
|
||||
x = torch.randn(2, 3)
|
||||
x_openreg = x.to("openreg")
|
||||
sd = {"x": x_openreg}
|
||||
rebuild_func = x_openreg._reduce_ex_internal(default_protocol)[0]
|
||||
self.assertTrue(
|
||||
rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor
|
||||
)
|
||||
|
||||
# Test map_location
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
# Test map_location
|
||||
sd_loaded_cpu = torch.load(f, weights_only=True, map_location="cpu")
|
||||
self.assertFalse(sd_loaded["x"].is_cpu)
|
||||
self.assertEqual(sd_loaded["x"].cpu(), x)
|
||||
self.assertTrue(sd_loaded_cpu["x"].is_cpu)
|
||||
|
||||
# Test metadata_only
|
||||
with TemporaryFileName() as f:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Cannot serialize tensors on backends with no storage under skip_data context manager",
|
||||
):
|
||||
with torch.serialization.skip_data():
|
||||
torch.save(sd, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,27 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class TestStream(TestCase):
|
||||
def test_stream_synchronize(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
stream.synchronize()
|
||||
self.assertEqual(True, stream.query())
|
||||
|
||||
def test_stream_wait_stream(self):
|
||||
stream_1 = torch.Stream(device="openreg:0")
|
||||
stream_2 = torch.Stream(device="openreg:1")
|
||||
stream_2.wait_stream(stream_1)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_stream_wait_event(self):
|
||||
s1 = torch.Stream(device="openreg")
|
||||
s2 = torch.Stream(device="openreg")
|
||||
e = s1.record_event()
|
||||
s2.wait_event(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -0,0 +1,20 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestDLPack(TestCase):
|
||||
def test_open_device_dlpack(self):
|
||||
x_in = torch.randn(2, 3).to("openreg")
|
||||
capsule = torch.utils.dlpack.to_dlpack(x_in)
|
||||
x_out = torch.from_dlpack(capsule)
|
||||
self.assertTrue(x_out.device == x_in.device)
|
||||
|
||||
x_in = x_in.to("cpu")
|
||||
x_out = x_out.to("cpu")
|
||||
self.assertEqual(x_in, x_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
Reference in New Issue
Block a user