mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569)
With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline: ``` When an operation mutates a buffer in-place, the scheduler creates a new buffer name to track the "before" and "after" states, even though they share the same memory. The mutated buffer represents a rename with zero allocation and deallocation cost. During dependency tracking, we transfer dependencies from the mutated name back to the original buffer, ensuring the original memory is only freed when all aliases are done. This handles cases where a buffer has multiple non-overlapping aliases - rather than trying to assign free costs to individual aliases, we forward all alias dependencies to the original buffer. Consider: buf0 = op0() buf1 = mutation_op_(buf0) del buf0 ... op(buf1) del buf1 The only memory events are the creation prior to op0, and the deletion following buf1. ``` As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect. This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
9884d0351e
commit
eb25a95a6e
@ -179,8 +179,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
.check("extern_kernels.mm")
|
||||
.check("triton_poi_fused_relu")
|
||||
.check("torch.ops._c10d_functional.all_reduce_.default")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||
.check_same("buf0")
|
||||
# mm not use buf prior to wait_tensor
|
||||
.check("extern_kernels.mm")
|
||||
.check_not("buf0")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default")
|
||||
.check("extern_kernels.mm")
|
||||
.run(code)
|
||||
)
|
||||
|
@ -1745,10 +1745,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
"allow_buffer_reuse": False,
|
||||
"test_configs.track_memory_lifecycle": "error",
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
compiled = torch.compile(func, fullgraph=True)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
|
||||
# make sure memory tracking is codegen. the ops will then do runtime checking with assertion.
|
||||
FileCheck().check("check_memory_step").check("tracked_empty_strided").run(code)
|
||||
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unnecessary copy is made.
|
||||
(
|
||||
|
@ -215,6 +215,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
|
||||
@mock.patch.object(config, "allow_buffer_reuse", False)
|
||||
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
|
||||
@config.patch("test_configs.track_memory_lifecycle", "assert")
|
||||
def test_mutation_size_propogation(self):
|
||||
"""
|
||||
This tests correct size propogation in the case of mutations.
|
||||
@ -262,6 +263,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
buffer_info[buf_name] = (
|
||||
buf.mpi_buffer.size_alloc,
|
||||
buf.mpi_buffer.size_free,
|
||||
buf.mpi_buffer.succ_nodes,
|
||||
)
|
||||
|
||||
# test example and checks
|
||||
@ -281,11 +283,15 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
):
|
||||
f_compiled = torch.compile(f)
|
||||
f_compiled(a, p)
|
||||
for buf_name in ["buf0", "buf2", "buf4", "buf6"]:
|
||||
self.assertEqual(buffer_info[buf_name], (2048, 0))
|
||||
|
||||
for buf_name in ["buf1", "buf3", "buf5", "buf7"]:
|
||||
self.assertEqual(buffer_info[buf_name], (0, 2048))
|
||||
pre_mutation = ["buf0", "buf2", "buf4", "buf6"]
|
||||
post_mutation = ["buf1", "buf3", "buf5", "buf7"]
|
||||
|
||||
for pre, post in zip(pre_mutation, post_mutation):
|
||||
self.assertEqual(buffer_info[pre][0:2], (2048, 2048))
|
||||
self.assertEqual(buffer_info[post][0:2], (0, 0))
|
||||
# succ nodes should be forwarded to pre mutation buffer
|
||||
self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2])
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
@ -359,6 +365,49 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
.run(code)
|
||||
)
|
||||
|
||||
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
|
||||
def test_multiple_mutations_of_buf(self):
|
||||
@torch.compile()
|
||||
def foo(inp, inp2):
|
||||
inp = inp @ inp
|
||||
inp = inp.view(2, -1, 256)
|
||||
x = inp[0]
|
||||
y = inp[1]
|
||||
x, y = torch._foreach_add([x, y], 1.0)
|
||||
out = x.sum()
|
||||
out2 = y.sum(dim=-1)
|
||||
|
||||
return out, out2, inp2 @ inp2
|
||||
|
||||
inp = torch.rand([256, 256], device="cuda")
|
||||
inp2 = torch.rand([256, 256], device="cuda")
|
||||
|
||||
def replace_foreach(gm):
|
||||
nodes = gm.find_nodes(
|
||||
op="call_function", target=torch.ops.aten._foreach_add.Scalar
|
||||
)
|
||||
assert len(nodes) == 1
|
||||
node = nodes[0]
|
||||
nodes[0].target = torch.ops.aten._foreach_add_.Scalar
|
||||
for inp, out in zip(node.args[0], list(node.users.keys())):
|
||||
out.replace_all_uses_with(inp)
|
||||
gm.erase_node(out)
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"post_grad_custom_post_pass": replace_foreach,
|
||||
"test_configs.track_memory_lifecycle": "assert",
|
||||
"allow_buffer_reuse": False,
|
||||
# make sure the mm is at the end so
|
||||
# the earlier deallocation is not at the last step,
|
||||
# which doesnt distinguish between returned tensors
|
||||
# and which tensors are deallocated immediately prior
|
||||
"reorder_for_peak_memory": False,
|
||||
}
|
||||
):
|
||||
code = run_and_get_triton_code(foo, inp, inp2)
|
||||
FileCheck().check("allocated=['buf0']").run(code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -963,9 +963,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||
aot_config_comment = ""
|
||||
if context is not None and context.aot_graph_name is not None:
|
||||
aot_config_comment = f"# AOT ID: {context.aot_graph_name}"
|
||||
aot_inductor_debug_utils = ""
|
||||
inductor_debug_utils = ""
|
||||
if int(config.aot_inductor.debug_intermediate_value_printer) > 0:
|
||||
aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info"
|
||||
inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info"
|
||||
elif torch._inductor.config.test_configs.track_memory_lifecycle:
|
||||
inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n"
|
||||
|
||||
self.imports.splice(
|
||||
f"""
|
||||
{aot_config_comment}
|
||||
@ -983,7 +986,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
from torch import device, empty_strided
|
||||
from {async_compile.__name__} import AsyncCompile
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
{aot_inductor_debug_utils}
|
||||
{inductor_debug_utils}
|
||||
""",
|
||||
strip=True,
|
||||
)
|
||||
@ -2773,6 +2776,14 @@ class PythonWrapperCodegen(CodeGen):
|
||||
buffer.get_name(), device, dtype, shape, stride, allocation_shape
|
||||
)
|
||||
|
||||
@cache_on_self
|
||||
def write_memory_track_allocation_once(self):
|
||||
import_str = """
|
||||
from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor
|
||||
"""
|
||||
if not V.graph.cpp_wrapper:
|
||||
self.imports.splice(import_str, strip=True)
|
||||
|
||||
def make_allocation(
|
||||
self, name, device, dtype, shape, stride, allocation_shape=None
|
||||
):
|
||||
@ -2784,7 +2795,16 @@ class PythonWrapperCodegen(CodeGen):
|
||||
allocation_shape
|
||||
)
|
||||
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
|
||||
if device.type in ("cpu", "cuda", "xpu", "mtia"):
|
||||
if torch._inductor.config.test_configs.track_memory_lifecycle:
|
||||
out = (
|
||||
f"{name} = tracked_empty_strided("
|
||||
f"{codegen_allocation_shape_tuple}, "
|
||||
f"{codegen_stride_tuple}, "
|
||||
f"dtype={dtype}, "
|
||||
f"device='{device.type}', "
|
||||
f"name='{name}')"
|
||||
)
|
||||
elif device.type in ("cpu", "cuda", "xpu", "mtia"):
|
||||
# optimized path for faster allocations, saving ~2us versus the stuff below
|
||||
out = (
|
||||
f"{name} = empty_strided_{device.type}("
|
||||
|
@ -1861,6 +1861,8 @@ class test_configs:
|
||||
|
||||
graphsafe_rng_func_ignores_fallback_random = False
|
||||
|
||||
track_memory_lifecycle: Optional[Literal["assert", "log"]] = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
@ -5324,6 +5324,11 @@ class ConcatKernel(NopKernel):
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class ExternKernel(InputsKernel):
|
||||
"""
|
||||
A class that represents Kernels which are not directly lowered to Inductor
|
||||
Loop Level IR, such as custom operators, or aten operators which we fallback to.
|
||||
"""
|
||||
|
||||
constant_args: Sequence[Any] = ()
|
||||
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
output_view: Optional[ReinterpretView] = None
|
||||
@ -6120,6 +6125,17 @@ class ExternKernel(InputsKernel):
|
||||
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
|
||||
)
|
||||
|
||||
def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
"""
|
||||
Track outputs of fallback operators if config.test_configs.track_memory_lifecycle
|
||||
"""
|
||||
if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper:
|
||||
return
|
||||
|
||||
wrapper.write_memory_track_allocation_once()
|
||||
name = self.get_name()
|
||||
wrapper.writeline(f"track_tensor({name}, '{name}')")
|
||||
|
||||
def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]:
|
||||
"""
|
||||
get output sizes and strides, for template_codegen
|
||||
@ -7579,6 +7595,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
self.codegen_memory_tracking(wrapper)
|
||||
|
||||
self.codegen_unbacked_symbol_defs(wrapper)
|
||||
|
||||
@ -7720,6 +7737,31 @@ class ComplexView(FallbackKernel):
|
||||
)
|
||||
|
||||
|
||||
class MemoryCheckKernel(FallbackKernel):
|
||||
"""
|
||||
Custom kernel for memory checking that generates direct function calls
|
||||
|
||||
TODO - the custom op was erroring with str inputs. should be able to custom op directly.
|
||||
"""
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
"""Override codegen to write direct function call"""
|
||||
# Extract our arguments from nontensor_args
|
||||
wrapper.write_memory_track_allocation_once()
|
||||
alive_list, dead_list, is_final_step = self.constant_args
|
||||
|
||||
alive_repr = repr(alive_list)
|
||||
dead_repr = repr(dead_list)
|
||||
if is_final_step:
|
||||
wrapper.writeline(
|
||||
"# note: dont currently distinguish between buffers returned and dealloc'd in last step"
|
||||
)
|
||||
call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})"
|
||||
else:
|
||||
call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})"
|
||||
wrapper.writeline(call)
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
class MultiOutputLayout(OutputSpec):
|
||||
device: torch.device
|
||||
|
@ -124,6 +124,28 @@ def compute_size_for_scheduler_buffer(
|
||||
buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed
|
||||
buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed
|
||||
|
||||
When an operation mutates a buffer in-place, the scheduler creates a new buffer name
|
||||
to track the "before" and "after" states, even though they share the same memory.
|
||||
|
||||
The mutated buffer represents a rename with zero allocation and deallocation cost.
|
||||
During dependency tracking, we transfer dependencies from the mutated name back to
|
||||
the original buffer, ensuring the original memory is only freed when all aliases
|
||||
are done.
|
||||
|
||||
This handles cases where a buffer has multiple non-overlapping aliases - rather than
|
||||
trying to assign free costs to individual aliases, we forward all alias dependencies
|
||||
to the original buffer.
|
||||
|
||||
Consider:
|
||||
buf0 = op0()
|
||||
buf1 = mutation_op_(buf0)
|
||||
del buf0
|
||||
...
|
||||
op(buf1)
|
||||
del buf1
|
||||
|
||||
The only memory events are the creation prior to op0, and the deletion following buf1.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free).
|
||||
"""
|
||||
@ -135,18 +157,11 @@ def compute_size_for_scheduler_buffer(
|
||||
def _compute_and_update_buf_size(
|
||||
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
|
||||
) -> int:
|
||||
if isinstance(sched_buf.node.layout, NoneLayout):
|
||||
# mutations should inherit the size of the mutated buffer
|
||||
if sched_buf.get_mutations():
|
||||
mutated_buf_name = sched_buf.get_mutations()[0]
|
||||
if mutated_buf_name in sched_buf_to_size:
|
||||
(_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name]
|
||||
else:
|
||||
(_size_alloc, _size_free) = (0, 0)
|
||||
sched_buf_to_size[sched_buf.get_name()] = (0, _size_free)
|
||||
sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0)
|
||||
else:
|
||||
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
|
||||
if sched_buf.get_name() in V.graph.scheduler.mutation_real_name:
|
||||
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
|
||||
return 0
|
||||
elif isinstance(sched_buf.node.layout, NoneLayout):
|
||||
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
|
||||
return 0
|
||||
elif isinstance(sched_buf.node.layout, MultiOutputLayout):
|
||||
size_alloc = 0
|
||||
@ -200,6 +215,14 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||
for dep in node.unmet_dependencies:
|
||||
dep_name_to_succ_nodes[dep.name].add(node)
|
||||
|
||||
# iterate in reverse, so dependencies are picked up transitively.
|
||||
for mutating_buf_name, real_buf_name in reversed(
|
||||
V.graph.scheduler.mutation_real_name.items()
|
||||
):
|
||||
dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[
|
||||
mutating_buf_name
|
||||
]
|
||||
|
||||
# populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer
|
||||
# note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs)
|
||||
for buf_name in name_to_buf.keys():
|
||||
@ -219,58 +242,72 @@ def assign_memory_planning_info_for_scheduler_nodes(
|
||||
"""
|
||||
Assign to each scheduler node its predecessor and successor nodes.
|
||||
"""
|
||||
from .scheduler import SchedulerBuffer
|
||||
|
||||
for index, node in enumerate(nodes):
|
||||
size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs())
|
||||
pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]()
|
||||
for dep in node.read_writes.reads:
|
||||
if dep.name in name_to_buf and dep in node.unmet_dependencies:
|
||||
pred_buffers.add(name_to_buf[dep.name])
|
||||
elif dep.name in name_to_freeable_input_buf:
|
||||
pred_buffers.add(name_to_freeable_input_buf[dep.name])
|
||||
pred_nodes = OrderedSet(
|
||||
name_to_fused_node[pred_buffer.defining_op_name()]
|
||||
for pred_buffer in pred_buffers
|
||||
if (isinstance(pred_buffer, SchedulerBuffer))
|
||||
)
|
||||
node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = (
|
||||
collections.defaultdict(OrderedSet)
|
||||
)
|
||||
node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {}
|
||||
node_to_pred_buffers: dict[
|
||||
BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
|
||||
# collect all predecessors using existing successor mappings
|
||||
for node in nodes:
|
||||
succ_nodes = OrderedSet(
|
||||
succ_node
|
||||
for buffer in node.get_outputs()
|
||||
for succ_node in buffer.mpi_buffer.succ_nodes
|
||||
)
|
||||
node_to_succ_nodes[node] = succ_nodes
|
||||
|
||||
# For each successor, add current node as its predecessor
|
||||
for succ_node in succ_nodes:
|
||||
node_to_pred_nodes[succ_node].add(node)
|
||||
|
||||
# For each output buffer, add it as predecessor to its successor nodes
|
||||
# TODO - is pred buffers needed ?
|
||||
for buffer in node.get_outputs():
|
||||
for succ_node in buffer.mpi_buffer.succ_nodes:
|
||||
node_to_pred_buffers[succ_node].add(buffer)
|
||||
|
||||
for freeable_buffer in name_to_freeable_input_buf.values():
|
||||
for succ_node in freeable_buffer.mpi_buffer.succ_nodes:
|
||||
node_to_pred_buffers[succ_node].add(freeable_buffer)
|
||||
|
||||
# Second pass: assign memory planning info using completed predecessor mappings
|
||||
for index, node in enumerate(nodes):
|
||||
size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs())
|
||||
succ_nodes = node_to_succ_nodes[node]
|
||||
|
||||
node.mpi_node = MemoryPlanningInfoForNode(
|
||||
index=index,
|
||||
size=size_alloc,
|
||||
pred_buffers=pred_buffers,
|
||||
pred_nodes=pred_nodes,
|
||||
pred_buffers=node_to_pred_buffers[node],
|
||||
pred_nodes=node_to_pred_nodes[node],
|
||||
succ_nodes=succ_nodes,
|
||||
)
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
# map each scheduler buffer to its size, start step, and end step
|
||||
@dataclasses.dataclass
|
||||
class BufferInfo:
|
||||
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
|
||||
size_alloc: int
|
||||
size_free: int
|
||||
start_step: int
|
||||
end_step: int
|
||||
|
||||
|
||||
def compute_memory_timeline(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[int, list[int]]:
|
||||
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
|
||||
"""
|
||||
Given a list of nodes in their execution order, estimate the peak memory, by
|
||||
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
|
||||
|
||||
Returns:
|
||||
int: peak memory
|
||||
List[int]: memory usage at each node (or each step).
|
||||
Compute buffer allocation and deallocation sizes and map their
|
||||
lifetime to the node schedule
|
||||
"""
|
||||
|
||||
# map each scheduler buffer to its size, start step, and end step
|
||||
@dataclasses.dataclass
|
||||
class BufferInfo:
|
||||
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
|
||||
size_alloc: int
|
||||
size_free: int
|
||||
start_step: int
|
||||
end_step: int
|
||||
|
||||
# get the execution step of each node, this will be used to determine
|
||||
# the end_step of buffers
|
||||
node_to_step: dict[BaseSchedulerNode, int] = {
|
||||
@ -325,6 +362,27 @@ def estimate_peak_memory(
|
||||
)
|
||||
)
|
||||
|
||||
return buf_info_list, node_to_step
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Given a list of nodes in their execution order, estimate the peak memory, by
|
||||
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
|
||||
|
||||
Returns:
|
||||
int: peak memory
|
||||
List[int]: memory usage at each node (or each step).
|
||||
"""
|
||||
|
||||
buf_info_list, _ = compute_memory_timeline(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
# incremental memory changes at each step
|
||||
memory = [0 for _ in range(len(nodes) + 1)]
|
||||
|
||||
|
138
torch/_inductor/runtime/debug_utils.py
Normal file
138
torch/_inductor/runtime/debug_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
local = threading.local()
|
||||
local.memory_tracker = None
|
||||
|
||||
|
||||
class BufferMemoryTracker:
|
||||
"""
|
||||
Tracks inductor runtime allocations and deallocations to compare against
|
||||
expected behavior.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = (
|
||||
weakref.WeakValueDictionary() # type: ignore[assignment]
|
||||
)
|
||||
self.died_since_last_step: OrderedSet[str] = OrderedSet()
|
||||
self.added_since_last_step: OrderedSet[str] = OrderedSet()
|
||||
self.error = (
|
||||
torch._inductor.config.test_configs.track_memory_lifecycle == "assert"
|
||||
)
|
||||
|
||||
def set_tensor(self, name: str, tensor: torch.Tensor) -> None:
|
||||
storage = tensor.untyped_storage()
|
||||
|
||||
self.added_since_last_step.add(name)
|
||||
self.tensor_tracker[name] = storage
|
||||
|
||||
def on_tensor_death() -> None:
|
||||
self.died_since_last_step.add(name)
|
||||
|
||||
weakref.finalize(storage, on_tensor_death)
|
||||
|
||||
def advance_step(self) -> None:
|
||||
self.died_since_last_step.clear()
|
||||
self.added_since_last_step.clear()
|
||||
|
||||
def log_or_raise(self, msg: str) -> None:
|
||||
if self.error:
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
log.info(msg)
|
||||
|
||||
def check_step_delta(
|
||||
self,
|
||||
expected_allocated: list[str],
|
||||
expected_freed: list[str],
|
||||
is_final_step: bool,
|
||||
) -> None:
|
||||
"""Check only the delta changes since last step"""
|
||||
|
||||
# Check expected deaths - we dont currently distinguish between nodes which die in last step
|
||||
# and are returned as outputs, so skip if final_step.
|
||||
if not is_final_step:
|
||||
missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step
|
||||
if missing_deaths:
|
||||
self.log_or_raise(
|
||||
f"Expected tensors to die but still alive: {missing_deaths}"
|
||||
)
|
||||
|
||||
# Check for unexpected deaths
|
||||
unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed)
|
||||
if unexpected_deaths:
|
||||
self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}")
|
||||
|
||||
# Check newly alive tensors - separate messages like deaths
|
||||
actual_allocated = self.added_since_last_step
|
||||
expected_allocated_set = OrderedSet(expected_allocated)
|
||||
|
||||
extra_alive = actual_allocated - expected_allocated_set
|
||||
if extra_alive:
|
||||
self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}")
|
||||
|
||||
missing_alive = expected_allocated_set - actual_allocated
|
||||
if missing_alive:
|
||||
self.log_or_raise(
|
||||
f"Expected allocated tensors but missing: {missing_alive}"
|
||||
)
|
||||
|
||||
# Reset for next step
|
||||
self.advance_step()
|
||||
|
||||
if is_final_step:
|
||||
local.memory_tracker = None
|
||||
|
||||
|
||||
def get_mem_tracker() -> BufferMemoryTracker:
|
||||
if local.memory_tracker is None:
|
||||
local.memory_tracker = BufferMemoryTracker()
|
||||
return local.memory_tracker
|
||||
|
||||
|
||||
def track_tensor(tensor: torch.Tensor, name: str) -> None:
|
||||
get_mem_tracker().set_tensor(name, tensor)
|
||||
|
||||
|
||||
def tracked_empty_strided(
|
||||
size: list[int],
|
||||
stride: list[int],
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
name: str,
|
||||
) -> torch.Tensor:
|
||||
o = torch.empty_strided(size, stride, dtype=dtype, device=device)
|
||||
track_tensor(o, name)
|
||||
return o
|
||||
|
||||
|
||||
def check_memory_step(
|
||||
allocated: list[str], freed: list[str], is_final_step: bool = False
|
||||
) -> None:
|
||||
tracker = get_mem_tracker()
|
||||
tracker.check_step_delta(allocated, freed, is_final_step)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def register_check_mem_op() -> None:
|
||||
lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901
|
||||
lib.define(
|
||||
"check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()"
|
||||
)
|
||||
lib.impl("check_memory_step", check_memory_step, "BackendSelect")
|
||||
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
|
||||
|
||||
_register_effectful_op(
|
||||
torch.ops._inductor_debug.check_memory_step.default,
|
||||
_EffectType.ORDERED,
|
||||
)
|
@ -2184,6 +2184,10 @@ class Scheduler:
|
||||
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
|
||||
|
||||
self.compute_last_usage()
|
||||
|
||||
if torch._inductor.config.test_configs.track_memory_lifecycle:
|
||||
self.insert_memory_check_nodes()
|
||||
|
||||
log_ir_post_fusion(self.nodes)
|
||||
V.debug.graph_diagram(self.nodes)
|
||||
self.debug_draw_graph()
|
||||
@ -2518,6 +2522,83 @@ class Scheduler:
|
||||
compute_dependencies_log.debug("BUFFER USER LIST\n")
|
||||
compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)
|
||||
|
||||
def insert_memory_check_nodes(self) -> None:
|
||||
from .memory import (
|
||||
assign_memory_planning_info_for_scheduler_buffers,
|
||||
compute_memory_timeline,
|
||||
FreeableInputBuffer,
|
||||
get_freeable_input_buf,
|
||||
)
|
||||
|
||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = (
|
||||
get_freeable_input_buf(self.nodes, graph_inputs)
|
||||
)
|
||||
|
||||
if not torch._inductor.config.reorder_for_peak_memory:
|
||||
assign_memory_planning_info_for_scheduler_buffers(
|
||||
self.nodes, self.name_to_buf
|
||||
)
|
||||
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
buf_info_list, _ = compute_memory_timeline(
|
||||
self.nodes,
|
||||
name_to_freeable_input_buf,
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
step_allocs_deallocs: list[tuple[list[str], list[str]]] = [
|
||||
([], []) for _ in range(len(self.nodes))
|
||||
]
|
||||
for buf_info in buf_info_list:
|
||||
# Skip zero-size buffers
|
||||
if buf_info.size_alloc == 0 and buf_info.size_free == 0:
|
||||
continue
|
||||
|
||||
buf_name = buf_info.buffer.get_name()
|
||||
|
||||
step_allocs_deallocs[buf_info.start_step][0].append(buf_name)
|
||||
step_allocs_deallocs[buf_info.end_step][1].append(buf_name)
|
||||
|
||||
from torch._inductor.runtime.debug_utils import register_check_mem_op
|
||||
|
||||
register_check_mem_op()
|
||||
|
||||
def construct_mem_check_node(
|
||||
step_idx: int, is_final_step: bool
|
||||
) -> ExternKernelSchedulerNode:
|
||||
expected_newly_alive = step_allocs_deallocs[step_idx][0]
|
||||
expected_newly_dead = step_allocs_deallocs[step_idx][1]
|
||||
|
||||
nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step]
|
||||
|
||||
node = ir.MemoryCheckKernel(
|
||||
layout=NoneLayout(device=torch.device("cpu")),
|
||||
kernel=torch.ops._inductor_debug.check_memory_step.default,
|
||||
tensor_args=[],
|
||||
nontensor_args=nontensor_args,
|
||||
unflatten_args=lambda tensor_args, constant_args: (
|
||||
tensor_args,
|
||||
{
|
||||
"alive": constant_args[0],
|
||||
"dead": constant_args[1],
|
||||
"is_final_step": constant_args[2],
|
||||
},
|
||||
),
|
||||
)
|
||||
node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}"
|
||||
return ExternKernelSchedulerNode(self, node)
|
||||
|
||||
new_nodes = []
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
new_nodes.append(node)
|
||||
new_nodes.append(
|
||||
construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1))
|
||||
)
|
||||
|
||||
self.nodes = new_nodes
|
||||
|
||||
def dead_node_elimination(self) -> None:
|
||||
"""
|
||||
Remove any nodes without users
|
||||
|
Reference in New Issue
Block a user