mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache. This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries. The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries(). Saving single entries and such with DynamoCache still works normally. Test Plan: All existing unit tests pass. Rollback Plan: Differential Revision: D82380307 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886 Approved by: https://github.com/zhxchen17
		
			
				
	
	
		
			112 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			112 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: dynamo"]
 | |
| import torch
 | |
| import torch._dynamo
 | |
| import torch._dynamo.test_case
 | |
| import torch._functorch
 | |
| from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext
 | |
| from torch._functorch import config as functorch_config
 | |
| from torch._functorch._aot_autograd.autograd_cache import (
 | |
|     BundledAOTAutogradCacheArtifact,
 | |
| )
 | |
| from torch._inductor.test_case import TestCase as InductorTestCase
 | |
| from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
 | |
| 
 | |
| 
 | |
| @functorch_config.patch({"enable_autograd_cache": True})
 | |
| @torch._dynamo.config.patch(
 | |
|     {"caching_precompile": True}
 | |
| )  # Requires bundledaotautograd cache for now
 | |
| class PrecompileContextTests(InductorTestCase):
 | |
|     def setUp(self):
 | |
|         """
 | |
|         Reset all counters and caches before each unit test
 | |
|         """
 | |
|         super().setUp()
 | |
|         # Clear PrecompileContext cache artifacts
 | |
|         PrecompileContext.clear()
 | |
| 
 | |
|     @requires_triton()
 | |
|     def test_basic(self):
 | |
|         """
 | |
|         Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
 | |
|         """
 | |
| 
 | |
|         def simple_function(x):
 | |
|             return x.sin() + x.cos()
 | |
| 
 | |
|         compiled_fn = torch.compile(simple_function)
 | |
| 
 | |
|         # Run the compiled function
 | |
|         x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
 | |
|         result = compiled_fn(x)
 | |
|         result.sum().backward()
 | |
|         self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
 | |
|         self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
 | |
|         cache_entries, _ = PrecompileContext.create_cache_entries()
 | |
|         self.assertEqual(len(cache_entries), 1)
 | |
| 
 | |
|     @requires_triton()
 | |
|     def test_serialize_by_key(self):
 | |
|         def simple_function(x):
 | |
|             return x.sin() + x.cos()
 | |
| 
 | |
|         compiled_fn = torch.compile(simple_function)
 | |
| 
 | |
|         # Run the compiled function
 | |
|         x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
 | |
|         result = compiled_fn(x)
 | |
|         result.sum().backward()
 | |
|         self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
 | |
|         self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
 | |
|         for key in PrecompileContext._backend_artifacts_by_key.keys():
 | |
|             result = PrecompileContext.serialize_artifact_by_key(key)
 | |
|             assert isinstance(result, BackendCacheArtifact)
 | |
|             self.assertEqual(result.key, key)
 | |
| 
 | |
|         # This should still work
 | |
|         result, _ = PrecompileContext.create_cache_entries()
 | |
|         assert len(result) == 1
 | |
| 
 | |
|     @requires_triton()
 | |
|     def test_editable(self):
 | |
|         """
 | |
|         Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
 | |
|         """
 | |
| 
 | |
|         def simple_function(x):
 | |
|             return x.sin() + x.cos()
 | |
| 
 | |
|         compiled_fn = torch.compile(simple_function)
 | |
| 
 | |
|         # Run the compiled function
 | |
|         x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
 | |
|         result = compiled_fn(x)
 | |
|         result.sum().backward()
 | |
|         self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
 | |
|         self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
 | |
|         # Find the key for the artifact of type "precompile_aot_autograd"
 | |
|         key = next(iter(PrecompileContext._backend_artifacts_by_key))
 | |
| 
 | |
|         def edit_fn(x):
 | |
|             x._my_private_field = 42
 | |
|             return x
 | |
| 
 | |
|         PrecompileContext.edit_artifact(key, edit_fn)
 | |
| 
 | |
|         result = PrecompileContext.serialize_artifact_by_key(key)
 | |
|         assert isinstance(result, BundledAOTAutogradCacheArtifact)
 | |
|         self.assertEqual(result.key, key)
 | |
| 
 | |
|         result, _ = PrecompileContext.create_cache_entries()
 | |
|         assert len(result) == 1
 | |
|         aot_autograd_artifacts = next(iter(result.values())).backends
 | |
|         assert len(aot_autograd_artifacts) == 1
 | |
|         entry = next(iter(aot_autograd_artifacts.values())).content
 | |
|         self.assertEqual(entry._my_private_field, 42)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     from torch._dynamo.test_case import run_tests
 | |
| 
 | |
|     run_tests()
 |