mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
All these UTs are working as is, just removing the skip - test_p2p_ipc - test_repros.py: working, added fp8 support - test_activation_checkpointing.py - test_content_store.py - test_cuda_multigpu.py - test_compute_comm_reordering.py - test_segment_reductions.py - test_dataloader.py - test_math_ops.py - test_loop_ordering.py - test_control_flow.py - distributed_test.py - test_mem_tracker.py - test_fsdp_optim_state.py - test_fully_shard_mixed_precision.py: skippped for < ROCm7.0 - test_aot_inductor_custom_ops.py - test_c10d_ops_nccl.py - test_eager_transforms.py - test_sparse_csr.py - test_inductor_collectives.py - test_fake_tensor.py - test_cupy_as_tensor.py - test_cuda.py: enable UTs that are working - test_matmul_cuda.py: enable UTs that are working Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/161715 Approved by: https://github.com/msaroufim Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
134 lines
4.8 KiB
Python
134 lines
4.8 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
|
|
import torch
|
|
from torch._prims.debug_prims import load_tensor_reader
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TemporaryDirectoryName,
|
|
TestCase,
|
|
)
|
|
from torch.utils._content_store import (
|
|
ContentStoreReader,
|
|
ContentStoreWriter,
|
|
hash_storage,
|
|
)
|
|
|
|
|
|
class TestContentStore(TestCase):
|
|
def test_basic(self, device):
|
|
# setup test data
|
|
x = torch.randn(4, device=device)
|
|
y = torch.randn(6, device=device)
|
|
z = x.view(2, 2)
|
|
# start writing
|
|
with TemporaryDirectoryName() as loc:
|
|
writer = ContentStoreWriter(loc)
|
|
writer.write_tensor("x", x)
|
|
writer.write_tensor("y", y)
|
|
writer.write_tensor("z", z)
|
|
# do some mutation that is VC UNTRACKED
|
|
x.data.add_(1)
|
|
writer.write_tensor("x2", x)
|
|
writer.write_tensor("y2", y)
|
|
writer.write_tensor("z2", z)
|
|
del writer
|
|
|
|
reader = ContentStoreReader(loc)
|
|
n_x = reader.read_tensor("x")
|
|
n_y = reader.read_tensor("y")
|
|
n_z = reader.read_tensor("z")
|
|
self.assertEqual(n_x + 1, x)
|
|
self.assertEqual(n_y, y)
|
|
self.assertEqual(n_z + 1, z)
|
|
self.assertEqual(
|
|
StorageWeakRef(n_x.untyped_storage()),
|
|
StorageWeakRef(n_z.untyped_storage()),
|
|
)
|
|
n_x2 = reader.read_tensor("x2")
|
|
n_y2 = reader.read_tensor("y2")
|
|
n_z2 = reader.read_tensor("z2")
|
|
self.assertEqual(n_x2, x)
|
|
self.assertEqual(n_y2, y)
|
|
self.assertEqual(n_z2, z)
|
|
self.assertEqual(
|
|
StorageWeakRef(n_y2.untyped_storage()),
|
|
StorageWeakRef(n_y.untyped_storage()),
|
|
)
|
|
|
|
def test_scalar(self, device):
|
|
# Should not raise an error
|
|
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
|
|
|
@torch._dynamo.config.patch(recompile_limit=1)
|
|
def test_repeated_hash(self, device):
|
|
# Test that repeated hashing doesn't trigger a recompile in dynamo
|
|
# If it does, we will execute prims.xor_sum in eager which fails
|
|
for _ in range(4):
|
|
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
|
|
|
def test_load_tensor(self, device):
|
|
with TemporaryDirectoryName() as loc:
|
|
writer = ContentStoreWriter(loc)
|
|
x = torch.randn(4, device=device)
|
|
|
|
def same_meta_as_x(t):
|
|
self.assertEqual(t.size(), x.size())
|
|
self.assertEqual(t.stride(), x.stride())
|
|
self.assertEqual(t.dtype, x.dtype)
|
|
self.assertEqual(t.device, x.device)
|
|
|
|
writer.write_tensor("x", x)
|
|
|
|
with load_tensor_reader(loc):
|
|
x2 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertEqual(x, x2)
|
|
x3 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertEqual(x, x3)
|
|
# Must not alias!
|
|
self.assertNotEqual(
|
|
StorageWeakRef(x.untyped_storage()),
|
|
StorageWeakRef(x2.untyped_storage()),
|
|
)
|
|
self.assertNotEqual(
|
|
StorageWeakRef(x2.untyped_storage()),
|
|
StorageWeakRef(x3.untyped_storage()),
|
|
)
|
|
|
|
# Check fake tensor mode works too
|
|
with FakeTensorMode():
|
|
x4 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertIsInstance(x4, FakeTensor)
|
|
same_meta_as_x(x4)
|
|
|
|
# Check fp64 works on non-MPS platforms, since MPS doesn't currently
|
|
# support fp64.
|
|
if not device.startswith("mps"):
|
|
x5 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float64, device=device
|
|
)
|
|
self.assertEqual(x5.float(), x)
|
|
self.assertEqual(x5.dtype, torch.float64)
|
|
|
|
x6 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
same_meta_as_x(x6)
|
|
|
|
|
|
instantiate_device_type_tests(
|
|
TestContentStore, globals(), allow_mps=True, allow_xpu=True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|