mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	**Problem:**
Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet
```
total = torch.rand(N, N)
for _ in range(r):
    x = torch.rand(N, N)
    total = total + x
```
The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like:
```
x_1 = torch.rand(N, N)
x_2 =  torch.rand(N, N)
...
x_r = torch.rand(N, N)
total = x_1 + x_2 + ... + x_r
```
Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient.
[internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details
**Solution:**
Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile.
* During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated.
* During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`.
**Results:**
For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match.
<img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563
Approved by: https://github.com/jansel, https://github.com/mlazos
		
	
		
			
				
	
	
		
			366 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			366 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: inductor"]
 | |
| import unittest
 | |
| from unittest import mock
 | |
| 
 | |
| import torch
 | |
| from torch._C import FileCheck
 | |
| from torch._dynamo.utils import same
 | |
| from torch._inductor import config, memory
 | |
| from torch._inductor.test_case import TestCase
 | |
| from torch._inductor.utils import run_and_get_triton_code
 | |
| from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
 | |
| 
 | |
| 
 | |
| try:
 | |
|     import triton
 | |
|     from triton import language as tl
 | |
| 
 | |
|     TRITON_AVAILABLE = True
 | |
| except ImportError:
 | |
|     TRITON_AVAILABLE = False
 | |
| 
 | |
| 
 | |
| class Foo(torch.nn.Module):
 | |
|     """
 | |
|     The default compiled graph is
 | |
|     graph():
 | |
|         ...
 | |
|         %op0 : [num_users=2] = call_function[...](args = (%primals_2, %primals_1), ...)
 | |
|         %op1 : [num_users=2] = call_function[...](args = (%primals_2, %primals_3), ...)
 | |
|         %op2 : [num_users=1] = call_function[...](args = (%op0, %primals_4), ...)
 | |
|         %op3 : [num_users=1] = call_function[...](args = (%op1, %primals_5), ...)
 | |
|         %op4 : [num_users=1] = call_function[...](args = (%op2,), ...)
 | |
|         %op5 : [num_users=1] = call_function[...](args = (%op3,), ...)
 | |
|         %op6_op7 : [num_users=1] = call_function[...](args = (%op5, %op4), ...)
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.w1 = torch.nn.Parameter(torch.ones(1, 10))
 | |
|         self.w2 = torch.nn.Parameter(torch.ones(1, 1))
 | |
|         self.w3 = torch.nn.Parameter(torch.ones(10, 1))
 | |
|         self.w4 = torch.nn.Parameter(torch.ones(1, 10))
 | |
| 
 | |
|     def forward(self, x):
 | |
|         t1 = torch.matmul(x, self.w1)
 | |
|         t2 = torch.matmul(x, self.w2)
 | |
|         t3 = torch.matmul(t1, self.w3)
 | |
|         t4 = torch.matmul(t2, self.w4)
 | |
|         return t3.sum() + t4.sum()
 | |
| 
 | |
| 
 | |
| # The tests in this class uses very small tensors. The default
 | |
| # score_fusion_memory threshold will cause different fusion decisions and
 | |
| # generate a different wrapper. Override the threshold to make these tests
 | |
| # happy.
 | |
| @config.patch("score_fusion_memory_threshold", 1)
 | |
| class TestOperatorReorderForPeakMemory(TestCase):
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
| 
 | |
|         self.model = Foo().to(GPU_TYPE)
 | |
|         self.inputs = torch.ones((2048, 1), device=GPU_TYPE)
 | |
|         self.orig_reorder_method = memory.reorder_for_peak_memory
 | |
| 
 | |
|     @mock.patch.object(config, "reorder_for_peak_memory", True)
 | |
|     def test_reorder_peak_memory(self):
 | |
|         outp_corr = self.model(self.inputs)
 | |
|         compiled_model = torch.compile(self.model)
 | |
|         code = run_and_get_triton_code(compiled_model, self.inputs)
 | |
|         (
 | |
|             FileCheck()
 | |
|             .check("def call(args):")
 | |
|             .check("buf1 = ")
 | |
|             .check("buf0 = ")
 | |
|             .check("buf2 = ")
 | |
|             .check("buf4 = ")
 | |
|             .check("buf3 = ")
 | |
|             .check("buf5 = ")
 | |
|             .check("buf7 = ")
 | |
|             .run(code)
 | |
|         )
 | |
|         # check for correctness
 | |
|         outp = compiled_model(self.inputs)
 | |
|         self.assertTrue(same(outp, outp_corr))
 | |
| 
 | |
|     @mock.patch.object(config, "reorder_for_peak_memory", True)
 | |
|     def test_reorder_peak_memory_lpmf(self):
 | |
|         outp_corr = self.model(self.inputs)
 | |
| 
 | |
|         def reorder_with_only_lpmf(
 | |
|             nodes,
 | |
|             name_to_buf,
 | |
|             name_to_fused_node,
 | |
|             graph_inputs,
 | |
|             graph_outputs,
 | |
|             methods=None,
 | |
|         ):
 | |
|             return self.orig_reorder_method(
 | |
|                 nodes,
 | |
|                 name_to_buf,
 | |
|                 name_to_fused_node,
 | |
|                 graph_inputs,
 | |
|                 graph_outputs,
 | |
|                 methods=[memory.topological_sort_lpmf],
 | |
|             )
 | |
| 
 | |
|         with mock.patch.object(
 | |
|             memory, "reorder_for_peak_memory", reorder_with_only_lpmf
 | |
|         ):
 | |
|             compiled_model = torch.compile(self.model)
 | |
| 
 | |
|             code = run_and_get_triton_code(compiled_model, self.inputs)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("def call(args):")
 | |
|                 .check("buf1 = ")
 | |
|                 .check("buf0 = ")
 | |
|                 .check("buf2 = ")
 | |
|                 .check("buf4 = ")
 | |
|                 .check("buf3 = ")
 | |
|                 .check("buf5 = ")
 | |
|                 .check("buf7 = ")
 | |
|                 .run(code)
 | |
|             )
 | |
|             # check for correctness
 | |
|             outp = compiled_model(self.inputs)
 | |
|             self.assertTrue(same(outp, outp_corr))
 | |
| 
 | |
|     @mock.patch.object(config, "reorder_for_peak_memory", True)
 | |
|     def test_reorder_peak_memory_bfs(self):
 | |
|         outp_corr = self.model(self.inputs)
 | |
| 
 | |
|         def reorder_with_only_bfs(
 | |
|             nodes,
 | |
|             name_to_buf,
 | |
|             name_to_fused_node,
 | |
|             graph_inputs,
 | |
|             graph_outputs,
 | |
|             methods=None,
 | |
|         ):
 | |
|             return self.orig_reorder_method(
 | |
|                 nodes,
 | |
|                 name_to_buf,
 | |
|                 name_to_fused_node,
 | |
|                 graph_inputs,
 | |
|                 graph_outputs,
 | |
|                 methods=[memory.topological_sort_bfs],
 | |
|             )
 | |
| 
 | |
|         with mock.patch.object(
 | |
|             memory, "reorder_for_peak_memory", reorder_with_only_bfs
 | |
|         ):
 | |
|             compiled_model = torch.compile(self.model)
 | |
| 
 | |
|             code = run_and_get_triton_code(compiled_model, self.inputs)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("def call(args):")
 | |
|                 .check("buf0 = ")
 | |
|                 .check("buf1 = ")
 | |
|                 .check("buf2 = ")
 | |
|                 .check("buf3 = ")
 | |
|                 .check("buf4 = ")
 | |
|                 .check("buf5 = ")
 | |
|                 .check("buf7 = ")
 | |
|                 .run(code)
 | |
|             )
 | |
|             # check for correctness
 | |
|             outp = compiled_model(self.inputs)
 | |
|             self.assertTrue(same(outp, outp_corr))
 | |
| 
 | |
|     @mock.patch.object(config, "reorder_for_peak_memory", True)
 | |
|     def test_reorder_peak_memory_dfs(self):
 | |
|         outp_corr = self.model(self.inputs)
 | |
| 
 | |
|         def reorder_with_only_dfs(
 | |
|             nodes,
 | |
|             name_to_buf,
 | |
|             name_to_fused_node,
 | |
|             graph_inputs,
 | |
|             graph_outputs,
 | |
|             methods=None,
 | |
|         ):
 | |
|             return self.orig_reorder_method(
 | |
|                 nodes,
 | |
|                 name_to_buf,
 | |
|                 name_to_fused_node,
 | |
|                 graph_inputs,
 | |
|                 graph_outputs,
 | |
|                 methods=[memory.topological_sort_dfs],
 | |
|             )
 | |
| 
 | |
|         with mock.patch.object(
 | |
|             memory, "reorder_for_peak_memory", reorder_with_only_dfs
 | |
|         ):
 | |
|             compiled_model = torch.compile(self.model)
 | |
| 
 | |
|             code = run_and_get_triton_code(compiled_model, self.inputs)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("def call(args):")
 | |
|                 .check("buf0 = ")
 | |
|                 .check("buf2 = ")
 | |
|                 .check("buf4 = ")
 | |
|                 .check("buf1 = ")
 | |
|                 .check("buf3 = ")
 | |
|                 .check("buf5 = ")
 | |
|                 .check("buf7 = ")
 | |
|                 .run(code)
 | |
|             )
 | |
|             # check for correctness
 | |
|             outp = compiled_model(self.inputs)
 | |
|             self.assertTrue(same(outp, outp_corr))
 | |
| 
 | |
|     @mock.patch.object(config, "allow_buffer_reuse", False)
 | |
|     @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
 | |
|     def test_mutation_size_propogation(self):
 | |
|         """
 | |
|         This tests correct size propogation in the case of mutations.
 | |
|         In this example, buf1 is a mutation of buf0; we should have:
 | |
|         * buf0: has size_alloc 2048 and size_free 0;
 | |
|         * buf1: has size_alloc 0 and size_free 2048.
 | |
|         This is because
 | |
|         - when buf1 is created, no additional memory is used; and
 | |
|         - the 2048 bytes of memory can only be released when buf1 is freed.
 | |
|         Similar arguments for buf2 and buf3, buf4 and buf5, etc.
 | |
|         """
 | |
| 
 | |
|         # using triton custom kernel to creat small example with mutations
 | |
|         @triton.jit
 | |
|         def convert_to_bf16_kernel(
 | |
|             input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
 | |
|         ):
 | |
|             pid = tl.program_id(axis=0)
 | |
|             block_start = pid * BLOCK_SIZE
 | |
|             offsets = block_start + tl.arange(0, BLOCK_SIZE)
 | |
|             mask = offsets < n_elements
 | |
|             x = tl.load(input_ptr + offsets, mask=mask)
 | |
|             x_bf16 = x.to(tl.bfloat16)
 | |
|             tl.store(output_ptr + offsets, x_bf16, mask=mask)
 | |
| 
 | |
|         def convert_to_bf16(x):
 | |
|             output = torch.empty_like(x, dtype=torch.bfloat16)
 | |
|             n_elements = x.numel()
 | |
|             BLOCK_SIZE = 1024
 | |
|             grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
 | |
|             convert_to_bf16_kernel[grid](
 | |
|                 x.flatten(), output.flatten(), n_elements, BLOCK_SIZE
 | |
|             )
 | |
|             return output.view(x.shape)
 | |
| 
 | |
|         # create a custom function to record the buffer size information
 | |
|         buffer_info = {}
 | |
|         og_method = memory.assign_memory_planning_info_for_scheduler_buffers
 | |
| 
 | |
|         def assign_memory_planning_info_for_scheduler_buffers_with_records(
 | |
|             nodes, name_to_buf
 | |
|         ):
 | |
|             og_method(nodes, name_to_buf)
 | |
|             for buf_name, buf in name_to_buf.items():
 | |
|                 buffer_info[buf_name] = (
 | |
|                     buf.mpi_buffer.size_alloc,
 | |
|                     buf.mpi_buffer.size_free,
 | |
|                 )
 | |
| 
 | |
|         # test example and checks
 | |
|         def f(a, p):
 | |
|             for e in a:
 | |
|                 e = convert_to_bf16(e)
 | |
|                 p = p @ e
 | |
|             return p
 | |
| 
 | |
|         a = [torch.randn(32, 32, device=GPU_TYPE) for _ in range(4)]
 | |
|         p = torch.ones(a[0].size(), dtype=torch.bfloat16, device=GPU_TYPE)
 | |
| 
 | |
|         with mock.patch.object(
 | |
|             memory,
 | |
|             "assign_memory_planning_info_for_scheduler_buffers",
 | |
|             assign_memory_planning_info_for_scheduler_buffers_with_records,
 | |
|         ):
 | |
|             f_compiled = torch.compile(f)
 | |
|             f_compiled(a, p)
 | |
|             for buf_name in ["buf0", "buf2", "buf4", "buf6"]:
 | |
|                 self.assertEqual(buffer_info[buf_name], (2048, 0))
 | |
| 
 | |
|             for buf_name in ["buf1", "buf3", "buf5", "buf7"]:
 | |
|                 self.assertEqual(buffer_info[buf_name], (0, 2048))
 | |
| 
 | |
|     @unittest.skipIf(
 | |
|         not torch.cuda.is_available()
 | |
|         or torch.cuda.get_device_properties().total_memory < int(1e10),
 | |
|         "Need 10GB memory to be safe to run the test",
 | |
|     )
 | |
|     def test_fusing_reductions_increase_peak_memory(self):
 | |
|         @torch.compile
 | |
|         def f(a, b, c):
 | |
|             return (a @ c).sum(dim=-1) + (b @ c).sum(dim=-1)
 | |
| 
 | |
|         a = torch.randn(1024 * 32, 16, device=GPU_TYPE)
 | |
|         b = torch.randn(1024 * 32, 16, device=GPU_TYPE)
 | |
|         c = torch.randn(16, 1024 * 32, device=GPU_TYPE)
 | |
|         torch.cuda.reset_peak_memory_stats()
 | |
|         f(a, b, c)
 | |
|         peak_mem = torch.cuda.max_memory_allocated()
 | |
| 
 | |
|         expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
 | |
|         self.assertLess(peak_mem, expected_bound)
 | |
| 
 | |
|     def test_fusion_acc_large_reads(self):
 | |
|         def f(x, y, z):
 | |
|             res = torch.zeros_like(x[0])
 | |
|             for i in range(4):
 | |
|                 temp = torch.matmul(x, y) + z
 | |
|                 res = res + temp
 | |
|             return res
 | |
| 
 | |
|         N = 128
 | |
|         x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
 | |
|         y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
 | |
|         z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
 | |
| 
 | |
|         # CASE 1: no restriction on the amount of accumulation
 | |
|         with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
 | |
|             f_compiled = torch.compile(f)
 | |
|             code = run_and_get_triton_code(f_compiled, x, y, z)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
 | |
|                 .run(code)
 | |
|             )
 | |
| 
 | |
|         # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
 | |
|         # at most 12 / 4 = 3 reads can be accumulated during fusion
 | |
|         with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
 | |
|             f_compiled = torch.compile(f)
 | |
|             code = run_and_get_triton_code(f_compiled, x, y, z)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
 | |
|                 .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
 | |
|                 .run(code)
 | |
|             )
 | |
| 
 | |
|         # CASE 3: no such fusion allowed
 | |
|         with config.patch({"realize_acc_reads_size_threshold": N**2}):
 | |
|             f_compiled = torch.compile(f)
 | |
|             code = run_and_get_triton_code(f_compiled, x, y, z)
 | |
|             (
 | |
|                 FileCheck()
 | |
|                 .check("triton_poi_fused_add_0.run(buf1, arg2_1,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf3, arg2_1,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf4, buf3,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf6, arg2_1,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf7, buf6,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf9, arg2_1,")
 | |
|                 .check("triton_poi_fused_add_0.run(buf10, buf9,")
 | |
|                 .run(code)
 | |
|             )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     from torch._inductor.test_case import run_tests
 | |
| 
 | |
|     if HAS_GPU:
 | |
|         run_tests()
 |