mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[PT2][fusion] ban fusions with large accumulated reads (#157563)
**Problem:**
Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet
```
total = torch.rand(N, N)
for _ in range(r):
    x = torch.rand(N, N)
    total = total + x
```
The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like:
```
x_1 = torch.rand(N, N)
x_2 =  torch.rand(N, N)
...
x_r = torch.rand(N, N)
total = x_1 + x_2 + ... + x_r
```
Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient.
[internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details
**Solution:**
Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile.
* During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated.
* During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`.
**Results:**
For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match.
<img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563
Approved by: https://github.com/jansel, https://github.com/mlazos
			
			
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							9345279c6e
						
					
				
				
					commit
					c062550a35
				
			| @ -1,89 +1,23 @@ | ||||
| add_loop_eager,compile_time_instruction_count,3017000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_eager,compile_time_instruction_count,2996000000,0.015 | ||||
| add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor,compile_time_instruction_count,29490000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor,compile_time_instruction_count,33090000000,0.015 | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025 | ||||
| add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015 | ||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015 | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015 | ||||
| basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2 | ||||
| update_hint_regression,compile_time_instruction_count,1673000000,0.02 | ||||
|  | ||||
|  | ||||
|  | ||||
| sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| symint_sum,compile_time_instruction_count,3166000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| symint_sum,compile_time_instruction_count,3184000000,0.015 | ||||
| symint_sum_loop,compile_time_instruction_count,4202000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015 | ||||
| mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015 | ||||
| basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 | ||||
|  | ||||
| 
 | 
| @ -306,6 +306,57 @@ class TestOperatorReorderForPeakMemory(TestCase): | ||||
|         expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 | ||||
|         self.assertLess(peak_mem, expected_bound) | ||||
|  | ||||
|     def test_fusion_acc_large_reads(self): | ||||
|         def f(x, y, z): | ||||
|             res = torch.zeros_like(x[0]) | ||||
|             for i in range(4): | ||||
|                 temp = torch.matmul(x, y) + z | ||||
|                 res = res + temp | ||||
|             return res | ||||
|  | ||||
|         N = 128 | ||||
|         x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) | ||||
|         y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) | ||||
|         z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) | ||||
|  | ||||
|         # CASE 1: no restriction on the amount of accumulation | ||||
|         with config.patch({"realize_acc_reads_size_threshold": float("inf")}): | ||||
|             f_compiled = torch.compile(f) | ||||
|             code = run_and_get_triton_code(f_compiled, x, y, z) | ||||
|             ( | ||||
|                 FileCheck() | ||||
|                 .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") | ||||
|                 .run(code) | ||||
|             ) | ||||
|  | ||||
|         # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) | ||||
|         # at most 12 / 4 = 3 reads can be accumulated during fusion | ||||
|         with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): | ||||
|             f_compiled = torch.compile(f) | ||||
|             code = run_and_get_triton_code(f_compiled, x, y, z) | ||||
|             ( | ||||
|                 FileCheck() | ||||
|                 .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") | ||||
|                 .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") | ||||
|                 .run(code) | ||||
|             ) | ||||
|  | ||||
|         # CASE 3: no such fusion allowed | ||||
|         with config.patch({"realize_acc_reads_size_threshold": N**2}): | ||||
|             f_compiled = torch.compile(f) | ||||
|             code = run_and_get_triton_code(f_compiled, x, y, z) | ||||
|             ( | ||||
|                 FileCheck() | ||||
|                 .check("triton_poi_fused_add_0.run(buf1, arg2_1,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf3, arg2_1,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf4, buf3,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf6, arg2_1,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf7, buf6,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf9, arg2_1,") | ||||
|                 .check("triton_poi_fused_add_0.run(buf10, buf9,") | ||||
|                 .run(code) | ||||
|             ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     from torch._inductor.test_case import run_tests | ||||
|  | ||||
| @ -13,6 +13,7 @@ from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
|     IS_LINUX, | ||||
|     parametrize, | ||||
|     serialTest, | ||||
| ) | ||||
| from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA | ||||
|  | ||||
| @ -77,12 +78,17 @@ class TestOnlineSoftmax(TestCase): | ||||
|         out, source_codes = run_and_get_code(f, x) | ||||
|         return source_codes[0] | ||||
|  | ||||
|     @serialTest() | ||||
|     def test_codegen_3pass_softmax_due_to_disable(self): | ||||
|         with inductor_config.patch(online_softmax=False): | ||||
|         with inductor_config.patch( | ||||
|             online_softmax=False, | ||||
|             realize_acc_reads_size_threshold=float("inf"), | ||||
|         ): | ||||
|             wrapper_code = self.get_softmax_wrapper() | ||||
|  | ||||
|         self.assertEqual(wrapper_code.count("for r0_offset in"), 3) | ||||
|  | ||||
|     @serialTest() | ||||
|     @parametrize("V", [2048, 50304]) | ||||
|     @parametrize("use_log_softmax", [False, True]) | ||||
|     def test_codegen_online_softmax(self, use_log_softmax, V): | ||||
|  | ||||
| @ -365,6 +365,10 @@ class InductorChoices: | ||||
|             WhyNoFuse(node1, node2)("Fusion will increase peak memory") | ||||
|             return False | ||||
|  | ||||
|         if scheduler.fusion_accumulate_large_reads(node1, node2): | ||||
|             WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") | ||||
|             return False | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     @staticmethod | ||||
|  | ||||
| @ -574,6 +574,7 @@ realize_opcount_threshold = 30 | ||||
|  | ||||
| # Threshold to prevent excessive accumulation of ops in one buffer during lowering | ||||
| realize_acc_reads_threshold = 8 | ||||
| realize_acc_reads_size_threshold = 3 * (1024**3) | ||||
|  | ||||
| # fallback to eager for random/dropout, this is slow but useful for debugging | ||||
| fallback_random = False | ||||
|  | ||||
| @ -123,6 +123,7 @@ if TYPE_CHECKING: | ||||
|     from torch.fx.graph import Graph | ||||
|  | ||||
|     from .codegen.wrapper import PythonWrapperCodegen | ||||
|     from .dependencies import Dep | ||||
|     from .scheduler import BaseSchedulerNode | ||||
|  | ||||
|     CompiledModule = Union[ModuleType, FileBackedGraphModule] | ||||
| @ -485,6 +486,9 @@ class GraphLowering(torch.fx.Interpreter): | ||||
|  | ||||
|         self.bw_donated_idxs = get_donated_idxs() | ||||
|  | ||||
|         # Cache for dep size hints to avoid expensive recomputation | ||||
|         self.dep_size_hint_cache: dict[Dep, int] = {} | ||||
|  | ||||
|     def freeze_runtime_asserts(self) -> None: | ||||
|         self._shape_env.freeze_runtime_asserts() | ||||
|  | ||||
| @ -570,6 +574,23 @@ class GraphLowering(torch.fx.Interpreter): | ||||
|         assert isinstance(feature, BackendFeature), feature | ||||
|         return feature in self.get_backend_features(get_device_type(device)) | ||||
|  | ||||
|     def get_dep_size_hint(self, dep: Dep) -> int: | ||||
|         """ | ||||
|         Get the size hint for a dependency with caching to avoid expensive recomputation. | ||||
|         """ | ||||
|         if dep not in self.dep_size_hint_cache: | ||||
|             res = 0 | ||||
|             try: | ||||
|                 if not dep.has_unbacked_symbols(): | ||||
|                     res = dep.numbytes_hint() | ||||
|             except KeyError: | ||||
|                 # In at least one test (test/inductor/test_torchbind.py) we | ||||
|                 # create a StarDep that doesn't exist in the graph and calling | ||||
|                 # `has_unbacked_symbols()` throws an error. | ||||
|                 pass | ||||
|             self.dep_size_hint_cache[dep] = res | ||||
|         return self.dep_size_hint_cache[dep] | ||||
|  | ||||
|     def get_current_device_or_throw(self) -> torch.device: | ||||
|         if device := self.current_device: | ||||
|             return device | ||||
|  | ||||
| @ -7829,6 +7829,10 @@ class TensorBox(MutableBox): | ||||
|  | ||||
|  | ||||
| class StorageBox(MutableBox): | ||||
|     """ | ||||
|     StorageBox allow in-place mutation of Tensors | ||||
|     """ | ||||
|  | ||||
|     def is_input_buffer(self) -> bool: | ||||
|         if isinstance(self.data, (InputBuffer, ReinterpretView)): | ||||
|             return self.data.get_name() in V.graph.graph_inputs | ||||
| @ -7878,10 +7882,17 @@ class StorageBox(MutableBox): | ||||
|         ): | ||||
|             self.realize() | ||||
|  | ||||
|     def has_accumulated_enough_reads_by_size(self) -> bool: | ||||
|         return ( | ||||
|             sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) | ||||
|             > config.realize_acc_reads_size_threshold | ||||
|         ) | ||||
|  | ||||
|     def has_exceeded_max_reads(self) -> bool: | ||||
|         return isinstance(self.data, Pointwise) and ( | ||||
|             self.num_reads() > config.realize_acc_reads_threshold | ||||
|             or self.has_large_inner_fn() | ||||
|             or self.has_accumulated_enough_reads_by_size() | ||||
|         ) | ||||
|  | ||||
|     def should_realize_on_reuse(self, users: int) -> bool: | ||||
|  | ||||
| @ -78,19 +78,8 @@ def get_freeable_input_buf( | ||||
|         A dictionary containing all freeble input buffers, keyed by their names. | ||||
|     """ | ||||
|  | ||||
|     # this function is copied from torch/_inductor/scheduler.py | ||||
|     # TODO: would be nice to remove the try/except block for both places | ||||
|     def _dep_size_hint(dep: Dep) -> int: | ||||
|         res = 0 | ||||
|         try: | ||||
|             if not dep.has_unbacked_symbols(): | ||||
|                 res = dep.numbytes_hint() | ||||
|         except KeyError: | ||||
|             # In at least one test (test/inductor/test_torchbind.py) we | ||||
|             # create a StarDep that doesn't exist in the graph and calling | ||||
|             # `has_unbacked_symbols()` throws an error. | ||||
|             pass | ||||
|         return res | ||||
|         return V.graph.get_dep_size_hint(dep) | ||||
|  | ||||
|     # get freeable input buffers' successor nodes and their sizes | ||||
|     # note that different deps can have the same name, so we use name as keys | ||||
|  | ||||
| @ -2051,15 +2051,12 @@ class Scheduler: | ||||
|     optimizations such as fusion, reorder, and graph partition. | ||||
|     """ | ||||
|  | ||||
|     __dep_size_hint_cache: dict[Dep, int] | ||||
|  | ||||
|     def __init__(self, nodes: list[ir.Operation]) -> None: | ||||
|         with dynamo_timed("Scheduler.__init__"): | ||||
|             self._init(nodes) | ||||
|  | ||||
|     def _init(self, nodes: list[ir.Operation]) -> None: | ||||
|         super().__init__() | ||||
|         self.__dep_size_hint_cache = {} | ||||
|         V.graph.scheduler = self | ||||
|         self.backends: dict[torch.device, BaseScheduling] = {} | ||||
|         self.post_grad_graph_id = next(_post_grad_graph_counter) | ||||
| @ -3505,6 +3502,17 @@ class Scheduler: | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     def fusion_accumulate_large_reads( | ||||
|         self, node1: BaseSchedulerNode, node2: BaseSchedulerNode | ||||
|     ) -> bool: | ||||
|         all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( | ||||
|             node1.read_writes.writes | node2.read_writes.writes | ||||
|         ) | ||||
|         return ( | ||||
|             sum(self.dep_size_hint(dep) for dep in all_reads) | ||||
|             > config.realize_acc_reads_size_threshold | ||||
|         ) | ||||
|  | ||||
|     def are_long_distant_nodes( | ||||
|         self, node1: BaseSchedulerNode, node2: BaseSchedulerNode | ||||
|     ) -> bool: | ||||
| @ -4010,20 +4018,7 @@ class Scheduler: | ||||
|         return False | ||||
|  | ||||
|     def dep_size_hint(self, dep: Dep) -> int: | ||||
|         res = 0 | ||||
|         if dep not in self.__dep_size_hint_cache: | ||||
|             try: | ||||
|                 if not dep.has_unbacked_symbols(): | ||||
|                     res = dep.numbytes_hint() | ||||
|             except KeyError: | ||||
|                 # In at least one test (test/inductor/test_torchbind.py) we | ||||
|                 # create a StarDep that doesn't exist in the graph and calling | ||||
|                 # `has_unbacked_symbols()` throws an error. | ||||
|                 pass | ||||
|             self.__dep_size_hint_cache[dep] = res | ||||
|         else: | ||||
|             res = self.__dep_size_hint_cache[dep] | ||||
|         return res | ||||
|         return V.graph.get_dep_size_hint(dep) | ||||
|  | ||||
|     def score_fusion_memory( | ||||
|         self, node1: BaseSchedulerNode, node2: BaseSchedulerNode | ||||
|  | ||||
		Reference in New Issue
	
	Block a user