mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation. It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do. For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`. Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566 Approved by: https://github.com/weifengpy
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
import warnings
|
|
|
|
import torch
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
|
|
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
|
"""
|
|
Recursively extracts untyped storages from a tensor or its subclasses.
|
|
|
|
Args:
|
|
t (torch.Tensor): The tensor to extract storages from.
|
|
|
|
Returns:
|
|
Set[torch.UntypedStorage]: A set of untyped storages.
|
|
"""
|
|
unflattened_tensors = [t]
|
|
flattened_tensor_storages = set()
|
|
while len(unflattened_tensors) > 0:
|
|
obj = unflattened_tensors.pop()
|
|
if is_traceable_wrapper_subclass(obj):
|
|
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
|
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
|
|
else:
|
|
if not hasattr(obj, "untyped_storage"):
|
|
warnings.warn(
|
|
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
|
|
category=UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
else:
|
|
flattened_tensor_storages.add(obj.untyped_storage())
|
|
return flattened_tensor_storages
|