Files
pytorch/test/dynamo/test_precompile_context.py
2025-06-17 23:43:37 +00:00

106 lines
3.8 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheArtifact,
BundledAOTAutogradCacheEntry,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch(
{"bundled_autograd_cache": True}
) # Requires bundledaotautograd cache for now
class PrecompileContextTests(InductorTestCase):
def setUp(self):
"""
Reset all counters and caches before each unit test
"""
super().setUp()
# Clear PrecompileContext cache artifacts
PrecompileContext.clear()
@requires_triton()
def test_basic(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
serialized, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
artifacts = PrecompileContext.deserialize(serialized)
assert artifacts is not None
deserialized = artifacts["precompile_aot_autograd"]
assert len(deserialized) == 1
entry = deserialized[0]
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
entry = entry.after_deserialization()
assert isinstance(
entry,
BundledAOTAutogradCacheEntry,
)
# Now that we've serialized, there should be no new cache artifacts
self.assertEqual(
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
)
@requires_triton()
def test_serialize_by_key(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
# we have torch._dynamo.package implemented
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, BundledAOTAutogradCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
_, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()