diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index 63ff2fa2bbfe..c05d5edae233 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -179,8 +179,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("torch.ops._c10d_functional.all_reduce_.default") - .check("torch.ops._c10d_functional.wait_tensor.default") + .check_same("buf0") + # mm not use buf prior to wait_tensor .check("extern_kernels.mm") + .check_not("buf0") + .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 856e1c5f7b3c..d0b8c32497f0 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1745,10 +1745,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): _reorder_communication_preserving_peak_memory, ], "allow_buffer_reuse": False, + "test_configs.track_memory_lifecycle": "error", } ): - compiled = torch.compile(func) + compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + + # make sure memory tracking is codegen. the ops will then do runtime checking with assertion. + FileCheck().check("check_memory_step").check("tracked_empty_strided").run(code) + # NOTE: The first return value should be the output of the first wait_tensor. # We want to make sure no unnecessary copy is made. ( diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 3e23442b38ec..2231b94316b3 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -215,6 +215,7 @@ class TestOperatorReorderForPeakMemory(TestCase): @mock.patch.object(config, "allow_buffer_reuse", False) @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + @config.patch("test_configs.track_memory_lifecycle", "assert") def test_mutation_size_propogation(self): """ This tests correct size propogation in the case of mutations. @@ -262,6 +263,7 @@ class TestOperatorReorderForPeakMemory(TestCase): buffer_info[buf_name] = ( buf.mpi_buffer.size_alloc, buf.mpi_buffer.size_free, + buf.mpi_buffer.succ_nodes, ) # test example and checks @@ -281,11 +283,15 @@ class TestOperatorReorderForPeakMemory(TestCase): ): f_compiled = torch.compile(f) f_compiled(a, p) - for buf_name in ["buf0", "buf2", "buf4", "buf6"]: - self.assertEqual(buffer_info[buf_name], (2048, 0)) - for buf_name in ["buf1", "buf3", "buf5", "buf7"]: - self.assertEqual(buffer_info[buf_name], (0, 2048)) + pre_mutation = ["buf0", "buf2", "buf4", "buf6"] + post_mutation = ["buf1", "buf3", "buf5", "buf7"] + + for pre, post in zip(pre_mutation, post_mutation): + self.assertEqual(buffer_info[pre][0:2], (2048, 2048)) + self.assertEqual(buffer_info[post][0:2], (0, 0)) + # succ nodes should be forwarded to pre mutation buffer + self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2]) @unittest.skipIf( not torch.cuda.is_available() @@ -359,6 +365,49 @@ class TestOperatorReorderForPeakMemory(TestCase): .run(code) ) + @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + def test_multiple_mutations_of_buf(self): + @torch.compile() + def foo(inp, inp2): + inp = inp @ inp + inp = inp.view(2, -1, 256) + x = inp[0] + y = inp[1] + x, y = torch._foreach_add([x, y], 1.0) + out = x.sum() + out2 = y.sum(dim=-1) + + return out, out2, inp2 @ inp2 + + inp = torch.rand([256, 256], device="cuda") + inp2 = torch.rand([256, 256], device="cuda") + + def replace_foreach(gm): + nodes = gm.find_nodes( + op="call_function", target=torch.ops.aten._foreach_add.Scalar + ) + assert len(nodes) == 1 + node = nodes[0] + nodes[0].target = torch.ops.aten._foreach_add_.Scalar + for inp, out in zip(node.args[0], list(node.users.keys())): + out.replace_all_uses_with(inp) + gm.erase_node(out) + + with torch._inductor.config.patch( + { + "post_grad_custom_post_pass": replace_foreach, + "test_configs.track_memory_lifecycle": "assert", + "allow_buffer_reuse": False, + # make sure the mm is at the end so + # the earlier deallocation is not at the last step, + # which doesnt distinguish between returned tensors + # and which tensors are deallocated immediately prior + "reorder_for_peak_memory": False, + } + ): + code = run_and_get_triton_code(foo, inp, inp2) + FileCheck().check("allocated=['buf0']").run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index f4370e619c1b..dd0316344099 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -963,9 +963,12 @@ class PythonWrapperCodegen(CodeGen): aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" - aot_inductor_debug_utils = "" + inductor_debug_utils = "" if int(config.aot_inductor.debug_intermediate_value_printer) > 0: - aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + elif torch._inductor.config.test_configs.track_memory_lifecycle: + inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n" + self.imports.splice( f""" {aot_config_comment} @@ -983,7 +986,7 @@ class PythonWrapperCodegen(CodeGen): from torch import device, empty_strided from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels - {aot_inductor_debug_utils} + {inductor_debug_utils} """, strip=True, ) @@ -2773,6 +2776,14 @@ class PythonWrapperCodegen(CodeGen): buffer.get_name(), device, dtype, shape, stride, allocation_shape ) + @cache_on_self + def write_memory_track_allocation_once(self): + import_str = """ + from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor + """ + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + def make_allocation( self, name, device, dtype, shape, stride, allocation_shape=None ): @@ -2784,7 +2795,16 @@ class PythonWrapperCodegen(CodeGen): allocation_shape ) codegen_stride_tuple = self.codegen_python_shape_tuple(stride) - if device.type in ("cpu", "cuda", "xpu", "mtia"): + if torch._inductor.config.test_configs.track_memory_lifecycle: + out = ( + f"{name} = tracked_empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"dtype={dtype}, " + f"device='{device.type}', " + f"name='{name}')" + ) + elif device.type in ("cpu", "cuda", "xpu", "mtia"): # optimized path for faster allocations, saving ~2us versus the stuff below out = ( f"{name} = empty_strided_{device.type}(" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e5b5fe224cc8..a42eb3cdeda9 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1861,6 +1861,8 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False + track_memory_lifecycle: Optional[Literal["assert", "log"]] = None + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a3bc472a129c..3f03c33d70da 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5324,6 +5324,11 @@ class ConcatKernel(NopKernel): @ir_dataclass(frozen=False) class ExternKernel(InputsKernel): + """ + A class that represents Kernels which are not directly lowered to Inductor + Loop Level IR, such as custom operators, or aten operators which we fallback to. + """ + constant_args: Sequence[Any] = () kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None @@ -6120,6 +6125,17 @@ class ExternKernel(InputsKernel): f"# buffer {name} (op: {op_name}) is assumed to be not aligned" ) + def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: + """ + Track outputs of fallback operators if config.test_configs.track_memory_lifecycle + """ + if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper: + return + + wrapper.write_memory_track_allocation_once() + name = self.get_name() + wrapper.writeline(f"track_tensor({name}, '{name}')") + def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: """ get output sizes and strides, for template_codegen @@ -7579,6 +7595,7 @@ class FallbackKernel(ExternKernelAlloc): if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) + self.codegen_memory_tracking(wrapper) self.codegen_unbacked_symbol_defs(wrapper) @@ -7720,6 +7737,31 @@ class ComplexView(FallbackKernel): ) +class MemoryCheckKernel(FallbackKernel): + """ + Custom kernel for memory checking that generates direct function calls + + TODO - the custom op was erroring with str inputs. should be able to custom op directly. + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Override codegen to write direct function call""" + # Extract our arguments from nontensor_args + wrapper.write_memory_track_allocation_once() + alive_list, dead_list, is_final_step = self.constant_args + + alive_repr = repr(alive_list) + dead_repr = repr(dead_list) + if is_final_step: + wrapper.writeline( + "# note: dont currently distinguish between buffers returned and dealloc'd in last step" + ) + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})" + else: + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})" + wrapper.writeline(call) + + @ir_dataclass class MultiOutputLayout(OutputSpec): device: torch.device diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index d287208419a9..0967bb553e04 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -124,6 +124,28 @@ def compute_size_for_scheduler_buffer( buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + When an operation mutates a buffer in-place, the scheduler creates a new buffer name + to track the "before" and "after" states, even though they share the same memory. + + The mutated buffer represents a rename with zero allocation and deallocation cost. + During dependency tracking, we transfer dependencies from the mutated name back to + the original buffer, ensuring the original memory is only freed when all aliases + are done. + + This handles cases where a buffer has multiple non-overlapping aliases - rather than + trying to assign free costs to individual aliases, we forward all alias dependencies + to the original buffer. + + Consider: + buf0 = op0() + buf1 = mutation_op_(buf0) + del buf0 + ... + op(buf1) + del buf1 + + The only memory events are the creation prior to op0, and the deletion following buf1. + Returns: A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). """ @@ -135,18 +157,11 @@ def compute_size_for_scheduler_buffer( def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False ) -> int: - if isinstance(sched_buf.node.layout, NoneLayout): - # mutations should inherit the size of the mutated buffer - if sched_buf.get_mutations(): - mutated_buf_name = sched_buf.get_mutations()[0] - if mutated_buf_name in sched_buf_to_size: - (_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name] - else: - (_size_alloc, _size_free) = (0, 0) - sched_buf_to_size[sched_buf.get_name()] = (0, _size_free) - sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0) - else: - sched_buf_to_size[sched_buf.get_name()] = (0, 0) + if sched_buf.get_name() in V.graph.scheduler.mutation_real_name: + sched_buf_to_size[sched_buf.get_name()] = (0, 0) + return 0 + elif isinstance(sched_buf.node.layout, NoneLayout): + sched_buf_to_size[sched_buf.get_name()] = (0, 0) return 0 elif isinstance(sched_buf.node.layout, MultiOutputLayout): size_alloc = 0 @@ -200,6 +215,14 @@ def assign_memory_planning_info_for_scheduler_buffers( for dep in node.unmet_dependencies: dep_name_to_succ_nodes[dep.name].add(node) + # iterate in reverse, so dependencies are picked up transitively. + for mutating_buf_name, real_buf_name in reversed( + V.graph.scheduler.mutation_real_name.items() + ): + dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[ + mutating_buf_name + ] + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) for buf_name in name_to_buf.keys(): @@ -219,58 +242,72 @@ def assign_memory_planning_info_for_scheduler_nodes( """ Assign to each scheduler node its predecessor and successor nodes. """ - from .scheduler import SchedulerBuffer - for index, node in enumerate(nodes): - size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) - pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]() - for dep in node.read_writes.reads: - if dep.name in name_to_buf and dep in node.unmet_dependencies: - pred_buffers.add(name_to_buf[dep.name]) - elif dep.name in name_to_freeable_input_buf: - pred_buffers.add(name_to_freeable_input_buf[dep.name]) - pred_nodes = OrderedSet( - name_to_fused_node[pred_buffer.defining_op_name()] - for pred_buffer in pred_buffers - if (isinstance(pred_buffer, SchedulerBuffer)) - ) + node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {} + node_to_pred_buffers: dict[ + BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer] + ] = collections.defaultdict(OrderedSet) + + # collect all predecessors using existing successor mappings + for node in nodes: succ_nodes = OrderedSet( succ_node for buffer in node.get_outputs() for succ_node in buffer.mpi_buffer.succ_nodes ) + node_to_succ_nodes[node] = succ_nodes + + # For each successor, add current node as its predecessor + for succ_node in succ_nodes: + node_to_pred_nodes[succ_node].add(node) + + # For each output buffer, add it as predecessor to its successor nodes + # TODO - is pred buffers needed ? + for buffer in node.get_outputs(): + for succ_node in buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(buffer) + + for freeable_buffer in name_to_freeable_input_buf.values(): + for succ_node in freeable_buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(freeable_buffer) + + # Second pass: assign memory planning info using completed predecessor mappings + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + succ_nodes = node_to_succ_nodes[node] + node.mpi_node = MemoryPlanningInfoForNode( index=index, size=size_alloc, - pred_buffers=pred_buffers, - pred_nodes=pred_nodes, + pred_buffers=node_to_pred_buffers[node], + pred_nodes=node_to_pred_nodes[node], succ_nodes=succ_nodes, ) -def estimate_peak_memory( +# map each scheduler buffer to its size, start step, and end step +@dataclasses.dataclass +class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + +def compute_memory_timeline( nodes: list[BaseSchedulerNode], name_to_freeable_input_buf: dict[str, FreeableInputBuffer], graph_outputs: OrderedSet[str], -) -> tuple[int, list[int]]: +) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]: """ - Given a list of nodes in their execution order, estimate the peak memory, by - keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. - - Returns: - int: peak memory - List[int]: memory usage at each node (or each step). + Compute buffer allocation and deallocation sizes and map their + lifetime to the node schedule """ - # map each scheduler buffer to its size, start step, and end step - @dataclasses.dataclass - class BufferInfo: - buffer: Union[SchedulerBuffer, FreeableInputBuffer] - size_alloc: int - size_free: int - start_step: int - end_step: int - # get the execution step of each node, this will be used to determine # the end_step of buffers node_to_step: dict[BaseSchedulerNode, int] = { @@ -325,6 +362,27 @@ def estimate_peak_memory( ) ) + return buf_info_list, node_to_step + + +def estimate_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[int, list[int]]: + """ + Given a list of nodes in their execution order, estimate the peak memory, by + keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. + + Returns: + int: peak memory + List[int]: memory usage at each node (or each step). + """ + + buf_info_list, _ = compute_memory_timeline( + nodes, name_to_freeable_input_buf, graph_outputs + ) + # incremental memory changes at each step memory = [0 for _ in range(len(nodes) + 1)] diff --git a/torch/_inductor/runtime/debug_utils.py b/torch/_inductor/runtime/debug_utils.py new file mode 100644 index 000000000000..9c15ff890dda --- /dev/null +++ b/torch/_inductor/runtime/debug_utils.py @@ -0,0 +1,138 @@ +import functools +import logging +import threading +import weakref + +import torch +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + +local = threading.local() +local.memory_tracker = None + + +class BufferMemoryTracker: + """ + Tracks inductor runtime allocations and deallocations to compare against + expected behavior. + """ + + def __init__(self) -> None: + self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = ( + weakref.WeakValueDictionary() # type: ignore[assignment] + ) + self.died_since_last_step: OrderedSet[str] = OrderedSet() + self.added_since_last_step: OrderedSet[str] = OrderedSet() + self.error = ( + torch._inductor.config.test_configs.track_memory_lifecycle == "assert" + ) + + def set_tensor(self, name: str, tensor: torch.Tensor) -> None: + storage = tensor.untyped_storage() + + self.added_since_last_step.add(name) + self.tensor_tracker[name] = storage + + def on_tensor_death() -> None: + self.died_since_last_step.add(name) + + weakref.finalize(storage, on_tensor_death) + + def advance_step(self) -> None: + self.died_since_last_step.clear() + self.added_since_last_step.clear() + + def log_or_raise(self, msg: str) -> None: + if self.error: + raise RuntimeError(msg) + else: + log.info(msg) + + def check_step_delta( + self, + expected_allocated: list[str], + expected_freed: list[str], + is_final_step: bool, + ) -> None: + """Check only the delta changes since last step""" + + # Check expected deaths - we dont currently distinguish between nodes which die in last step + # and are returned as outputs, so skip if final_step. + if not is_final_step: + missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step + if missing_deaths: + self.log_or_raise( + f"Expected tensors to die but still alive: {missing_deaths}" + ) + + # Check for unexpected deaths + unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed) + if unexpected_deaths: + self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}") + + # Check newly alive tensors - separate messages like deaths + actual_allocated = self.added_since_last_step + expected_allocated_set = OrderedSet(expected_allocated) + + extra_alive = actual_allocated - expected_allocated_set + if extra_alive: + self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}") + + missing_alive = expected_allocated_set - actual_allocated + if missing_alive: + self.log_or_raise( + f"Expected allocated tensors but missing: {missing_alive}" + ) + + # Reset for next step + self.advance_step() + + if is_final_step: + local.memory_tracker = None + + +def get_mem_tracker() -> BufferMemoryTracker: + if local.memory_tracker is None: + local.memory_tracker = BufferMemoryTracker() + return local.memory_tracker + + +def track_tensor(tensor: torch.Tensor, name: str) -> None: + get_mem_tracker().set_tensor(name, tensor) + + +def tracked_empty_strided( + size: list[int], + stride: list[int], + *, + dtype: torch.dtype, + device: torch.device, + name: str, +) -> torch.Tensor: + o = torch.empty_strided(size, stride, dtype=dtype, device=device) + track_tensor(o, name) + return o + + +def check_memory_step( + allocated: list[str], freed: list[str], is_final_step: bool = False +) -> None: + tracker = get_mem_tracker() + tracker.check_step_delta(allocated, freed, is_final_step) + + +@functools.lru_cache(None) +def register_check_mem_op() -> None: + lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901 + lib.define( + "check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()" + ) + lib.impl("check_memory_step", check_memory_step, "BackendSelect") + from torch._higher_order_ops.effects import _EffectType, _register_effectful_op + + _register_effectful_op( + torch.ops._inductor_debug.check_memory_step.default, + _EffectType.ORDERED, + ) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 951f07ab7a5b..abd2fe413d1a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2184,6 +2184,10 @@ class Scheduler: self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) self.compute_last_usage() + + if torch._inductor.config.test_configs.track_memory_lifecycle: + self.insert_memory_check_nodes() + log_ir_post_fusion(self.nodes) V.debug.graph_diagram(self.nodes) self.debug_draw_graph() @@ -2518,6 +2522,83 @@ class Scheduler: compute_dependencies_log.debug("BUFFER USER LIST\n") compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def insert_memory_check_nodes(self) -> None: + from .memory import ( + assign_memory_planning_info_for_scheduler_buffers, + compute_memory_timeline, + FreeableInputBuffer, + get_freeable_input_buf, + ) + + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = ( + get_freeable_input_buf(self.nodes, graph_inputs) + ) + + if not torch._inductor.config.reorder_for_peak_memory: + assign_memory_planning_info_for_scheduler_buffers( + self.nodes, self.name_to_buf + ) + + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + buf_info_list, _ = compute_memory_timeline( + self.nodes, + name_to_freeable_input_buf, + graph_outputs, + ) + + step_allocs_deallocs: list[tuple[list[str], list[str]]] = [ + ([], []) for _ in range(len(self.nodes)) + ] + for buf_info in buf_info_list: + # Skip zero-size buffers + if buf_info.size_alloc == 0 and buf_info.size_free == 0: + continue + + buf_name = buf_info.buffer.get_name() + + step_allocs_deallocs[buf_info.start_step][0].append(buf_name) + step_allocs_deallocs[buf_info.end_step][1].append(buf_name) + + from torch._inductor.runtime.debug_utils import register_check_mem_op + + register_check_mem_op() + + def construct_mem_check_node( + step_idx: int, is_final_step: bool + ) -> ExternKernelSchedulerNode: + expected_newly_alive = step_allocs_deallocs[step_idx][0] + expected_newly_dead = step_allocs_deallocs[step_idx][1] + + nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step] + + node = ir.MemoryCheckKernel( + layout=NoneLayout(device=torch.device("cpu")), + kernel=torch.ops._inductor_debug.check_memory_step.default, + tensor_args=[], + nontensor_args=nontensor_args, + unflatten_args=lambda tensor_args, constant_args: ( + tensor_args, + { + "alive": constant_args[0], + "dead": constant_args[1], + "is_final_step": constant_args[2], + }, + ), + ) + node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}" + return ExternKernelSchedulerNode(self, node) + + new_nodes = [] + + for i, node in enumerate(self.nodes): + new_nodes.append(node) + new_nodes.append( + construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1)) + ) + + self.nodes = new_nodes + def dead_node_elimination(self) -> None: """ Remove any nodes without users