Files
pytorch/torch/_inductor/runtime/debug_utils.py
eellison eb25a95a6e Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569)
With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline:

```
    When an operation mutates a buffer in-place, the scheduler creates a new buffer name
    to track the "before" and "after" states, even though they share the same memory.
    The mutated buffer represents a rename with zero allocation and deallocation cost.
    During dependency tracking, we transfer dependencies from the mutated name back to
    the original buffer, ensuring the original memory is only freed when all aliases
    are done.
    This handles cases where a buffer has multiple non-overlapping aliases - rather than
    trying to assign free costs to individual aliases, we forward all alias dependencies
    to the original buffer.
    Consider:
        buf0 = op0()
        buf1 = mutation_op_(buf0)
        del buf0
        ...
        op(buf1)
        del buf1
    The only memory events are the creation prior to op0, and the deletion following buf1.
```

As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect.

This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569
Approved by: https://github.com/IvanKobzarev
2025-08-05 19:58:11 +00:00

139 lines
4.2 KiB
Python

import functools
import logging
import threading
import weakref
import torch
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
local = threading.local()
local.memory_tracker = None
class BufferMemoryTracker:
"""
Tracks inductor runtime allocations and deallocations to compare against
expected behavior.
"""
def __init__(self) -> None:
self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = (
weakref.WeakValueDictionary() # type: ignore[assignment]
)
self.died_since_last_step: OrderedSet[str] = OrderedSet()
self.added_since_last_step: OrderedSet[str] = OrderedSet()
self.error = (
torch._inductor.config.test_configs.track_memory_lifecycle == "assert"
)
def set_tensor(self, name: str, tensor: torch.Tensor) -> None:
storage = tensor.untyped_storage()
self.added_since_last_step.add(name)
self.tensor_tracker[name] = storage
def on_tensor_death() -> None:
self.died_since_last_step.add(name)
weakref.finalize(storage, on_tensor_death)
def advance_step(self) -> None:
self.died_since_last_step.clear()
self.added_since_last_step.clear()
def log_or_raise(self, msg: str) -> None:
if self.error:
raise RuntimeError(msg)
else:
log.info(msg)
def check_step_delta(
self,
expected_allocated: list[str],
expected_freed: list[str],
is_final_step: bool,
) -> None:
"""Check only the delta changes since last step"""
# Check expected deaths - we dont currently distinguish between nodes which die in last step
# and are returned as outputs, so skip if final_step.
if not is_final_step:
missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step
if missing_deaths:
self.log_or_raise(
f"Expected tensors to die but still alive: {missing_deaths}"
)
# Check for unexpected deaths
unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed)
if unexpected_deaths:
self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}")
# Check newly alive tensors - separate messages like deaths
actual_allocated = self.added_since_last_step
expected_allocated_set = OrderedSet(expected_allocated)
extra_alive = actual_allocated - expected_allocated_set
if extra_alive:
self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}")
missing_alive = expected_allocated_set - actual_allocated
if missing_alive:
self.log_or_raise(
f"Expected allocated tensors but missing: {missing_alive}"
)
# Reset for next step
self.advance_step()
if is_final_step:
local.memory_tracker = None
def get_mem_tracker() -> BufferMemoryTracker:
if local.memory_tracker is None:
local.memory_tracker = BufferMemoryTracker()
return local.memory_tracker
def track_tensor(tensor: torch.Tensor, name: str) -> None:
get_mem_tracker().set_tensor(name, tensor)
def tracked_empty_strided(
size: list[int],
stride: list[int],
*,
dtype: torch.dtype,
device: torch.device,
name: str,
) -> torch.Tensor:
o = torch.empty_strided(size, stride, dtype=dtype, device=device)
track_tensor(o, name)
return o
def check_memory_step(
allocated: list[str], freed: list[str], is_final_step: bool = False
) -> None:
tracker = get_mem_tracker()
tracker.check_step_delta(allocated, freed, is_final_step)
@functools.lru_cache(None)
def register_check_mem_op() -> None:
lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901
lib.define(
"check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()"
)
lib.impl("check_memory_step", check_memory_step, "BackendSelect")
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
_register_effectful_op(
torch.ops._inductor_debug.check_memory_step.default,
_EffectType.ORDERED,
)