Files
pytorch/torch/distributed/_tools/common_utils.py
Sanket Purandare 9841f0ddcf Add support for non functional collectives under FakeTensorMode and fake_pg for memory tracking (#147566)
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation.

It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do.

For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`.
Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566
Approved by: https://github.com/weifengpy
2025-03-08 18:00:49 +00:00

34 lines
1.2 KiB
Python

import warnings
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
Args:
t (torch.Tensor): The tensor to extract storages from.
Returns:
Set[torch.UntypedStorage]: A set of untyped storages.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages