Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569)

With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline:

```
    When an operation mutates a buffer in-place, the scheduler creates a new buffer name
    to track the "before" and "after" states, even though they share the same memory.
    The mutated buffer represents a rename with zero allocation and deallocation cost.
    During dependency tracking, we transfer dependencies from the mutated name back to
    the original buffer, ensuring the original memory is only freed when all aliases
    are done.
    This handles cases where a buffer has multiple non-overlapping aliases - rather than
    trying to assign free costs to individual aliases, we forward all alias dependencies
    to the original buffer.
    Consider:
        buf0 = op0()
        buf1 = mutation_op_(buf0)
        del buf0
        ...
        op(buf1)
        del buf1
    The only memory events are the creation prior to op0, and the deletion following buf1.
```

As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect.

This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569
Approved by: https://github.com/IvanKobzarev
This commit is contained in:
eellison
2025-08-04 20:30:00 -07:00
committed by PyTorch MergeBot
parent 9884d0351e
commit eb25a95a6e
9 changed files with 453 additions and 55 deletions

View File

@ -179,8 +179,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check_same("buf0")
# mm not use buf prior to wait_tensor
.check("extern_kernels.mm")
.check_not("buf0")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)

View File

@ -1745,10 +1745,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
_reorder_communication_preserving_peak_memory,
],
"allow_buffer_reuse": False,
"test_configs.track_memory_lifecycle": "error",
}
):
compiled = torch.compile(func)
compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# make sure memory tracking is codegen. the ops will then do runtime checking with assertion.
FileCheck().check("check_memory_step").check("tracked_empty_strided").run(code)
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unnecessary copy is made.
(

View File

@ -215,6 +215,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
@mock.patch.object(config, "allow_buffer_reuse", False)
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
@config.patch("test_configs.track_memory_lifecycle", "assert")
def test_mutation_size_propogation(self):
"""
This tests correct size propogation in the case of mutations.
@ -262,6 +263,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
buffer_info[buf_name] = (
buf.mpi_buffer.size_alloc,
buf.mpi_buffer.size_free,
buf.mpi_buffer.succ_nodes,
)
# test example and checks
@ -281,11 +283,15 @@ class TestOperatorReorderForPeakMemory(TestCase):
):
f_compiled = torch.compile(f)
f_compiled(a, p)
for buf_name in ["buf0", "buf2", "buf4", "buf6"]:
self.assertEqual(buffer_info[buf_name], (2048, 0))
for buf_name in ["buf1", "buf3", "buf5", "buf7"]:
self.assertEqual(buffer_info[buf_name], (0, 2048))
pre_mutation = ["buf0", "buf2", "buf4", "buf6"]
post_mutation = ["buf1", "buf3", "buf5", "buf7"]
for pre, post in zip(pre_mutation, post_mutation):
self.assertEqual(buffer_info[pre][0:2], (2048, 2048))
self.assertEqual(buffer_info[post][0:2], (0, 0))
# succ nodes should be forwarded to pre mutation buffer
self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2])
@unittest.skipIf(
not torch.cuda.is_available()
@ -359,6 +365,49 @@ class TestOperatorReorderForPeakMemory(TestCase):
.run(code)
)
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
def test_multiple_mutations_of_buf(self):
@torch.compile()
def foo(inp, inp2):
inp = inp @ inp
inp = inp.view(2, -1, 256)
x = inp[0]
y = inp[1]
x, y = torch._foreach_add([x, y], 1.0)
out = x.sum()
out2 = y.sum(dim=-1)
return out, out2, inp2 @ inp2
inp = torch.rand([256, 256], device="cuda")
inp2 = torch.rand([256, 256], device="cuda")
def replace_foreach(gm):
nodes = gm.find_nodes(
op="call_function", target=torch.ops.aten._foreach_add.Scalar
)
assert len(nodes) == 1
node = nodes[0]
nodes[0].target = torch.ops.aten._foreach_add_.Scalar
for inp, out in zip(node.args[0], list(node.users.keys())):
out.replace_all_uses_with(inp)
gm.erase_node(out)
with torch._inductor.config.patch(
{
"post_grad_custom_post_pass": replace_foreach,
"test_configs.track_memory_lifecycle": "assert",
"allow_buffer_reuse": False,
# make sure the mm is at the end so
# the earlier deallocation is not at the last step,
# which doesnt distinguish between returned tensors
# and which tensors are deallocated immediately prior
"reorder_for_peak_memory": False,
}
):
code = run_and_get_triton_code(foo, inp, inp2)
FileCheck().check("allocated=['buf0']").run(code)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -963,9 +963,12 @@ class PythonWrapperCodegen(CodeGen):
aot_config_comment = ""
if context is not None and context.aot_graph_name is not None:
aot_config_comment = f"# AOT ID: {context.aot_graph_name}"
aot_inductor_debug_utils = ""
inductor_debug_utils = ""
if int(config.aot_inductor.debug_intermediate_value_printer) > 0:
aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info"
inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info"
elif torch._inductor.config.test_configs.track_memory_lifecycle:
inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n"
self.imports.splice(
f"""
{aot_config_comment}
@ -983,7 +986,7 @@ class PythonWrapperCodegen(CodeGen):
from torch import device, empty_strided
from {async_compile.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
{aot_inductor_debug_utils}
{inductor_debug_utils}
""",
strip=True,
)
@ -2773,6 +2776,14 @@ class PythonWrapperCodegen(CodeGen):
buffer.get_name(), device, dtype, shape, stride, allocation_shape
)
@cache_on_self
def write_memory_track_allocation_once(self):
import_str = """
from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor
"""
if not V.graph.cpp_wrapper:
self.imports.splice(import_str, strip=True)
def make_allocation(
self, name, device, dtype, shape, stride, allocation_shape=None
):
@ -2784,7 +2795,16 @@ class PythonWrapperCodegen(CodeGen):
allocation_shape
)
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
if device.type in ("cpu", "cuda", "xpu", "mtia"):
if torch._inductor.config.test_configs.track_memory_lifecycle:
out = (
f"{name} = tracked_empty_strided("
f"{codegen_allocation_shape_tuple}, "
f"{codegen_stride_tuple}, "
f"dtype={dtype}, "
f"device='{device.type}', "
f"name='{name}')"
)
elif device.type in ("cpu", "cuda", "xpu", "mtia"):
# optimized path for faster allocations, saving ~2us versus the stuff below
out = (
f"{name} = empty_strided_{device.type}("

View File

@ -1861,6 +1861,8 @@ class test_configs:
graphsafe_rng_func_ignores_fallback_random = False
track_memory_lifecycle: Optional[Literal["assert", "log"]] = None
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -5324,6 +5324,11 @@ class ConcatKernel(NopKernel):
@ir_dataclass(frozen=False)
class ExternKernel(InputsKernel):
"""
A class that represents Kernels which are not directly lowered to Inductor
Loop Level IR, such as custom operators, or aten operators which we fallback to.
"""
constant_args: Sequence[Any] = ()
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None
@ -6120,6 +6125,17 @@ class ExternKernel(InputsKernel):
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
)
def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None:
"""
Track outputs of fallback operators if config.test_configs.track_memory_lifecycle
"""
if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper:
return
wrapper.write_memory_track_allocation_once()
name = self.get_name()
wrapper.writeline(f"track_tensor({name}, '{name}')")
def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]:
"""
get output sizes and strides, for template_codegen
@ -7579,6 +7595,7 @@ class FallbackKernel(ExternKernelAlloc):
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)
@ -7720,6 +7737,31 @@ class ComplexView(FallbackKernel):
)
class MemoryCheckKernel(FallbackKernel):
"""
Custom kernel for memory checking that generates direct function calls
TODO - the custom op was erroring with str inputs. should be able to custom op directly.
"""
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
"""Override codegen to write direct function call"""
# Extract our arguments from nontensor_args
wrapper.write_memory_track_allocation_once()
alive_list, dead_list, is_final_step = self.constant_args
alive_repr = repr(alive_list)
dead_repr = repr(dead_list)
if is_final_step:
wrapper.writeline(
"# note: dont currently distinguish between buffers returned and dealloc'd in last step"
)
call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})"
else:
call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})"
wrapper.writeline(call)
@ir_dataclass
class MultiOutputLayout(OutputSpec):
device: torch.device

View File

@ -124,6 +124,28 @@ def compute_size_for_scheduler_buffer(
buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed
buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed
When an operation mutates a buffer in-place, the scheduler creates a new buffer name
to track the "before" and "after" states, even though they share the same memory.
The mutated buffer represents a rename with zero allocation and deallocation cost.
During dependency tracking, we transfer dependencies from the mutated name back to
the original buffer, ensuring the original memory is only freed when all aliases
are done.
This handles cases where a buffer has multiple non-overlapping aliases - rather than
trying to assign free costs to individual aliases, we forward all alias dependencies
to the original buffer.
Consider:
buf0 = op0()
buf1 = mutation_op_(buf0)
del buf0
...
op(buf1)
del buf1
The only memory events are the creation prior to op0, and the deletion following buf1.
Returns:
A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free).
"""
@ -135,18 +157,11 @@ def compute_size_for_scheduler_buffer(
def _compute_and_update_buf_size(
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
) -> int:
if isinstance(sched_buf.node.layout, NoneLayout):
# mutations should inherit the size of the mutated buffer
if sched_buf.get_mutations():
mutated_buf_name = sched_buf.get_mutations()[0]
if mutated_buf_name in sched_buf_to_size:
(_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name]
else:
(_size_alloc, _size_free) = (0, 0)
sched_buf_to_size[sched_buf.get_name()] = (0, _size_free)
sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0)
else:
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
if sched_buf.get_name() in V.graph.scheduler.mutation_real_name:
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
return 0
elif isinstance(sched_buf.node.layout, NoneLayout):
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
return 0
elif isinstance(sched_buf.node.layout, MultiOutputLayout):
size_alloc = 0
@ -200,6 +215,14 @@ def assign_memory_planning_info_for_scheduler_buffers(
for dep in node.unmet_dependencies:
dep_name_to_succ_nodes[dep.name].add(node)
# iterate in reverse, so dependencies are picked up transitively.
for mutating_buf_name, real_buf_name in reversed(
V.graph.scheduler.mutation_real_name.items()
):
dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[
mutating_buf_name
]
# populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer
# note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs)
for buf_name in name_to_buf.keys():
@ -219,58 +242,72 @@ def assign_memory_planning_info_for_scheduler_nodes(
"""
Assign to each scheduler node its predecessor and successor nodes.
"""
from .scheduler import SchedulerBuffer
for index, node in enumerate(nodes):
size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs())
pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]()
for dep in node.read_writes.reads:
if dep.name in name_to_buf and dep in node.unmet_dependencies:
pred_buffers.add(name_to_buf[dep.name])
elif dep.name in name_to_freeable_input_buf:
pred_buffers.add(name_to_freeable_input_buf[dep.name])
pred_nodes = OrderedSet(
name_to_fused_node[pred_buffer.defining_op_name()]
for pred_buffer in pred_buffers
if (isinstance(pred_buffer, SchedulerBuffer))
)
node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {}
node_to_pred_buffers: dict[
BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer]
] = collections.defaultdict(OrderedSet)
# collect all predecessors using existing successor mappings
for node in nodes:
succ_nodes = OrderedSet(
succ_node
for buffer in node.get_outputs()
for succ_node in buffer.mpi_buffer.succ_nodes
)
node_to_succ_nodes[node] = succ_nodes
# For each successor, add current node as its predecessor
for succ_node in succ_nodes:
node_to_pred_nodes[succ_node].add(node)
# For each output buffer, add it as predecessor to its successor nodes
# TODO - is pred buffers needed ?
for buffer in node.get_outputs():
for succ_node in buffer.mpi_buffer.succ_nodes:
node_to_pred_buffers[succ_node].add(buffer)
for freeable_buffer in name_to_freeable_input_buf.values():
for succ_node in freeable_buffer.mpi_buffer.succ_nodes:
node_to_pred_buffers[succ_node].add(freeable_buffer)
# Second pass: assign memory planning info using completed predecessor mappings
for index, node in enumerate(nodes):
size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs())
succ_nodes = node_to_succ_nodes[node]
node.mpi_node = MemoryPlanningInfoForNode(
index=index,
size=size_alloc,
pred_buffers=pred_buffers,
pred_nodes=pred_nodes,
pred_buffers=node_to_pred_buffers[node],
pred_nodes=node_to_pred_nodes[node],
succ_nodes=succ_nodes,
)
def estimate_peak_memory(
# map each scheduler buffer to its size, start step, and end step
@dataclasses.dataclass
class BufferInfo:
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
size_alloc: int
size_free: int
start_step: int
end_step: int
def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[int, list[int]]:
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
"""
# map each scheduler buffer to its size, start step, and end step
@dataclasses.dataclass
class BufferInfo:
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
size_alloc: int
size_free: int
start_step: int
end_step: int
# get the execution step of each node, this will be used to determine
# the end_step of buffers
node_to_step: dict[BaseSchedulerNode, int] = {
@ -325,6 +362,27 @@ def estimate_peak_memory(
)
)
return buf_info_list, node_to_step
def estimate_peak_memory(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
"""
buf_info_list, _ = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
memory = [0 for _ in range(len(nodes) + 1)]

View File

@ -0,0 +1,138 @@
import functools
import logging
import threading
import weakref
import torch
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
local = threading.local()
local.memory_tracker = None
class BufferMemoryTracker:
"""
Tracks inductor runtime allocations and deallocations to compare against
expected behavior.
"""
def __init__(self) -> None:
self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = (
weakref.WeakValueDictionary() # type: ignore[assignment]
)
self.died_since_last_step: OrderedSet[str] = OrderedSet()
self.added_since_last_step: OrderedSet[str] = OrderedSet()
self.error = (
torch._inductor.config.test_configs.track_memory_lifecycle == "assert"
)
def set_tensor(self, name: str, tensor: torch.Tensor) -> None:
storage = tensor.untyped_storage()
self.added_since_last_step.add(name)
self.tensor_tracker[name] = storage
def on_tensor_death() -> None:
self.died_since_last_step.add(name)
weakref.finalize(storage, on_tensor_death)
def advance_step(self) -> None:
self.died_since_last_step.clear()
self.added_since_last_step.clear()
def log_or_raise(self, msg: str) -> None:
if self.error:
raise RuntimeError(msg)
else:
log.info(msg)
def check_step_delta(
self,
expected_allocated: list[str],
expected_freed: list[str],
is_final_step: bool,
) -> None:
"""Check only the delta changes since last step"""
# Check expected deaths - we dont currently distinguish between nodes which die in last step
# and are returned as outputs, so skip if final_step.
if not is_final_step:
missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step
if missing_deaths:
self.log_or_raise(
f"Expected tensors to die but still alive: {missing_deaths}"
)
# Check for unexpected deaths
unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed)
if unexpected_deaths:
self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}")
# Check newly alive tensors - separate messages like deaths
actual_allocated = self.added_since_last_step
expected_allocated_set = OrderedSet(expected_allocated)
extra_alive = actual_allocated - expected_allocated_set
if extra_alive:
self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}")
missing_alive = expected_allocated_set - actual_allocated
if missing_alive:
self.log_or_raise(
f"Expected allocated tensors but missing: {missing_alive}"
)
# Reset for next step
self.advance_step()
if is_final_step:
local.memory_tracker = None
def get_mem_tracker() -> BufferMemoryTracker:
if local.memory_tracker is None:
local.memory_tracker = BufferMemoryTracker()
return local.memory_tracker
def track_tensor(tensor: torch.Tensor, name: str) -> None:
get_mem_tracker().set_tensor(name, tensor)
def tracked_empty_strided(
size: list[int],
stride: list[int],
*,
dtype: torch.dtype,
device: torch.device,
name: str,
) -> torch.Tensor:
o = torch.empty_strided(size, stride, dtype=dtype, device=device)
track_tensor(o, name)
return o
def check_memory_step(
allocated: list[str], freed: list[str], is_final_step: bool = False
) -> None:
tracker = get_mem_tracker()
tracker.check_step_delta(allocated, freed, is_final_step)
@functools.lru_cache(None)
def register_check_mem_op() -> None:
lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901
lib.define(
"check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()"
)
lib.impl("check_memory_step", check_memory_step, "BackendSelect")
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
_register_effectful_op(
torch.ops._inductor_debug.check_memory_step.default,
_EffectType.ORDERED,
)

View File

@ -2184,6 +2184,10 @@ class Scheduler:
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
self.compute_last_usage()
if torch._inductor.config.test_configs.track_memory_lifecycle:
self.insert_memory_check_nodes()
log_ir_post_fusion(self.nodes)
V.debug.graph_diagram(self.nodes)
self.debug_draw_graph()
@ -2518,6 +2522,83 @@ class Scheduler:
compute_dependencies_log.debug("BUFFER USER LIST\n")
compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)
def insert_memory_check_nodes(self) -> None:
from .memory import (
assign_memory_planning_info_for_scheduler_buffers,
compute_memory_timeline,
FreeableInputBuffer,
get_freeable_input_buf,
)
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = (
get_freeable_input_buf(self.nodes, graph_inputs)
)
if not torch._inductor.config.reorder_for_peak_memory:
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _ = compute_memory_timeline(
self.nodes,
name_to_freeable_input_buf,
graph_outputs,
)
step_allocs_deallocs: list[tuple[list[str], list[str]]] = [
([], []) for _ in range(len(self.nodes))
]
for buf_info in buf_info_list:
# Skip zero-size buffers
if buf_info.size_alloc == 0 and buf_info.size_free == 0:
continue
buf_name = buf_info.buffer.get_name()
step_allocs_deallocs[buf_info.start_step][0].append(buf_name)
step_allocs_deallocs[buf_info.end_step][1].append(buf_name)
from torch._inductor.runtime.debug_utils import register_check_mem_op
register_check_mem_op()
def construct_mem_check_node(
step_idx: int, is_final_step: bool
) -> ExternKernelSchedulerNode:
expected_newly_alive = step_allocs_deallocs[step_idx][0]
expected_newly_dead = step_allocs_deallocs[step_idx][1]
nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step]
node = ir.MemoryCheckKernel(
layout=NoneLayout(device=torch.device("cpu")),
kernel=torch.ops._inductor_debug.check_memory_step.default,
tensor_args=[],
nontensor_args=nontensor_args,
unflatten_args=lambda tensor_args, constant_args: (
tensor_args,
{
"alive": constant_args[0],
"dead": constant_args[1],
"is_final_step": constant_args[2],
},
),
)
node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}"
return ExternKernelSchedulerNode(self, node)
new_nodes = []
for i, node in enumerate(self.nodes):
new_nodes.append(node)
new_nodes.append(
construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1))
)
self.nodes = new_nodes
def dead_node_elimination(self) -> None:
"""
Remove any nodes without users