mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 00:24:53 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			codex/add-
			...
			ciflow/ind
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d6d3367233 | |||
| a514a050fa | |||
| 1a71669a22 | |||
| a91c4ceb08 | |||
| 8b3dc0d1b0 | |||
| 485c73a947 | |||
| 06773663b5 | |||
| 0bff65503c | |||
| 5c075d8bcf | 
| @ -39,7 +39,7 @@ struct HostBlock { | ||||
| }; | ||||
|  | ||||
| template <typename B> | ||||
| struct alignas(64) FreeBlockList { | ||||
| struct alignas(hardware_destructive_interference_size) FreeBlockList { | ||||
|   std::mutex mutex_; | ||||
|   std::deque<B*> list_; | ||||
| }; | ||||
| @ -122,7 +122,7 @@ struct TORCH_API HostStats { | ||||
| // Struct containing memory allocator summary statistics for host, as they | ||||
| // are staged for reporting. This is a temporary struct that is used to | ||||
| // avoid locking the allocator while collecting stats. | ||||
| struct alignas(64) HostStatsStaged { | ||||
| struct alignas(hardware_destructive_interference_size) HostStatsStaged { | ||||
|   std::mutex timing_mutex_; | ||||
|   // COUNT: total allocations (active + free) | ||||
|   // LOCK: access to this stat is protected by the allocator's blocks_mutex_ | ||||
| @ -669,7 +669,7 @@ struct CachingHostAllocatorImpl { | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); | ||||
|   } | ||||
|  | ||||
|   alignas(64) std::mutex blocks_mutex_; | ||||
|   alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_; | ||||
|   ska::flat_hash_set<B*> blocks_; // block list | ||||
|   ska::flat_hash_map<void*, B*> ptr_to_block_; | ||||
|  | ||||
| @ -677,17 +677,17 @@ struct CachingHostAllocatorImpl { | ||||
|   // size. This allows us to quickly find a free block of the right size. | ||||
|   // We use deque to store per size free list and guard the list with its own | ||||
|   // mutex. | ||||
|   alignas(64) std::vector<FreeBlockList<B>> free_list_ = | ||||
|   alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ = | ||||
|       std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); | ||||
|  | ||||
|   alignas(64) std::mutex events_mutex_; | ||||
|   alignas(hardware_destructive_interference_size) std::mutex events_mutex_; | ||||
|   std::deque<std::pair<E, B*>> events_; // event queue paired with block | ||||
|  | ||||
|   // Indicates whether the object is active. | ||||
|   // Set to false in the destructor to signal background threads to stop. | ||||
|   std::atomic<bool> active_{true}; | ||||
| protected: | ||||
|   alignas(64) HostStatsStaged stats_; | ||||
|   alignas(hardware_destructive_interference_size) HostStatsStaged stats_; | ||||
| }; | ||||
|  | ||||
| struct TORCH_API HostAllocator : public at::Allocator { | ||||
|  | ||||
| @ -141,8 +141,6 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   checkTrilTriuMemoryOverlap(result, self); | ||||
|  | ||||
|   bool inplace_op = self.is_same(result); | ||||
|  | ||||
|   bool inplace_update = false; | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| #include <ATen/MemoryOverlap.h> | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/native/LinearAlgebraUtils.h> | ||||
|  | ||||
| @ -55,13 +54,4 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor | ||||
|   return std::make_tuple(true, tensor); | ||||
| } | ||||
|  | ||||
| static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) { | ||||
|   if (result.is_same(self)) { | ||||
|     at::assert_no_internal_overlap(result); | ||||
|   } else { | ||||
|     at::assert_no_internal_overlap(result); | ||||
|     at::assert_no_overlap(result, self); | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace at::native | ||||
|  | ||||
| @ -5,7 +5,6 @@ | ||||
| #include <ATen/Dispatch.h> | ||||
| #include <ATen/MemoryOverlap.h> | ||||
| #include <ATen/native/Resize.h> | ||||
| #include <ATen/native/TriangularOpsUtils.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -111,8 +110,6 @@ __global__ void triu_tril_kernel( | ||||
|  | ||||
| template <bool upper> | ||||
| void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) { | ||||
|   checkTrilTriuMemoryOverlap(result, self); | ||||
|  | ||||
|   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( | ||||
|       at::ScalarType::ComplexHalf, | ||||
|       at::ScalarType::Half, | ||||
|  | ||||
| @ -9,6 +9,7 @@ | ||||
|  | ||||
| #include <c10/core/Device.h> | ||||
| #include <c10/core/DeviceType.h> | ||||
| #include <c10/core/alignment.h> | ||||
| #include <c10/macros/Export.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/Exception.h> | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <cstddef> | ||||
| #include <new> | ||||
|  | ||||
| namespace c10 { | ||||
|  | ||||
| @ -18,4 +19,12 @@ constexpr size_t gPagesize = 4096; | ||||
| // since the default thp pagesize is 2MB, enable thp only | ||||
| // for buffers of size 2MB or larger to avoid memory bloating | ||||
| constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024; | ||||
|  | ||||
| // Cache line size used to avoid false sharing between threads. Falls back to 64 | ||||
| // bytes if C++17 feature is unavailable. | ||||
| #ifdef __cpp_lib_hardware_interference_size | ||||
| using std::hardware_destructive_interference_size; | ||||
| #else | ||||
| constexpr std::size_t hardware_destructive_interference_size = 64; | ||||
| #endif | ||||
| } // namespace c10 | ||||
|  | ||||
| @ -941,7 +941,7 @@ class EventPool { | ||||
|  | ||||
|  private: | ||||
|   struct PerDevicePool { | ||||
|     alignas(64) std::mutex mutex_; | ||||
|     alignas(hardware_destructive_interference_size) std::mutex mutex_; | ||||
|     std::vector<std::unique_ptr<cudaEvent_t>> event_pool_; | ||||
|   }; | ||||
|   std::vector<PerDevicePool> pools_; | ||||
| @ -3758,11 +3758,6 @@ static void uncached_delete(void* ptr) { | ||||
| static void local_raw_delete(void* ptr); | ||||
| thread_local std::stack<std::string> DeviceCachingAllocator::compile_context; | ||||
| thread_local std::string DeviceCachingAllocator::user_metadata; | ||||
| #ifdef __cpp_lib_hardware_interference_size | ||||
| using std::hardware_destructive_interference_size; | ||||
| #else | ||||
| static constexpr std::size_t hardware_destructive_interference_size = 64; | ||||
| #endif | ||||
|  | ||||
| class NativeCachingAllocator : public CUDAAllocator { | ||||
|  private: | ||||
|  | ||||
| @ -554,7 +554,7 @@ static void local_raw_delete(void* ptr); | ||||
|  | ||||
| class XPUAllocator : public DeviceAllocator { | ||||
|  private: | ||||
|   std::mutex mutex; | ||||
|   alignas(hardware_destructive_interference_size) std::mutex mutex; | ||||
|   ska::flat_hash_map<void*, Block*> allocated_blocks; | ||||
|  | ||||
|   void add_allocated_block(Block* block) { | ||||
|  | ||||
| @ -3064,11 +3064,12 @@ class GraphModule(torch.nn.Module): | ||||
|             primals_6,  # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor') | ||||
|             primals_7,  # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor') | ||||
|             primals_8,  # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0) | ||||
|             primals_10,  # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) | ||||
|             primals_3,  # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) | ||||
|             primals_10,  # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) | ||||
|             primals_1,  # SavedForBackwardsAOTOutput(idx=0) | ||||
|             primals_8,  # SavedForBackwardsAOTOutput(idx=1) | ||||
|             primals_10,  # SavedForBackwardsAOTOutput(idx=2) | ||||
|             primals_3,  # SavedForBackwardsAOTOutput(idx=1) | ||||
|             primals_8,  # SavedForBackwardsAOTOutput(idx=2) | ||||
|             primals_10,  # SavedForBackwardsAOTOutput(idx=3) | ||||
|         ) | ||||
| """,  # noqa: B950 | ||||
|         ) | ||||
| @ -3080,6 +3081,7 @@ class GraphModule(torch.nn.Module): | ||||
|     def forward( | ||||
|         self, | ||||
|         primals_1: "Sym(s51)",  # PlainAOTInput(idx=0) | ||||
|         primals_3: "Sym(s55)",  # PlainAOTInput(idx=2) | ||||
|         primals_8: "Sym(s51)",  # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0) | ||||
|         primals_10: "Sym(s55)",  # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) | ||||
|         tangents_1: "f64[s64, s55]",  # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values') | ||||
| @ -3097,7 +3099,7 @@ class GraphModule(torch.nn.Module): | ||||
|             tangents_3,  # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor') | ||||
|             tangents_4,  # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor') | ||||
|             primals_8,  # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0) | ||||
|             primals_10,  # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) | ||||
|             primals_3,  # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) | ||||
|             primals_10,  # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) | ||||
|         ) | ||||
| """,  # noqa: B950 | ||||
| @ -3134,7 +3136,7 @@ class GraphModule(torch.nn.Module): | ||||
|         clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4);  primals_4 = None | ||||
|  | ||||
|         cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1);  clone = None | ||||
|         add_2: "Sym(2*s55)" = primals_10 + primals_10 | ||||
|         add_2: "Sym(2*s55)" = primals_10 + primals_3;  primals_3 = None | ||||
|         return ( | ||||
|             cat,  # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') | ||||
|             primals_5,  # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') | ||||
|  | ||||
| @ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase): | ||||
|     @parametrize("format", ("binary", "unpacked")) | ||||
|     @parametrize("dynamic", (False, True)) | ||||
|     @parametrize("graph_partition", (False, True)) | ||||
|     @parametrize("is_aot", (False, True)) | ||||
|     def test_basic( | ||||
|         self, device: str, format: str, dynamic: bool, graph_partition: bool | ||||
|         self, | ||||
|         device: str, | ||||
|         format: str, | ||||
|         dynamic: bool, | ||||
|         graph_partition: bool, | ||||
|         is_aot: bool, | ||||
|     ) -> None: | ||||
|         if device == GPU_TYPE and not HAS_GPU: | ||||
|             raise unittest.SkipTest(f"requires {GPU_TYPE}") | ||||
|  | ||||
|         # AOT mode does not support unpacked format | ||||
|         if is_aot and format == "unpacked": | ||||
|             raise unittest.SkipTest("AOT mode does not support unpacked format") | ||||
|  | ||||
|         mod = torch.nn.Linear(1, 3, device=device) | ||||
|         x = torch.randn(4, 1, device=device) | ||||
|         if dynamic: | ||||
| @ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase): | ||||
|                 gm, args, kwargs = self.capture(f)(x) | ||||
|                 assert not kwargs | ||||
|  | ||||
|                 compiled_artifact = torch._inductor.standalone_compile(gm, args) | ||||
|                 compiled_artifact = torch._inductor.standalone_compile( | ||||
|                     gm, args, aot=is_aot | ||||
|                 ) | ||||
|                 compiled_artifact.save(path=path, format=format) | ||||
|  | ||||
|             self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) | ||||
| @ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase): | ||||
|                 compiled_out = loaded(*concrete_args) | ||||
|                 self.assertEqual(eager_out, compiled_out) | ||||
|  | ||||
|             self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) | ||||
|             if not is_aot: | ||||
|                 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) | ||||
|  | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     @parametrize("dynamic", (False, True)) | ||||
|     def test_call_in_backend(self, dynamic: bool) -> None: | ||||
|     @parametrize("is_aot", (False, True)) | ||||
|     def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None: | ||||
|         mod = torch.nn.Linear(1, 3) | ||||
|         x = torch.randn(4, 1) | ||||
|         if dynamic: | ||||
| @ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase): | ||||
|         eager_out = f(x) | ||||
|  | ||||
|         def backend(gm, args, **kwargs): | ||||
|             return torch._inductor.standalone_compile(gm, args) | ||||
|             return torch._inductor.standalone_compile(gm, args, aot=is_aot) | ||||
|  | ||||
|         with fresh_cache(): | ||||
|             compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) | ||||
| @ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     def test_dynamic_shapes_from_graph(self): | ||||
|     @parametrize("is_aot", (False, True)) | ||||
|     def test_dynamic_shapes_from_graph(self, is_aot: bool): | ||||
|         def f(x): | ||||
|             return x.shape[0] * x | ||||
|  | ||||
| @ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|             assert not kwargs | ||||
|  | ||||
|         compiled_artifact = torch._inductor.standalone_compile( | ||||
|             gm, args, dynamic_shapes="from_graph" | ||||
|             gm, args, dynamic_shapes="from_graph", aot=is_aot | ||||
|         ) | ||||
|         x = torch.ones(4) | ||||
|         (result,) = compiled_artifact(4, x) | ||||
| @ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     @functorch_config.patch({"autograd_cache_normalize_inputs": True}) | ||||
|     def test_split_module(self): | ||||
|     @parametrize("is_aot", (False, True)) | ||||
|     def test_split_module(self, is_aot): | ||||
|         class Mod(torch.nn.Module): | ||||
|             def forward(self, x, a0, a1, b0, b1, c0, c1): | ||||
|                 x = x + (a0**2) + (a1 / 2) | ||||
| @ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|         split = torch.fx.passes.split_module.split_module(gm, gm, split) | ||||
|  | ||||
|         # Each of the split graphs only has one output. | ||||
|         ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1)) | ||||
|         ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1)) | ||||
|         ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1)) | ||||
|         ca0 = torch._inductor.standalone_compile( | ||||
|             split.submod_0, (a0, x, a1), aot=is_aot | ||||
|         ) | ||||
|         ca1 = torch._inductor.standalone_compile( | ||||
|             split.submod_1, (b0, x, b1), aot=is_aot | ||||
|         ) | ||||
|         ca2 = torch._inductor.standalone_compile( | ||||
|             split.submod_2, (c0, x, c1), aot=is_aot | ||||
|         ) | ||||
|  | ||||
|         y = ca0(a0, x, a1) | ||||
|         y = ca1(b0, y, b1) | ||||
|         y = ca2(c0, y, c1) | ||||
|         self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) | ||||
|         self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) | ||||
|         self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) | ||||
|         if not is_aot: | ||||
|             # fx graph cache doesn't run in AOT mode | ||||
|             self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) | ||||
|             self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) | ||||
|             self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) | ||||
|         # TODO: split_module causes ca1 and ca2 to have different type annotations | ||||
|         # for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice | ||||
|         self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) | ||||
| @ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     @parametrize("is_aot", (False, True)) | ||||
|     @parametrize("config_patches", [True, False]) | ||||
|     def test_dynamic_shapes_from_example_inputs(self, config_patches): | ||||
|     def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot): | ||||
|         def f(x): | ||||
|             return x.shape[0] * x | ||||
|  | ||||
| @ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|             (5, torch.ones(4)), | ||||
|             dynamic_shapes="from_example_inputs", | ||||
|             options={"config_patches": config_patches}, | ||||
|             aot=is_aot, | ||||
|         ) | ||||
|         x = torch.ones(4) | ||||
|         (result,) = compiled_artifact(3, x) | ||||
| @ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     @parametrize("is_aot", (True, False)) | ||||
|     @parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"]) | ||||
|     def test_static_shapes(self, dynamic_shapes): | ||||
|     def test_static_shapes(self, dynamic_shapes, is_aot): | ||||
|         def f(x): | ||||
|             return x.shape[0] * x | ||||
|  | ||||
| @ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|             static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x) | ||||
|             assert not kwargs | ||||
|         compiled_artifact = torch._inductor.standalone_compile( | ||||
|             static_gm, [static_x], dynamic_shapes=dynamic_shapes | ||||
|             static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot | ||||
|         ) | ||||
|         x = torch.randn(3) | ||||
|         (result,) = compiled_artifact(x) | ||||
| @ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     @parametrize("is_aot", (True, False)) | ||||
|     @parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"]) | ||||
|     def test_backend(self, dynamic_shapes): | ||||
|     def test_backend(self, dynamic_shapes, is_aot): | ||||
|         def f(x): | ||||
|             return x.shape[0] * x | ||||
|  | ||||
| @ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|  | ||||
|         def backend(gm, args, **kwargs): | ||||
|             compiled_artifact = torch._inductor.standalone_compile( | ||||
|                 gm, args, dynamic_shapes=dynamic_shapes | ||||
|                 gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot | ||||
|             ) | ||||
|             y = torch.randn(4) | ||||
|             (result,) = compiled_artifact(4, y) | ||||
| @ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|     @config.patch({"fx_graph_cache": True}) | ||||
|     @config.patch({"fx_graph_remote_cache": False}) | ||||
|     @functorch_config.patch({"enable_autograd_cache": True}) | ||||
|     def test_backend_dynamic_shapes_from_example_inputs(self): | ||||
|     @parametrize("is_aot", (True, False)) | ||||
|     def test_backend_dynamic_shapes_from_example_inputs(self, is_aot): | ||||
|         def f(x): | ||||
|             return x.shape[0] * x | ||||
|  | ||||
| @ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): | ||||
|  | ||||
|         def backend(gm, args, **kwargs): | ||||
|             compiled_artifact = torch._inductor.standalone_compile( | ||||
|                 gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs" | ||||
|                 gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot | ||||
|             ) | ||||
|             y = torch.ones(4) | ||||
|             (result,) = compiled_artifact(4, y) | ||||
|  | ||||
| @ -9986,20 +9986,6 @@ scipy_lobpcg  | {eq_err_scipy:10.2e}  | {eq_err_general_scipy:10.2e}  | {iters2: | ||||
|         self.assertEqual(result_triu_min, expected_triu_min) | ||||
|         self.assertEqual(result_tril_min, expected_tril_min) | ||||
|  | ||||
|     @dtypes(torch.float) | ||||
|     def test_triu_tril_inplace_memory_overlap(self, device, dtype): | ||||
|         base = torch.rand((), dtype=dtype, device=device) | ||||
|         expanded = base.expand(3, 3) | ||||
|         msg = ( | ||||
|             "unsupported operation: more than one element of the written-to tensor " | ||||
|             "refers to a single memory location. Please clone() the tensor before " | ||||
|             "performing the operation." | ||||
|         ) | ||||
|         with self.assertRaisesRegex(RuntimeError, msg): | ||||
|             expanded.triu_(1) | ||||
|         with self.assertRaisesRegex(RuntimeError, msg): | ||||
|             expanded.tril_(-1) | ||||
|  | ||||
|     @dtypes(torch.float, torch.double) | ||||
|     @precisionOverride({torch.float32: 1e-4}) | ||||
|     def test_1_sized_with_0_strided(self, device, dtype): | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| import abc | ||||
| import dataclasses | ||||
| import importlib | ||||
| import inspect | ||||
| @ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type | ||||
| from torch._dynamo.package import SystemInfo | ||||
|  | ||||
| from . import convert_frame | ||||
| from .aot_compile_types import ( | ||||
|     BundledAOTAutogradSerializableCallable, | ||||
|     SerializableCallable, | ||||
| ) | ||||
| from .hooks import Hooks | ||||
|  | ||||
|  | ||||
| @ -26,18 +29,6 @@ if TYPE_CHECKING: | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class SerializableCallable(abc.ABC): | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     def serialize_compile_artifacts(cls, fn: Any) -> bytes: | ||||
|         pass | ||||
|  | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     def deserialize_compile_artifacts(cls, data: bytes) -> Any: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def bind_locals( | ||||
|     signature: inspect.Signature, *args: Any, **kwargs: Any | ||||
| ) -> dict[str, Any]: | ||||
| @ -149,53 +140,6 @@ class AOTCompiledFunction: | ||||
|         self._guard_check_enabled = False | ||||
|  | ||||
|  | ||||
| class BundledAOTAutogradSerializableCallable(SerializableCallable): | ||||
|     """ | ||||
|     Represents a serializable callable generated by compile_fx. | ||||
|     This class wraps around the compiled function generated by AOTAutograd. | ||||
|  | ||||
|     TODO: Instead of using PrecompileContext to grab it from AOTAutograd, | ||||
|     this object should be what's *returned* by aot_module_simplified. | ||||
|     We'll do that refactor in a later PR. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, compiled_fn: Any) -> None: | ||||
|         """ | ||||
|         Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form | ||||
|         of a compiled function generated by AOTAutograd. | ||||
|         """ | ||||
|         assert hasattr(compiled_fn, "serialize") | ||||
|         self.compiled_fn = compiled_fn | ||||
|  | ||||
|     def __getattr__(self, attr: Any) -> Any: | ||||
|         if hasattr(self, attr): | ||||
|             return getattr(super(), attr) | ||||
|         else: | ||||
|             return getattr(self.compiled_fn, attr) | ||||
|  | ||||
|     @classmethod | ||||
|     def serialize_compile_artifacts( | ||||
|         cls, fn: "BundledAOTAutogradSerializableCallable" | ||||
|     ) -> bytes: | ||||
|         with torch._functorch.config.patch("bundled_autograd_cache", True): | ||||
|             result = pickle.dumps(fn.compiled_fn.serialize()) | ||||
|             return result | ||||
|  | ||||
|     @classmethod | ||||
|     def deserialize_compile_artifacts(cls, data: bytes) -> Any: | ||||
|         from torch._functorch._aot_autograd.autograd_cache import ( | ||||
|             deserialize_bundled_cache_entry, | ||||
|         ) | ||||
|  | ||||
|         entry = pickle.loads(data) | ||||
|  | ||||
|         compiled_fn = deserialize_bundled_cache_entry(entry) | ||||
|         return cls(compiled_fn) | ||||
|  | ||||
|     def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||||
|         return self.compiled_fn(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| def aot_compile_fullgraph( | ||||
|     model: Any, | ||||
|     example_inputs: tuple[tuple[Any, ...], dict[str, Any]], | ||||
|  | ||||
							
								
								
									
										61
									
								
								torch/_dynamo/aot_compile_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								torch/_dynamo/aot_compile_types.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | ||||
| import abc | ||||
| import pickle | ||||
| from typing import Any | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class SerializableCallable(abc.ABC): | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     def serialize_compile_artifacts(cls, fn: Any) -> bytes: | ||||
|         pass | ||||
|  | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     def deserialize_compile_artifacts(cls, data: bytes) -> Any: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class BundledAOTAutogradSerializableCallable(SerializableCallable): | ||||
|     """ | ||||
|     Represents a serializable callable generated by compile_fx. | ||||
|     This class wraps around the compiled function generated by AOTAutograd. | ||||
|  | ||||
|     TODO: Instead of using PrecompileContext to grab it from AOTAutograd, | ||||
|     this object should be what's *returned* by aot_module_simplified. | ||||
|     We'll do that refactor in a later PR. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, compiled_fn: Any) -> None: | ||||
|         """ | ||||
|         Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form | ||||
|         of a compiled function generated by AOTAutograd. | ||||
|         """ | ||||
|         assert hasattr(compiled_fn, "serialize") | ||||
|         self.compiled_fn = compiled_fn | ||||
|  | ||||
|     def __getattr__(self, attr: Any) -> Any: | ||||
|         return getattr(self.compiled_fn, attr) | ||||
|  | ||||
|     @classmethod | ||||
|     def serialize_compile_artifacts( | ||||
|         cls, fn: "BundledAOTAutogradSerializableCallable" | ||||
|     ) -> bytes: | ||||
|         with torch._functorch.config.patch("bundled_autograd_cache", True): | ||||
|             result = pickle.dumps(fn.compiled_fn.serialize()) | ||||
|             return result | ||||
|  | ||||
|     @classmethod | ||||
|     def deserialize_compile_artifacts(cls, data: bytes) -> Any: | ||||
|         from torch._functorch._aot_autograd.autograd_cache import ( | ||||
|             deserialize_bundled_cache_entry, | ||||
|         ) | ||||
|  | ||||
|         entry = pickle.loads(data) | ||||
|  | ||||
|         compiled_fn = deserialize_bundled_cache_entry(entry) | ||||
|         return cls(compiled_fn) | ||||
|  | ||||
|     def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||||
|         return self.compiled_fn(*args, **kwargs) | ||||
| @ -2188,20 +2188,6 @@ class OutputGraph(OutputGraphCommon): | ||||
|                 ), | ||||
|             ) | ||||
|             self.call_cleanup_hooks() | ||||
|             old_fake_mode = self.tracing_context.fake_mode | ||||
|             assert old_fake_mode is not None | ||||
|             if not self.export: | ||||
|                 import torch._functorch.config as _config | ||||
|  | ||||
|                 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): | ||||
|                     # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting | ||||
|                     backend_fake_mode = torch._subclasses.FakeTensorMode( | ||||
|                         shape_env=old_fake_mode.shape_env, | ||||
|                     ) | ||||
|                 # TODO(voz): Ostensibily, this should be scoped and | ||||
|                 # restore back to old_fake_mode, but doing so currently violates | ||||
|                 # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode | ||||
|                 self.tracing_context.fake_mode = backend_fake_mode | ||||
|  | ||||
|             with self.restore_global_state(): | ||||
|                 compiled_fn = self.call_user_compiler(gm, self.example_inputs()) | ||||
| @ -2237,8 +2223,10 @@ class OutputGraph(OutputGraphCommon): | ||||
|             ) | ||||
|  | ||||
|             counters["stats"]["unique_graphs"] += 1 | ||||
|             assert old_fake_mode.shape_env is not None | ||||
|             if specializations := old_fake_mode.shape_env.specializations: | ||||
|             if ( | ||||
|                 specializations | ||||
|                 := self.tracing_context.fake_mode.shape_env.specializations | ||||
|             ): | ||||
|                 specialization_guards = [] | ||||
|                 specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} | ||||
|                 sources = [a.source for a in self.graphargs] | ||||
|  | ||||
| @ -182,7 +182,7 @@ def create_subclass_meta( | ||||
| def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]: | ||||
|     # Capture all SymInts from the iterable. | ||||
|     def symint_check(s: IntLikeType) -> TypeGuard[SymInt]: | ||||
|         return isinstance(s, SymInt) and not s.node.is_nested_int() | ||||
|         return isinstance(s, SymInt) and not s.node.is_nested_int() and not s.node.expr.is_number | ||||
|  | ||||
|     return [(i, s) for i, s in enumerate(lst) if symint_check(s)] | ||||
|  | ||||
|  | ||||
| @ -391,6 +391,7 @@ def standalone_compile( | ||||
|         "from_example_inputs", "from_tracing_context", "from_graph" | ||||
|     ] = "from_graph", | ||||
|     options: Optional[dict[str, Any]] = None, | ||||
|     aot: bool = False,  # AOT mode, which uses BundledAOTAutogradCache | ||||
| ) -> CompiledArtifact: | ||||
|     """ | ||||
|     Precompilation API for inductor. | ||||
| @ -422,5 +423,5 @@ def standalone_compile( | ||||
|  | ||||
|     options = options if options else {} | ||||
|     return standalone_compile( | ||||
|         gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options | ||||
|         gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot | ||||
|     ) | ||||
|  | ||||
| @ -5,10 +5,12 @@ import logging | ||||
| import os | ||||
| import pickle | ||||
| import shutil | ||||
| from abc import ABC, abstractmethod | ||||
| from contextlib import AbstractContextManager, nullcontext | ||||
| from typing import Any, Callable, Literal, Optional, TYPE_CHECKING | ||||
|  | ||||
| import torch.fx | ||||
| from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable | ||||
| from torch._dynamo.utils import dynamo_timed | ||||
| from torch._inductor.cpp_builder import normalize_path_separator | ||||
| from torch._inductor.cudagraph_utils import BoxedDeviceIndex | ||||
| @ -30,9 +32,9 @@ if TYPE_CHECKING: | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class CompiledArtifact: | ||||
| class CompiledArtifact(ABC): | ||||
|     """ | ||||
|     CompiledArtifact class represents the precompiled inductor artifact that | ||||
|     CompiledArtifact class represents the inductor cache artifacts that | ||||
|     can be invoked in order to avoid repeated compilation. | ||||
|  | ||||
|     CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs) | ||||
| @ -45,11 +47,68 @@ class CompiledArtifact: | ||||
|     binary or unpacked data. | ||||
|  | ||||
|     Finally, the CompiledArtifact can be invoked via the __call__ method | ||||
|     to execute the precompiled artifact. | ||||
|     to execute the cached artifact. | ||||
|     """ | ||||
|  | ||||
|     _compiled_fn: Callable[..., Any] | ||||
|     _artifacts: Optional[tuple[bytes, CacheInfo]] | ||||
|     def __init__( | ||||
|         self, | ||||
|         compiled_fn: Callable[..., Any], | ||||
|         artifacts: Optional[tuple[bytes, CacheInfo]], | ||||
|     ): | ||||
|         self._compiled_fn = compiled_fn | ||||
|         self._artifacts = artifacts | ||||
|  | ||||
|     @abstractmethod | ||||
|     def __call__(self, *args: Any) -> Any: ... | ||||
|  | ||||
|     @abstractmethod | ||||
|     def save( | ||||
|         self, *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> None: ... | ||||
|  | ||||
|     @staticmethod | ||||
|     def load( | ||||
|         *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> CompiledArtifact: | ||||
|         if format == "unpacked": | ||||
|             # If format is unpacked, it must be a CacheCompiledArtifact | ||||
|             return CacheCompiledArtifact.load(path=path, format=format) | ||||
|  | ||||
|         assert format == "binary" | ||||
|         with open(path, "rb") as file: | ||||
|             from torch.utils._appending_byte_serializer import BytesReader | ||||
|  | ||||
|             from .codecache import torch_key | ||||
|  | ||||
|             result_bytes = file.read() | ||||
|             reader = BytesReader(result_bytes) | ||||
|             header = reader.read_bytes() | ||||
|             if header == AOTCompiledArtifact.AOT_HEADER: | ||||
|                 assert reader.read_bytes() == torch_key() | ||||
|                 artifact = reader.read_bytes() | ||||
|                 assert reader.is_finished() | ||||
|                 return AOTCompiledArtifact.deserialize(artifact) | ||||
|             # Otherwise, it's in the CacheCompiledArtifact format | ||||
|             elif header == CacheCompiledArtifact.CACHE_HEADER: | ||||
|                 assert reader.read_bytes() == torch_key() | ||||
|                 key = reader.read_str() | ||||
|                 artifact_bytes = reader.read_bytes() | ||||
|                 assert reader.is_finished() | ||||
|                 torch.compiler.load_cache_artifacts(artifact_bytes) | ||||
|                 return CacheCompiledArtifact._load_impl(nullcontext(), key) | ||||
|             else: | ||||
|                 raise RuntimeError( | ||||
|                     "Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: " | ||||
|                     + header.decode("utf-8") | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| class CacheCompiledArtifact(CompiledArtifact): | ||||
|     """ | ||||
|     CompiledArtifact that depends on torch.compiler.save_cache_artifacts | ||||
|     """ | ||||
|  | ||||
|     CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8") | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -83,6 +142,7 @@ class CompiledArtifact: | ||||
|                 from .codecache import torch_key | ||||
|  | ||||
|                 writer = BytesWriter() | ||||
|                 writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER) | ||||
|                 writer.write_bytes(torch_key()) | ||||
|                 writer.write_str(key) | ||||
|                 writer.write_bytes(artifact_bytes) | ||||
| @ -116,9 +176,51 @@ class CompiledArtifact: | ||||
|                             log.info("Output code written to: %s", output_file) | ||||
|  | ||||
|     @staticmethod | ||||
|     def load( | ||||
|         *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     def _load_impl( | ||||
|         cache_dir_ctx: AbstractContextManager[Any], key: str | ||||
|     ) -> CompiledArtifact: | ||||
|         with ( | ||||
|             cache_dir_ctx, | ||||
|             config.patch(unsafe_skip_cache_dynamic_shape_guards=True), | ||||
|         ): | ||||
|             with torch._functorch.config.patch(strict_autograd_cache=True): | ||||
|                 from torch._functorch._aot_autograd.autograd_cache import ( | ||||
|                     AOTAutogradCache, | ||||
|                 ) | ||||
|  | ||||
|                 result = AOTAutogradCache._lookup( | ||||
|                     key, | ||||
|                     local=True, | ||||
|                     remote=False, | ||||
|                     args=[], | ||||
|                     cache_info={}, | ||||
|                     aot_config=None, | ||||
|                 ) | ||||
|  | ||||
|             assert result is not None | ||||
|             (entry, _) = result | ||||
|  | ||||
|             from .compile_fx import _CompileFxKwargs | ||||
|  | ||||
|             fx_config = _CompileFxKwargs( | ||||
|                 cudagraphs=BoxedBool(False), | ||||
|                 boxed_forward_device_index=BoxedDeviceIndex(0), | ||||
|             ) | ||||
|  | ||||
|             context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) | ||||
|             with torch._guards.tracing(context): | ||||
|                 compiled_fn = entry.wrap_post_compile( | ||||
|                     [], entry.sanitized_aot_config, fx_config | ||||
|                 ) | ||||
|         return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _prepare_load( | ||||
|         *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> tuple[str, AbstractContextManager[Any]]: | ||||
|         """ | ||||
|         Do format specific prep and loads, return a context manager and key | ||||
|         """ | ||||
|         path = normalize_path_separator(path) | ||||
|         with dynamo_timed("CompiledArtifact.load"): | ||||
|             if format == "binary": | ||||
| @ -137,8 +239,7 @@ class CompiledArtifact: | ||||
|                 assert reader.is_finished() | ||||
|  | ||||
|                 torch.compiler.load_cache_artifacts(artifact_bytes) | ||||
|  | ||||
|                 cache_dir_ctx: AbstractContextManager[None] = nullcontext() | ||||
|                 return key, nullcontext() | ||||
|             else: | ||||
|                 assert format == "unpacked" | ||||
|                 assert os.path.isdir(path) | ||||
| @ -148,43 +249,105 @@ class CompiledArtifact: | ||||
|                 assert len(files) == 1 | ||||
|                 key = files[0] | ||||
|                 cache_dir_ctx = temporary_cache_dir(path) | ||||
|                 return key, cache_dir_ctx | ||||
|  | ||||
|             with ( | ||||
|                 cache_dir_ctx, | ||||
|                 config.patch(unsafe_skip_cache_dynamic_shape_guards=True), | ||||
|             ): | ||||
|                 with torch._functorch.config.patch(strict_autograd_cache=True): | ||||
|                     from torch._functorch._aot_autograd.autograd_cache import ( | ||||
|                         AOTAutogradCache, | ||||
|                     ) | ||||
|     @staticmethod | ||||
|     def load( | ||||
|         *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> CompiledArtifact: | ||||
|         key, cache_dir_ctx = CacheCompiledArtifact._prepare_load( | ||||
|             path=path, format=format | ||||
|         ) | ||||
|         return CacheCompiledArtifact._load_impl(cache_dir_ctx, key) | ||||
|  | ||||
|                     result = AOTAutogradCache._lookup( | ||||
|                         key, | ||||
|                         local=True, | ||||
|                         remote=False, | ||||
|                         args=[], | ||||
|                         cache_info={}, | ||||
|                         aot_config=None, | ||||
|                     ) | ||||
|  | ||||
|                 assert result is not None | ||||
|                 (entry, _) = result | ||||
| class AOTCompiledArtifact(CompiledArtifact): | ||||
|     """ | ||||
|     Similar to CompiledArtifact, but the object is a single, bundled precompiled function. | ||||
|     This object is always a serializable callable function. | ||||
|  | ||||
|                 from .compile_fx import _CompileFxKwargs | ||||
|     This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which | ||||
|     is used by torch._dynamo.aot_compile for AOT Precompilation. | ||||
|     """ | ||||
|  | ||||
|                 fx_config = _CompileFxKwargs( | ||||
|                     cudagraphs=BoxedBool(False), | ||||
|                     boxed_forward_device_index=BoxedDeviceIndex(0), | ||||
|                 ) | ||||
|     AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8") | ||||
|  | ||||
|                 context = torch._guards.TracingContext( | ||||
|                     FakeTensorMode(shape_env=ShapeEnv()) | ||||
|                 ) | ||||
|                 with torch._guards.tracing(context): | ||||
|                     compiled_fn = entry.wrap_post_compile( | ||||
|                         [], entry.sanitized_aot_config, fx_config | ||||
|                     ) | ||||
|             return CompiledArtifact(lambda *args: compiled_fn(list(args)), None) | ||||
|     def __init__( | ||||
|         self, | ||||
|         compiled_fn: Callable[..., Any], | ||||
|     ): | ||||
|         self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn) | ||||
|         self._artifacts = ( | ||||
|             None  # We don't need artifacts, the inner object handles everything | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_bundled_callable( | ||||
|         bundled_fn: BundledAOTAutogradSerializableCallable, | ||||
|     ) -> AOTCompiledArtifact: | ||||
|         return AOTCompiledArtifact(bundled_fn.compiled_fn) | ||||
|  | ||||
|     def __call__(self, *args: Any) -> Any: | ||||
|         return self.inner_fn(*args) | ||||
|  | ||||
|     def save( | ||||
|         self, *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> None: | ||||
|         if format == "unpacked": | ||||
|             raise RuntimeError( | ||||
|                 "AOTCompiledArtifact does not support unpacked format yet" | ||||
|             ) | ||||
|         result_bytes = self.serialize() | ||||
|         from torch.utils._appending_byte_serializer import BytesWriter | ||||
|  | ||||
|         from .codecache import torch_key | ||||
|  | ||||
|         writer = BytesWriter() | ||||
|         writer.write_bytes(AOTCompiledArtifact.AOT_HEADER) | ||||
|         writer.write_bytes(torch_key()) | ||||
|         writer.write_bytes(result_bytes) | ||||
|  | ||||
|         from torch._inductor.codecache import write_atomic | ||||
|  | ||||
|         # Save a sentinel file to indicate that this is AOT | ||||
|         write_atomic(path, writer.to_bytes()) | ||||
|  | ||||
|     def serialize(self) -> bytes: | ||||
|         return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts( | ||||
|             self.inner_fn | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def deserialize(result_bytes: bytes) -> AOTCompiledArtifact: | ||||
|         deserialized = ( | ||||
|             BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts( | ||||
|                 result_bytes | ||||
|             ) | ||||
|         ) | ||||
|         assert isinstance(deserialized, BundledAOTAutogradSerializableCallable) | ||||
|         return AOTCompiledArtifact.from_bundled_callable(deserialized) | ||||
|  | ||||
|     @staticmethod | ||||
|     def load( | ||||
|         *, path: str, format: Literal["binary", "unpacked"] = "binary" | ||||
|     ) -> CompiledArtifact: | ||||
|         if format == "unpacked": | ||||
|             raise RuntimeError( | ||||
|                 "AOTCompiledArtifact does not support unpacked format yet" | ||||
|             ) | ||||
|         with open(path, "rb") as file: | ||||
|             from torch.utils._appending_byte_serializer import BytesReader | ||||
|  | ||||
|             from .codecache import torch_key | ||||
|  | ||||
|             result_bytes = file.read() | ||||
|             reader = BytesReader(result_bytes) | ||||
|             header = reader.read_bytes() | ||||
|             assert header == AOTCompiledArtifact.AOT_HEADER | ||||
|             assert reader.read_bytes() == torch_key() | ||||
|             artifact = reader.read_bytes() | ||||
|             assert reader.is_finished() | ||||
|             return AOTCompiledArtifact.deserialize(artifact) | ||||
|  | ||||
|  | ||||
| def standalone_compile( | ||||
| @ -193,7 +356,11 @@ def standalone_compile( | ||||
|     *, | ||||
|     dynamic_shapes: Any, | ||||
|     options: Any, | ||||
|     aot: bool = False,  # AOT mode, which uses BundledAOTAutogradCache | ||||
| ) -> CompiledArtifact: | ||||
|     """ | ||||
|     Implementation of torch.inductor.standalone_compile | ||||
|     """ | ||||
|     from torch.compiler._cache import CacheArtifactManager | ||||
|  | ||||
|     from .compile_fx import compile_fx | ||||
| @ -249,6 +416,7 @@ def standalone_compile( | ||||
|         torch._guards.tracing(context), | ||||
|         CacheArtifactManager.with_fresh_cache(), | ||||
|         config.patch("triton.autotune_at_compile_time", True), | ||||
|         torch._functorch.config.patch("bundled_autograd_cache", aot), | ||||
|     ): | ||||
|         # compile_fx can mutate gm | ||||
|         gm = copy.deepcopy(gm) | ||||
| @ -256,7 +424,12 @@ def standalone_compile( | ||||
|             gm, example_inputs, ignore_shape_env=ignore_shape_env, **options | ||||
|         ) | ||||
|         assert callable(compiled_fn) | ||||
|  | ||||
|         if aot: | ||||
|             if not hasattr(compiled_fn, "serialize"): | ||||
|                 raise RuntimeError( | ||||
|                     "Compiled function should have serialize method when aot=True" | ||||
|                 ) | ||||
|             return AOTCompiledArtifact(compiled_fn) | ||||
|         artifacts = torch.compiler.save_cache_artifacts() | ||||
|         if artifacts is None: | ||||
|             log.warning( | ||||
| @ -264,4 +437,4 @@ def standalone_compile( | ||||
|                 "Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem" | ||||
|             ) | ||||
|  | ||||
|     return CompiledArtifact(compiled_fn, artifacts) | ||||
|     return CacheCompiledArtifact(compiled_fn, artifacts) | ||||
|  | ||||
| @ -35,7 +35,7 @@ from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedte | ||||
| from torch._library.fake_class_registry import FakeScriptObject | ||||
| from torch._library.fake_profile import MissingOpProfile | ||||
| from torch._logging import dtrace_structured | ||||
| from torch._prims_common import suggest_memory_format | ||||
| from torch._prims_common import check_contiguous_sizes_strides, suggest_memory_format | ||||
| from torch._subclasses.meta_utils import ( | ||||
|     assert_eq, | ||||
|     assert_metadata_eq, | ||||
| @ -1072,21 +1072,46 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata: | ||||
|     Extract the TensorMetadata of a tensor. | ||||
|     """ | ||||
|     memory_format = suggest_memory_format(t) | ||||
|     # Don't call is_contiguous() on a Tensor which has symbolic sizes or things | ||||
|     # will go badly (guards will be messed up?) | ||||
|     if ( | ||||
|         t._has_symbolic_sizes_strides | ||||
|         or is_sparse_any(t) | ||||
|         or not t.is_contiguous(memory_format=memory_format) | ||||
|     ): | ||||
|     shape = tuple(t.shape) | ||||
|     stride = tuple(t.stride()) if t.layout == torch.strided else () | ||||
|  | ||||
|     if is_sparse_any(t): | ||||
|         is_contiguous = False | ||||
|     else: | ||||
|         if t._has_symbolic_sizes_strides: | ||||
|             still_has_symbolic_sizes_strides = False | ||||
|  | ||||
|             def simplify(x: IntLikeType) -> IntLikeType: | ||||
|                 if not isinstance(x, SymInt): | ||||
|                     return x | ||||
|                 value = x.node.expr | ||||
|                 if value.is_number: | ||||
|                     return int(value) | ||||
|                 nonlocal still_has_symbolic_sizes_strides | ||||
|                 still_has_symbolic_sizes_strides = True | ||||
|                 return x | ||||
|  | ||||
|             shape = tuple(simplify(x) for x in shape) | ||||
|             stride = tuple(simplify(x) for x in stride) | ||||
|  | ||||
|             if still_has_symbolic_sizes_strides: | ||||
|                 # Don't call is_contiguous() on a Tensor which has symbolic sizes or things | ||||
|                 # will go badly (guards will be messed up?) | ||||
|                 is_contiguous = False | ||||
|             else: | ||||
|                 is_contiguous = check_contiguous_sizes_strides(shape, stride, True) | ||||
|         else: | ||||
|             is_contiguous = t.is_contiguous(memory_format=memory_format) | ||||
|  | ||||
|     if not is_contiguous: | ||||
|         memory_format = None  # type: ignore[assignment] | ||||
|  | ||||
|     storage_offset = t.storage_offset() | ||||
|  | ||||
|     return TensorMetadata( | ||||
|         t.dtype, | ||||
|         t.shape, | ||||
|         t.stride() if t.layout == torch.strided else (), | ||||
|         shape, | ||||
|         stride, | ||||
|         t.device, | ||||
|         t.layout, | ||||
|         memory_format, | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #include <torch/csrc/jit/runtime/logging.h> | ||||
|  | ||||
| #include <c10/util/Exception.h> | ||||
| #include <atomic> | ||||
| #include <chrono> | ||||
| #include <mutex> | ||||
| @ -33,7 +34,7 @@ int64_t LockingLogger::getCounterValue(const std::string& name) const { | ||||
|       return raw_counter.sum / raw_counter.count; | ||||
|     } break; | ||||
|   } | ||||
|   throw std::runtime_error("Unknown aggregation type!"); | ||||
|   TORCH_CHECK(false, "Unknown aggregation type!"); | ||||
| } | ||||
|  | ||||
| void LockingLogger::setAggregationType( | ||||
|  | ||||
| @ -11,6 +11,7 @@ | ||||
| #include <torch/csrc/jit/runtime/slice_indices_adjust.h> | ||||
| #include <limits> | ||||
|  | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/irange.h> | ||||
|  | ||||
| namespace torch::jit { | ||||
| @ -112,20 +113,17 @@ void listRemove<at::Tensor>(Stack& stack) { | ||||
| } | ||||
|  | ||||
| void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) { | ||||
|   if (t.requires_grad()) { | ||||
|     throw std::runtime_error( | ||||
|         "Cannot input a tensor that requires grad as a scalar argument"); | ||||
|   } | ||||
|   if (!t.sizes().empty()) { | ||||
|     throw std::runtime_error( | ||||
|         "Cannot input a tensor of dimension other than 0 as a scalar argument"); | ||||
|   } | ||||
|   if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) { | ||||
|     std::stringstream ss; | ||||
|     ss << "Cannot input a tensor of type " << t.scalar_type() | ||||
|        << " as an integral argument"; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|   TORCH_CHECK( | ||||
|       !t.requires_grad(), | ||||
|       "Cannot input a tensor that requires grad as a scalar argument"); | ||||
|   TORCH_CHECK( | ||||
|       t.sizes().empty(), | ||||
|       "Cannot input a tensor of dimension other than 0 as a scalar argument"); | ||||
|   TORCH_CHECK( | ||||
|       !toInt || isIntegralType(t.scalar_type(), /*includeBool=*/false), | ||||
|       "Cannot input a tensor of type ", | ||||
|       t.scalar_type(), | ||||
|       " as an integral argument"); | ||||
| } | ||||
|  | ||||
| void checkDoubleInRange(double a) { | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #include <ATen/autocast_mode.h> | ||||
| #include <ATen/core/Generator.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <torch/csrc/jit/mobile/promoted_prim_ops.h> | ||||
| #include <torch/csrc/jit/runtime/custom_operator.h> | ||||
| @ -159,9 +160,8 @@ void sort_op(Stack& stack) { | ||||
|  | ||||
|   if (!g_list.empty()) { | ||||
|     std::stringstream error_str; | ||||
|     if (!isSortableListOfObjectsOrTuples(g_list, error_str)) { | ||||
|       throw std::runtime_error(error_str.str()); | ||||
|     } | ||||
|     TORCH_CHECK( | ||||
|         isSortableListOfObjectsOrTuples(g_list, error_str), error_str.str()); | ||||
|  | ||||
|     c10::IValueComparator comparator; | ||||
|     if (reverse) { | ||||
| @ -254,9 +254,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{ | ||||
|           int64_t lo = 0, hi = 0, step = 0; | ||||
|           pop(stack, lo, hi, step); | ||||
|           // error handling when step_val = 0 during runtime | ||||
|           if (step == 0) { | ||||
|             throw std::runtime_error("range() arg 3 must not be zero"); | ||||
|           } | ||||
|           TORCH_CHECK(step != 0, "range() arg 3 must not be zero"); | ||||
|           if (step > 0 && lo < hi) { | ||||
|             push(stack, 1 + (hi - 1 - lo) / step); | ||||
|           } else if (step < 0 && lo > hi) { | ||||
| @ -382,14 +380,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{ | ||||
|           auto s = pop(stack).toString(); | ||||
|           std::string::size_type sz = 0; | ||||
|           int64_t val = static_cast<int64_t>(std::stoll(s->string(), &sz)); | ||||
|           if (sz == s->string().size()) { | ||||
|             push(stack, val); | ||||
|           } else { | ||||
|             std::stringstream error_str; | ||||
|             error_str << "invalid literal for int() " | ||||
|                       << "with base 10: '" << s->string() << "'"; | ||||
|             throw std::runtime_error(error_str.str()); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               sz == s->string().size(), | ||||
|               "invalid literal for int() ", | ||||
|               "with base 10: '", | ||||
|               s->string(), | ||||
|               "'"); | ||||
|           push(stack, val); | ||||
|         }, | ||||
|         aliasAnalysisFromSchema()), | ||||
|     OperatorGeneratorArgs( | ||||
| @ -436,14 +433,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{ | ||||
|           auto s = pop(stack).toString(); | ||||
|           std::string::size_type sz = 0; | ||||
|           double b = std::stod(s->string(), &sz); | ||||
|           if (sz == s->string().size()) { | ||||
|             push(stack, b); | ||||
|           } else { | ||||
|             std::stringstream error_str; | ||||
|             error_str << "could not convert string " | ||||
|                       << "to float: '" << s->string() << "'"; | ||||
|             throw std::runtime_error(error_str.str()); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               sz == s->string().size(), | ||||
|               "could not convert string ", | ||||
|               "to float: '", | ||||
|               s->string(), | ||||
|               "'"); | ||||
|           push(stack, b); | ||||
|         }, | ||||
|         aliasAnalysisFromSchema()), | ||||
|     OperatorGeneratorArgs( | ||||
| @ -1793,10 +1789,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           } | ||||
|  | ||||
|           const std::string& separator = ivalue.toStringRef(); | ||||
|  | ||||
|           if (separator.empty()) { | ||||
|             throw std::runtime_error("ValueError: empty separator"); | ||||
|           } | ||||
|           TORCH_CHECK(!separator.empty(), "ValueError: empty separator"); | ||||
|  | ||||
|           auto count = 0; | ||||
|  | ||||
| @ -1919,11 +1912,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           std::string fillchar = pop(stack).toStringRef(); | ||||
|           int64_t width = pop(stack).toInt(); | ||||
|           std::string string = pop(stack).toStringRef(); | ||||
|           if (fillchar.size() != 1) { | ||||
|             // TODO: this should be a TypeError | ||||
|             throw std::runtime_error( | ||||
|                 "TypeError: The fill character must be exactly one character long"); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               fillchar.size() == 1, | ||||
|               "TypeError: The fill character must be exactly one character long"); | ||||
|           if (string.size() > static_cast<std::string::size_type>(width)) { | ||||
|             push(stack, string); | ||||
|             return; | ||||
| @ -2092,9 +2083,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           std::string substr = pop(stack).toStringRef(); | ||||
|           std::string string = pop(stack).toStringRef(); | ||||
|           auto result = stringFindImpl(string, substr, start, end); | ||||
|           if (result < 0) { | ||||
|             throw std::runtime_error("ValueError: substring not found"); | ||||
|           } | ||||
|           TORCH_CHECK(result >= 0, "ValueError: substring not found"); | ||||
|           push(stack, result); | ||||
|         }, | ||||
|         aliasAnalysisFromSchema()), | ||||
| @ -2107,9 +2096,7 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           std::string substr = pop(stack).toStringRef(); | ||||
|           std::string string = pop(stack).toStringRef(); | ||||
|           auto result = stringFindImpl(string, substr, start, end, true); | ||||
|           if (result < 0) { | ||||
|             throw std::runtime_error("ValueError: substring not found"); | ||||
|           } | ||||
|           TORCH_CHECK(result >= 0, "ValueError: substring not found"); | ||||
|           push(stack, result); | ||||
|         }, | ||||
|         aliasAnalysisFromSchema()), | ||||
| @ -2183,11 +2170,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           std::string fillchar = pop(stack).toStringRef(); | ||||
|           int64_t width = pop(stack).toInt(); | ||||
|           std::string string = pop(stack).toStringRef(); | ||||
|           if (fillchar.size() != 1) { | ||||
|             // TODO: this should be a TypeError | ||||
|             throw std::runtime_error( | ||||
|                 "TypeError: The fill character must be exactly one character long"); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               fillchar.size() == 1, | ||||
|               "TypeError: The fill character must be exactly one character long"); | ||||
|           auto to_append = | ||||
|               std::max(int64_t(0), width - static_cast<int64_t>(string.size())); | ||||
|  | ||||
| @ -2207,11 +2192,9 @@ static const std::vector<OperatorGeneratorArgs> stringOpGenArgs{ | ||||
|           std::string fillchar = pop(stack).toStringRef(); | ||||
|           int64_t width = pop(stack).toInt(); | ||||
|           std::string string = pop(stack).toStringRef(); | ||||
|           if (fillchar.size() != 1) { | ||||
|             // TODO: this should be a TypeError | ||||
|             throw std::runtime_error( | ||||
|                 "TypeError: The fill character must be exactly one character long"); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               fillchar.size() == 1, | ||||
|               "TypeError: The fill character must be exactly one character long"); | ||||
|           auto to_append = | ||||
|               std::max(int64_t(0), width - static_cast<int64_t>(string.size())); | ||||
|  | ||||
| @ -3358,10 +3341,8 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{ | ||||
|           int64_t a = 0, b = 0; | ||||
|           lldiv_t divresult = {}; | ||||
|           pop(stack, a, b); | ||||
|           if (b == 0) { | ||||
|             throw std::runtime_error( | ||||
|                 "ZeroDivisionError: integer division or modulo by zero"); | ||||
|           } | ||||
|           TORCH_CHECK( | ||||
|               b != 0, "ZeroDivisionError: integer division or modulo by zero"); | ||||
|           divresult = lldiv(a, b); | ||||
|           if (divresult.rem && (a < 0) != (b < 0)) { | ||||
|             divresult.quot -= 1; | ||||
| @ -3379,9 +3360,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{ | ||||
|         [](Stack& stack) { | ||||
|           double a = 0, b = 0; | ||||
|           pop(stack, a, b); | ||||
|           if (b == 0) { | ||||
|             throw std::runtime_error("ZeroDivisionError: float divmod()"); | ||||
|           } | ||||
|           TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()"); | ||||
|           double rem = fmod(a, b); | ||||
|           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) | ||||
|           if (rem && (a < 0) != (b < 0)) { | ||||
| @ -3426,9 +3405,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs2{ | ||||
|         type_a a;                                                            \ | ||||
|         type_b b;                                                            \ | ||||
|         pop(stack, a, b);                                                    \ | ||||
|         if (b == 0) {                                                        \ | ||||
|           throw std::runtime_error("ZeroDivisionError: float divmod()");     \ | ||||
|         }                                                                    \ | ||||
|         TORCH_CHECK(b != 0, "ZeroDivisionError: float divmod()");            \ | ||||
|         double quot = floor(a / b);                                          \ | ||||
|         double rem = a - (quot * b);                                         \ | ||||
|         push(stack, quot, rem);                                              \ | ||||
|  | ||||
| @ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix): | ||||
|  | ||||
|  | ||||
| def _compile_submod(gm, prefix): | ||||
|     from torch._inductor.standalone_compile import AOTCompiledArtifact | ||||
|  | ||||
|     for node in gm.graph.nodes: | ||||
|         if node.op == "call_module" and node.target.startswith(prefix): | ||||
|             fake_inputs = [] | ||||
| @ -56,13 +58,12 @@ def _compile_submod(gm, prefix): | ||||
|  | ||||
|             submod = getattr(gm, node.target) | ||||
|  | ||||
|             # _dummy_wrapper is to make call_function happy | ||||
|             compiled_submod = _dummy_wrapper( | ||||
|                 torch._inductor.standalone_compile( | ||||
|                     submod, fake_inputs, dynamic_shapes="from_tracing_context" | ||||
|                 ) | ||||
|             compiled_fn = torch._inductor.standalone_compile( | ||||
|                 submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True | ||||
|             ) | ||||
|  | ||||
|             assert isinstance(compiled_fn, AOTCompiledArtifact) | ||||
|             # _dummy_wrapper is to make call_function happy | ||||
|             compiled_submod = _dummy_wrapper(compiled_fn) | ||||
|             with gm.graph.inserting_after(node): | ||||
|                 new_node = gm.graph.call_function( | ||||
|                     compiled_submod, args=node.args, kwargs=node.kwargs | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	