Files
pytorch/test/test_content_store.py
Prachi Gupta c0142f5c06 [ROCm] Enabling several UTs (#161715)
All these UTs are working as is, just removing the skip
- test_p2p_ipc
- test_repros.py: working, added fp8 support
- test_activation_checkpointing.py
- test_content_store.py
- test_cuda_multigpu.py
- test_compute_comm_reordering.py
- test_segment_reductions.py
- test_dataloader.py
- test_math_ops.py
- test_loop_ordering.py
- test_control_flow.py
- distributed_test.py
- test_mem_tracker.py
- test_fsdp_optim_state.py
- test_fully_shard_mixed_precision.py: skippped for < ROCm7.0
- test_aot_inductor_custom_ops.py
- test_c10d_ops_nccl.py
- test_eager_transforms.py
- test_sparse_csr.py
- test_inductor_collectives.py
- test_fake_tensor.py
- test_cupy_as_tensor.py
- test_cuda.py: enable UTs that are working
- test_matmul_cuda.py: enable UTs that are working

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161715
Approved by: https://github.com/msaroufim

Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
2025-09-09 15:49:21 +00:00

134 lines
4.8 KiB
Python

# Owner(s): ["oncall: pt2"]
import torch
from torch._prims.debug_prims import load_tensor_reader
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
run_tests,
TemporaryDirectoryName,
TestCase,
)
from torch.utils._content_store import (
ContentStoreReader,
ContentStoreWriter,
hash_storage,
)
class TestContentStore(TestCase):
def test_basic(self, device):
# setup test data
x = torch.randn(4, device=device)
y = torch.randn(6, device=device)
z = x.view(2, 2)
# start writing
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)
writer.write_tensor("x", x)
writer.write_tensor("y", y)
writer.write_tensor("z", z)
# do some mutation that is VC UNTRACKED
x.data.add_(1)
writer.write_tensor("x2", x)
writer.write_tensor("y2", y)
writer.write_tensor("z2", z)
del writer
reader = ContentStoreReader(loc)
n_x = reader.read_tensor("x")
n_y = reader.read_tensor("y")
n_z = reader.read_tensor("z")
self.assertEqual(n_x + 1, x)
self.assertEqual(n_y, y)
self.assertEqual(n_z + 1, z)
self.assertEqual(
StorageWeakRef(n_x.untyped_storage()),
StorageWeakRef(n_z.untyped_storage()),
)
n_x2 = reader.read_tensor("x2")
n_y2 = reader.read_tensor("y2")
n_z2 = reader.read_tensor("z2")
self.assertEqual(n_x2, x)
self.assertEqual(n_y2, y)
self.assertEqual(n_z2, z)
self.assertEqual(
StorageWeakRef(n_y2.untyped_storage()),
StorageWeakRef(n_y.untyped_storage()),
)
def test_scalar(self, device):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@torch._dynamo.config.patch(recompile_limit=1)
def test_repeated_hash(self, device):
# Test that repeated hashing doesn't trigger a recompile in dynamo
# If it does, we will execute prims.xor_sum in eager which fails
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
def test_load_tensor(self, device):
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)
x = torch.randn(4, device=device)
def same_meta_as_x(t):
self.assertEqual(t.size(), x.size())
self.assertEqual(t.stride(), x.stride())
self.assertEqual(t.dtype, x.dtype)
self.assertEqual(t.device, x.device)
writer.write_tensor("x", x)
with load_tensor_reader(loc):
x2 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x2)
x3 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x3)
# Must not alias!
self.assertNotEqual(
StorageWeakRef(x.untyped_storage()),
StorageWeakRef(x2.untyped_storage()),
)
self.assertNotEqual(
StorageWeakRef(x2.untyped_storage()),
StorageWeakRef(x3.untyped_storage()),
)
# Check fake tensor mode works too
with FakeTensorMode():
x4 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works on non-MPS platforms, since MPS doesn't currently
# support fp64.
if not device.startswith("mps"):
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
same_meta_as_x(x6)
instantiate_device_type_tests(
TestContentStore, globals(), allow_mps=True, allow_xpu=True
)
if __name__ == "__main__":
run_tests()