mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR renames every cache_limit to recompile_limit via sed. Old config options are maintained via Config(alias='xyz') Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709 Approved by: https://github.com/jansel
136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
from torch._prims.debug_prims import load_tensor_reader
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfRocm,
|
|
TestCase,
|
|
)
|
|
from torch.utils._content_store import (
|
|
ContentStoreReader,
|
|
ContentStoreWriter,
|
|
hash_storage,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
|
|
class TestContentStore(TestCase):
|
|
def test_basic(self, device):
|
|
# setup test data
|
|
x = torch.randn(4, device=device)
|
|
y = torch.randn(6, device=device)
|
|
z = x.view(2, 2)
|
|
# start writing
|
|
with tempfile.TemporaryDirectory() as loc:
|
|
writer = ContentStoreWriter(loc)
|
|
writer.write_tensor("x", x)
|
|
writer.write_tensor("y", y)
|
|
writer.write_tensor("z", z)
|
|
# do some mutation that is VC UNTRACKED
|
|
x.data.add_(1)
|
|
writer.write_tensor("x2", x)
|
|
writer.write_tensor("y2", y)
|
|
writer.write_tensor("z2", z)
|
|
del writer
|
|
|
|
reader = ContentStoreReader(loc)
|
|
n_x = reader.read_tensor("x")
|
|
n_y = reader.read_tensor("y")
|
|
n_z = reader.read_tensor("z")
|
|
self.assertEqual(n_x + 1, x)
|
|
self.assertEqual(n_y, y)
|
|
self.assertEqual(n_z + 1, z)
|
|
self.assertEqual(
|
|
StorageWeakRef(n_x.untyped_storage()),
|
|
StorageWeakRef(n_z.untyped_storage()),
|
|
)
|
|
n_x2 = reader.read_tensor("x2")
|
|
n_y2 = reader.read_tensor("y2")
|
|
n_z2 = reader.read_tensor("z2")
|
|
self.assertEqual(n_x2, x)
|
|
self.assertEqual(n_y2, y)
|
|
self.assertEqual(n_z2, z)
|
|
self.assertEqual(
|
|
StorageWeakRef(n_y2.untyped_storage()),
|
|
StorageWeakRef(n_y.untyped_storage()),
|
|
)
|
|
|
|
def test_scalar(self, device):
|
|
# Should not raise an error
|
|
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
|
|
|
@torch._dynamo.config.patch(recompile_limit=1)
|
|
def test_repeated_hash(self, device):
|
|
# Test that repeated hashing doesn't trigger a recompile in dynamo
|
|
# If it does, we will execute prims.xor_sum in eager which fails
|
|
for _ in range(4):
|
|
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
|
|
|
@skipIfRocm
|
|
def test_load_tensor(self, device):
|
|
with tempfile.TemporaryDirectory() as loc:
|
|
writer = ContentStoreWriter(loc)
|
|
x = torch.randn(4, device=device)
|
|
|
|
def same_meta_as_x(t):
|
|
self.assertEqual(t.size(), x.size())
|
|
self.assertEqual(t.stride(), x.stride())
|
|
self.assertEqual(t.dtype, x.dtype)
|
|
self.assertEqual(t.device, x.device)
|
|
|
|
writer.write_tensor("x", x)
|
|
|
|
with load_tensor_reader(loc):
|
|
x2 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertEqual(x, x2)
|
|
x3 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertEqual(x, x3)
|
|
# Must not alias!
|
|
self.assertNotEqual(
|
|
StorageWeakRef(x.untyped_storage()),
|
|
StorageWeakRef(x2.untyped_storage()),
|
|
)
|
|
self.assertNotEqual(
|
|
StorageWeakRef(x2.untyped_storage()),
|
|
StorageWeakRef(x3.untyped_storage()),
|
|
)
|
|
|
|
# Check fake tensor mode works too
|
|
with FakeTensorMode():
|
|
x4 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
self.assertIsInstance(x4, FakeTensor)
|
|
same_meta_as_x(x4)
|
|
|
|
# Check fp64 works
|
|
x5 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float64, device=device
|
|
)
|
|
self.assertEqual(x5.float(), x)
|
|
self.assertEqual(x5.dtype, torch.float64)
|
|
|
|
x6 = torch.ops.debugprims.load_tensor.default(
|
|
"x", (4,), (1,), dtype=torch.float32, device=device
|
|
)
|
|
same_meta_as_x(x6)
|
|
|
|
|
|
instantiate_device_type_tests(TestContentStore, globals())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|