mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #15089, Fixes #156063, Fixes #155689, Fixes #155692, Fixes #156146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156091 Approved by: https://github.com/jansel
106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._functorch
|
|
from torch._dynamo.precompile_context import PrecompileContext
|
|
from torch._functorch import config as functorch_config
|
|
from torch._functorch._aot_autograd.autograd_cache import (
|
|
BundledAOTAutogradCacheArtifact,
|
|
BundledAOTAutogradCacheEntry,
|
|
)
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
|
|
|
|
|
|
@functorch_config.patch({"enable_autograd_cache": True})
|
|
@functorch_config.patch(
|
|
{"bundled_autograd_cache": True}
|
|
) # Requires bundledaotautograd cache for now
|
|
class PrecompileContextTests(InductorTestCase):
|
|
def setUp(self):
|
|
"""
|
|
Reset all counters and caches before each unit test
|
|
"""
|
|
super().setUp()
|
|
# Clear PrecompileContext cache artifacts
|
|
PrecompileContext.clear()
|
|
|
|
@requires_triton()
|
|
def test_basic(self):
|
|
"""
|
|
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
|
|
"""
|
|
|
|
def simple_function(x):
|
|
return x.sin() + x.cos()
|
|
|
|
compiled_fn = torch.compile(simple_function)
|
|
|
|
# Run the compiled function
|
|
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
|
|
result = compiled_fn(x)
|
|
result.sum().backward()
|
|
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
|
|
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
|
result = PrecompileContext.serialize()
|
|
assert result is not None
|
|
serialized, cache_info = result
|
|
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
|
|
|
artifacts = PrecompileContext.deserialize(serialized)
|
|
assert artifacts is not None
|
|
deserialized = artifacts["precompile_aot_autograd"]
|
|
assert len(deserialized) == 1
|
|
entry = deserialized[0]
|
|
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
|
|
entry = entry.after_deserialization()
|
|
assert isinstance(
|
|
entry,
|
|
BundledAOTAutogradCacheEntry,
|
|
)
|
|
|
|
# Now that we've serialized, there should be no new cache artifacts
|
|
self.assertEqual(
|
|
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
|
|
)
|
|
|
|
@requires_triton()
|
|
def test_serialize_by_key(self):
|
|
"""
|
|
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
|
|
"""
|
|
|
|
def simple_function(x):
|
|
return x.sin() + x.cos()
|
|
|
|
compiled_fn = torch.compile(simple_function)
|
|
|
|
# Run the compiled function
|
|
x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
|
|
result = compiled_fn(x)
|
|
result.sum().backward()
|
|
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
|
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
|
|
# we have torch._dynamo.package implemented
|
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
|
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
|
|
result = PrecompileContext.serialize_artifact_by_key(key)
|
|
assert isinstance(result, BundledAOTAutogradCacheArtifact)
|
|
self.assertEqual(result.key, key)
|
|
|
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
|
result = PrecompileContext.serialize()
|
|
assert result is not None
|
|
_, cache_info = result
|
|
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|