mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline: ``` When an operation mutates a buffer in-place, the scheduler creates a new buffer name to track the "before" and "after" states, even though they share the same memory. The mutated buffer represents a rename with zero allocation and deallocation cost. During dependency tracking, we transfer dependencies from the mutated name back to the original buffer, ensuring the original memory is only freed when all aliases are done. This handles cases where a buffer has multiple non-overlapping aliases - rather than trying to assign free costs to individual aliases, we forward all alias dependencies to the original buffer. Consider: buf0 = op0() buf1 = mutation_op_(buf0) del buf0 ... op(buf1) del buf1 The only memory events are the creation prior to op0, and the deletion following buf1. ``` As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect. This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569 Approved by: https://github.com/IvanKobzarev
417 lines
15 KiB
Python
417 lines
15 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
import torch
|
|
from torch._C import FileCheck
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import config, memory
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import run_and_get_triton_code
|
|
from torch.testing._internal.common_utils import serialTest
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
|
|
|
|
try:
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
TRITON_AVAILABLE = True
|
|
except ImportError:
|
|
TRITON_AVAILABLE = False
|
|
|
|
|
|
class Foo(torch.nn.Module):
|
|
"""
|
|
The default compiled graph is
|
|
graph():
|
|
...
|
|
%op0 : [num_users=2] = call_function[...](args = (%primals_2, %primals_1), ...)
|
|
%op1 : [num_users=2] = call_function[...](args = (%primals_2, %primals_3), ...)
|
|
%op2 : [num_users=1] = call_function[...](args = (%op0, %primals_4), ...)
|
|
%op3 : [num_users=1] = call_function[...](args = (%op1, %primals_5), ...)
|
|
%op4 : [num_users=1] = call_function[...](args = (%op2,), ...)
|
|
%op5 : [num_users=1] = call_function[...](args = (%op3,), ...)
|
|
%op6_op7 : [num_users=1] = call_function[...](args = (%op5, %op4), ...)
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.w1 = torch.nn.Parameter(torch.ones(1, 10))
|
|
self.w2 = torch.nn.Parameter(torch.ones(1, 1))
|
|
self.w3 = torch.nn.Parameter(torch.ones(10, 1))
|
|
self.w4 = torch.nn.Parameter(torch.ones(1, 10))
|
|
|
|
def forward(self, x):
|
|
t1 = torch.matmul(x, self.w1)
|
|
t2 = torch.matmul(x, self.w2)
|
|
t3 = torch.matmul(t1, self.w3)
|
|
t4 = torch.matmul(t2, self.w4)
|
|
return t3.sum() + t4.sum()
|
|
|
|
|
|
# The tests in this class uses very small tensors. The default
|
|
# score_fusion_memory threshold will cause different fusion decisions and
|
|
# generate a different wrapper. Override the threshold to make these tests
|
|
# happy.
|
|
@config.patch("score_fusion_memory_threshold", 1)
|
|
class TestOperatorReorderForPeakMemory(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
self.model = Foo().to(GPU_TYPE)
|
|
self.inputs = torch.ones((2048, 1), device=GPU_TYPE)
|
|
self.orig_reorder_method = memory.reorder_for_peak_memory
|
|
|
|
@mock.patch.object(config, "reorder_for_peak_memory", True)
|
|
def test_reorder_peak_memory(self):
|
|
outp_corr = self.model(self.inputs)
|
|
compiled_model = torch.compile(self.model)
|
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
|
(
|
|
FileCheck()
|
|
.check("def call(args):")
|
|
.check("buf1 = ")
|
|
.check("buf0 = ")
|
|
.check("buf2 = ")
|
|
.check("buf4 = ")
|
|
.check("buf3 = ")
|
|
.check("buf5 = ")
|
|
.check("buf7 = ")
|
|
.run(code)
|
|
)
|
|
# check for correctness
|
|
outp = compiled_model(self.inputs)
|
|
self.assertTrue(same(outp, outp_corr))
|
|
|
|
@mock.patch.object(config, "reorder_for_peak_memory", True)
|
|
def test_reorder_peak_memory_lpmf(self):
|
|
outp_corr = self.model(self.inputs)
|
|
|
|
def reorder_with_only_lpmf(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=None,
|
|
):
|
|
return self.orig_reorder_method(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=[memory.topological_sort_lpmf],
|
|
)
|
|
|
|
with mock.patch.object(
|
|
memory, "reorder_for_peak_memory", reorder_with_only_lpmf
|
|
):
|
|
compiled_model = torch.compile(self.model)
|
|
|
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
|
(
|
|
FileCheck()
|
|
.check("def call(args):")
|
|
.check("buf1 = ")
|
|
.check("buf0 = ")
|
|
.check("buf2 = ")
|
|
.check("buf4 = ")
|
|
.check("buf3 = ")
|
|
.check("buf5 = ")
|
|
.check("buf7 = ")
|
|
.run(code)
|
|
)
|
|
# check for correctness
|
|
outp = compiled_model(self.inputs)
|
|
self.assertTrue(same(outp, outp_corr))
|
|
|
|
@mock.patch.object(config, "reorder_for_peak_memory", True)
|
|
def test_reorder_peak_memory_bfs(self):
|
|
outp_corr = self.model(self.inputs)
|
|
|
|
def reorder_with_only_bfs(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=None,
|
|
):
|
|
return self.orig_reorder_method(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=[memory.topological_sort_bfs],
|
|
)
|
|
|
|
with mock.patch.object(
|
|
memory, "reorder_for_peak_memory", reorder_with_only_bfs
|
|
):
|
|
compiled_model = torch.compile(self.model)
|
|
|
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
|
(
|
|
FileCheck()
|
|
.check("def call(args):")
|
|
.check("buf0 = ")
|
|
.check("buf1 = ")
|
|
.check("buf2 = ")
|
|
.check("buf3 = ")
|
|
.check("buf4 = ")
|
|
.check("buf5 = ")
|
|
.check("buf7 = ")
|
|
.run(code)
|
|
)
|
|
# check for correctness
|
|
outp = compiled_model(self.inputs)
|
|
self.assertTrue(same(outp, outp_corr))
|
|
|
|
@mock.patch.object(config, "reorder_for_peak_memory", True)
|
|
def test_reorder_peak_memory_dfs(self):
|
|
outp_corr = self.model(self.inputs)
|
|
|
|
def reorder_with_only_dfs(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=None,
|
|
):
|
|
return self.orig_reorder_method(
|
|
nodes,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
graph_inputs,
|
|
graph_outputs,
|
|
methods=[memory.topological_sort_dfs],
|
|
)
|
|
|
|
with mock.patch.object(
|
|
memory, "reorder_for_peak_memory", reorder_with_only_dfs
|
|
):
|
|
compiled_model = torch.compile(self.model)
|
|
|
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
|
(
|
|
FileCheck()
|
|
.check("def call(args):")
|
|
.check("buf0 = ")
|
|
.check("buf2 = ")
|
|
.check("buf4 = ")
|
|
.check("buf1 = ")
|
|
.check("buf3 = ")
|
|
.check("buf5 = ")
|
|
.check("buf7 = ")
|
|
.run(code)
|
|
)
|
|
# check for correctness
|
|
outp = compiled_model(self.inputs)
|
|
self.assertTrue(same(outp, outp_corr))
|
|
|
|
@mock.patch.object(config, "allow_buffer_reuse", False)
|
|
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
|
|
@config.patch("test_configs.track_memory_lifecycle", "assert")
|
|
def test_mutation_size_propogation(self):
|
|
"""
|
|
This tests correct size propogation in the case of mutations.
|
|
In this example, buf1 is a mutation of buf0; we should have:
|
|
* buf0: has size_alloc 2048 and size_free 0;
|
|
* buf1: has size_alloc 0 and size_free 2048.
|
|
This is because
|
|
- when buf1 is created, no additional memory is used; and
|
|
- the 2048 bytes of memory can only be released when buf1 is freed.
|
|
Similar arguments for buf2 and buf3, buf4 and buf5, etc.
|
|
"""
|
|
|
|
# using triton custom kernel to creat small example with mutations
|
|
@triton.jit
|
|
def convert_to_bf16_kernel(
|
|
input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(input_ptr + offsets, mask=mask)
|
|
x_bf16 = x.to(tl.bfloat16)
|
|
tl.store(output_ptr + offsets, x_bf16, mask=mask)
|
|
|
|
def convert_to_bf16(x):
|
|
output = torch.empty_like(x, dtype=torch.bfloat16)
|
|
n_elements = x.numel()
|
|
BLOCK_SIZE = 1024
|
|
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
|
|
convert_to_bf16_kernel[grid](
|
|
x.flatten(), output.flatten(), n_elements, BLOCK_SIZE
|
|
)
|
|
return output.view(x.shape)
|
|
|
|
# create a custom function to record the buffer size information
|
|
buffer_info = {}
|
|
og_method = memory.assign_memory_planning_info_for_scheduler_buffers
|
|
|
|
def assign_memory_planning_info_for_scheduler_buffers_with_records(
|
|
nodes, name_to_buf
|
|
):
|
|
og_method(nodes, name_to_buf)
|
|
for buf_name, buf in name_to_buf.items():
|
|
buffer_info[buf_name] = (
|
|
buf.mpi_buffer.size_alloc,
|
|
buf.mpi_buffer.size_free,
|
|
buf.mpi_buffer.succ_nodes,
|
|
)
|
|
|
|
# test example and checks
|
|
def f(a, p):
|
|
for e in a:
|
|
e = convert_to_bf16(e)
|
|
p = p @ e
|
|
return p
|
|
|
|
a = [torch.randn(32, 32, device=GPU_TYPE) for _ in range(4)]
|
|
p = torch.ones(a[0].size(), dtype=torch.bfloat16, device=GPU_TYPE)
|
|
|
|
with mock.patch.object(
|
|
memory,
|
|
"assign_memory_planning_info_for_scheduler_buffers",
|
|
assign_memory_planning_info_for_scheduler_buffers_with_records,
|
|
):
|
|
f_compiled = torch.compile(f)
|
|
f_compiled(a, p)
|
|
|
|
pre_mutation = ["buf0", "buf2", "buf4", "buf6"]
|
|
post_mutation = ["buf1", "buf3", "buf5", "buf7"]
|
|
|
|
for pre, post in zip(pre_mutation, post_mutation):
|
|
self.assertEqual(buffer_info[pre][0:2], (2048, 2048))
|
|
self.assertEqual(buffer_info[post][0:2], (0, 0))
|
|
# succ nodes should be forwarded to pre mutation buffer
|
|
self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2])
|
|
|
|
@unittest.skipIf(
|
|
not torch.cuda.is_available()
|
|
or torch.cuda.get_device_properties().total_memory < int(1e10),
|
|
"Need 10GB memory to be safe to run the test",
|
|
)
|
|
def test_fusing_reductions_increase_peak_memory(self):
|
|
@torch.compile
|
|
def f(a, b, c):
|
|
return (a @ c).sum(dim=-1) + (b @ c).sum(dim=-1)
|
|
|
|
a = torch.randn(1024 * 32, 16, device=GPU_TYPE)
|
|
b = torch.randn(1024 * 32, 16, device=GPU_TYPE)
|
|
c = torch.randn(16, 1024 * 32, device=GPU_TYPE)
|
|
torch.cuda.reset_peak_memory_stats()
|
|
f(a, b, c)
|
|
peak_mem = torch.cuda.max_memory_allocated()
|
|
|
|
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
|
|
self.assertLess(peak_mem, expected_bound)
|
|
|
|
@serialTest()
|
|
def test_fusion_acc_large_reads(self):
|
|
def f(x, y, z):
|
|
res = torch.zeros_like(x[0])
|
|
for i in range(4):
|
|
temp = torch.matmul(x, y) + z
|
|
res = res + temp
|
|
return res
|
|
|
|
N = 128
|
|
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
|
|
# CASE 1: no restriction on the amount of accumulation
|
|
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
|
|
f_compiled = torch.compile(f)
|
|
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
(
|
|
FileCheck()
|
|
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
|
|
.run(code)
|
|
)
|
|
|
|
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
|
|
# at most 12 / 4 = 3 reads can be accumulated during fusion
|
|
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
|
|
f_compiled = torch.compile(f)
|
|
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
(
|
|
FileCheck()
|
|
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
|
|
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
|
|
.run(code)
|
|
)
|
|
|
|
# CASE 3: no such fusion allowed
|
|
with config.patch({"realize_acc_reads_size_threshold": N**2}):
|
|
f_compiled = torch.compile(f)
|
|
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
(
|
|
FileCheck()
|
|
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
|
|
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
|
|
.check("triton_poi_fused_add_0.run(buf4, buf3,")
|
|
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
|
|
.check("triton_poi_fused_add_0.run(buf7, buf6,")
|
|
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
|
|
.check("triton_poi_fused_add_0.run(buf10, buf9,")
|
|
.run(code)
|
|
)
|
|
|
|
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
|
|
def test_multiple_mutations_of_buf(self):
|
|
@torch.compile()
|
|
def foo(inp, inp2):
|
|
inp = inp @ inp
|
|
inp = inp.view(2, -1, 256)
|
|
x = inp[0]
|
|
y = inp[1]
|
|
x, y = torch._foreach_add([x, y], 1.0)
|
|
out = x.sum()
|
|
out2 = y.sum(dim=-1)
|
|
|
|
return out, out2, inp2 @ inp2
|
|
|
|
inp = torch.rand([256, 256], device="cuda")
|
|
inp2 = torch.rand([256, 256], device="cuda")
|
|
|
|
def replace_foreach(gm):
|
|
nodes = gm.find_nodes(
|
|
op="call_function", target=torch.ops.aten._foreach_add.Scalar
|
|
)
|
|
assert len(nodes) == 1
|
|
node = nodes[0]
|
|
nodes[0].target = torch.ops.aten._foreach_add_.Scalar
|
|
for inp, out in zip(node.args[0], list(node.users.keys())):
|
|
out.replace_all_uses_with(inp)
|
|
gm.erase_node(out)
|
|
|
|
with torch._inductor.config.patch(
|
|
{
|
|
"post_grad_custom_post_pass": replace_foreach,
|
|
"test_configs.track_memory_lifecycle": "assert",
|
|
"allow_buffer_reuse": False,
|
|
# make sure the mm is at the end so
|
|
# the earlier deallocation is not at the last step,
|
|
# which doesnt distinguish between returned tensors
|
|
# and which tensors are deallocated immediately prior
|
|
"reorder_for_peak_memory": False,
|
|
}
|
|
):
|
|
code = run_and_get_triton_code(foo, inp, inp2)
|
|
FileCheck().check("allocated=['buf0']").run(code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU:
|
|
run_tests()
|