Add support for non functional collectives under FakeTensorMode and fake_pg for memory tracking (#147566)

This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation.

It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do.

For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`.
Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566
Approved by: https://github.com/weifengpy
This commit is contained in:
Sanket Purandare
2025-03-08 18:00:49 +00:00
committed by PyTorch MergeBot
parent 439782960c
commit 9841f0ddcf
11 changed files with 843 additions and 331 deletions

View File

@ -0,0 +1,216 @@
# Owner(s): ["module: unknown"]
import unittest
import torch
import torch.distributed as dist
from torch._C._distributed_c10d import FakeWork, ProcessGroup
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._functional_collectives import (
all_gather_into_tensor_coalesced,
all_gather_tensor,
all_gather_tensor_autograd,
all_reduce,
all_reduce_coalesced,
all_to_all_single,
all_to_all_single_autograd,
broadcast,
reduce_scatter_tensor,
reduce_scatter_tensor_autograd,
reduce_scatter_tensor_coalesced,
wait_tensor,
)
from torch.distributed._tools.fake_collectives import (
collective_ops,
CollectiveOp,
non_functional_collectives,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
c10d = torch.ops.c10d
_c10d_functional = torch.ops._c10d_functional
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
class TestFakeCollectives(TestCase):
def _setup_distributed(self):
world_size = 4
store = FakeStore()
dist.init_process_group("fake", rank=0, world_size=world_size, store=store)
torch.cuda.set_device(torch.cuda.current_device())
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_collectives(self):
try:
self._setup_distributed()
with FakeTensorMode(), CollectiveTest(test=self):
test_tensor_list = [torch.randn(100, device="cuda") for _ in range(4)]
test_tensor_list_2 = [torch.randn(400, device="cuda") for _ in range(4)]
test_tensor = torch.randn(100, device="cuda")
# Used as gather output or scatter input
test_tensor2 = torch.randn(400, device="cuda")
# Testing non-functional collective operations
dist.broadcast(test_tensor, src=0)
dist.all_reduce(test_tensor)
dist.reduce(test_tensor, dst=0)
dist.send(test_tensor, dst=1)
dist.recv(test_tensor, src=1)
dist.all_gather(test_tensor_list, test_tensor)
dist.reduce_scatter(test_tensor, test_tensor_list)
dist.reduce_scatter_tensor(test_tensor, test_tensor2)
dist.scatter(test_tensor, scatter_list=test_tensor_list, src=0)
dist.gather(test_tensor, gather_list=test_tensor_list, dst=0)
dist.all_gather_into_tensor(test_tensor2, test_tensor)
dist.all_to_all(test_tensor_list, test_tensor_list)
dist.all_to_all_single(test_tensor2, test_tensor2)
dist.barrier()
# Testing functional collectives
wait_tensor(test_tensor)
broadcast(test_tensor, src=0, group=dist.group.WORLD)
all_reduce(test_tensor, reduceOp="avg", group=dist.group.WORLD)
all_gather_tensor(test_tensor, gather_dim=0, group=dist.group.WORLD)
all_gather_tensor_autograd(
test_tensor, gather_dim=0, group=dist.group.WORLD
)
reduce_scatter_tensor(
test_tensor2, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD
)
reduce_scatter_tensor_autograd(
test_tensor2, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD
)
all_to_all_single(
test_tensor,
output_split_sizes=[0],
input_split_sizes=[1],
group=dist.group.WORLD,
)
all_reduce_coalesced(
test_tensor_list, reduceOp="avg", group=dist.group.WORLD
)
all_gather_into_tensor_coalesced(
test_tensor_list, group=dist.group.WORLD
)
reduce_scatter_tensor_coalesced(
test_tensor_list_2,
scatter_dim=[0] * 4,
reduceOp="sum",
group=dist.group.WORLD,
)
all_to_all_single_autograd(
test_tensor,
output_split_sizes=[0],
input_split_sizes=[1],
group=dist.group.WORLD,
)
finally:
if dist.group.WORLD is not None:
dist.destroy_process_group()
class CollectiveTest(TorchDispatchMode):
collective_size_exclude = {
c10d.barrier.default,
c10d.monitored_barrier_.default,
_c10d_functional.wait_tensor.default,
}
def __init__(self, test: TestFakeCollectives, _dispatch_key=None):
super().__init__(_dispatch_key)
self.test = test
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
res = func(*args, **(kwargs or {}))
if func in collective_ops:
if func != _c10d_functional.wait_tensor.default:
pg = CollectiveOp.get_process_group(func, args)
self.test.assertIsInstance(
pg, ProcessGroup, "Error: pg is not an instance of ProcessGroup"
)
self.test.assertEqual(
pg, dist.group.WORLD, "Error: pg is not equal to dist.group.WORLD"
)
self.test.assertEqual(
pg.size(),
4,
f"Error: Expected pg.size() to be 4, but got {pg.size()}",
)
self.test.assertNotEqual(
pg.name(), "", "Error: pg.name() should not be an empty string"
)
if func not in CollectiveTest.collective_size_exclude:
# Compute expected communication tensor size
computed_size = CollectiveOp.get_comm_tensor_size(
func, res, args, kwargs
)
expected_size = self.get_expected_size(func, res, args, kwargs)
self.test.assertEqual(
computed_size,
expected_size,
msg=f"Size mismatch for {func.__name__}: expected {expected_size}, got {computed_size}",
)
if (
func in non_functional_collectives
and func != c10d.monitored_barrier_.default
):
work = res[-1] if isinstance(res, (tuple, list)) else res
self.test.assertIsInstance(FakeWork.unbox(work), FakeWork)
return res
@staticmethod
def get_expected_size(func, res, args, kwargs):
"""Return expected tensor size for collectives explicitly used in run_test()."""
WORLD_SIZE, TENSOR_100, TENSOR_400 = 4, 100 * 4, 400 * 4
TENSOR_LIST_100, TENSOR_LIST_400 = (
WORLD_SIZE * TENSOR_100,
WORLD_SIZE * TENSOR_400,
)
size_map = {
# Non-functional collectives
c10d.broadcast_.default: TENSOR_100,
c10d.allreduce_.default: TENSOR_100,
c10d.reduce_.default: TENSOR_100,
c10d.send.default: TENSOR_100,
c10d.recv_.default: TENSOR_100,
c10d.allgather_.default: TENSOR_LIST_100,
c10d.reduce_scatter_.default: TENSOR_LIST_100,
c10d._reduce_scatter_base_.default: TENSOR_400,
c10d.scatter_.default: TENSOR_LIST_100,
c10d.gather_.default: TENSOR_LIST_100,
c10d._allgather_base_.default: TENSOR_400,
c10d.alltoall_.default: TENSOR_LIST_100,
c10d.alltoall_base_.default: TENSOR_400,
# Functional collectives
_c10d_functional.broadcast.default: TENSOR_100,
_c10d_functional.all_reduce.default: TENSOR_100,
_c10d_functional.all_gather_into_tensor.default: TENSOR_LIST_100,
_c10d_functional_autograd.all_gather_into_tensor.default: TENSOR_LIST_100,
_c10d_functional.reduce_scatter_tensor.default: TENSOR_400,
_c10d_functional_autograd.reduce_scatter_tensor.default: TENSOR_400,
_c10d_functional.all_to_all_single.default: TENSOR_100,
_c10d_functional_autograd.all_to_all_single.default: TENSOR_100,
_c10d_functional.all_reduce_coalesced.default: TENSOR_LIST_100,
_c10d_functional.all_gather_into_tensor_coalesced.default: TENSOR_LIST_400,
_c10d_functional.reduce_scatter_tensor_coalesced.default: TENSOR_LIST_100,
}
if func in size_map:
return size_map[func]
raise ValueError(f"Unhandled function: {func}")
if __name__ == "__main__":
run_tests()

View File

@ -192,6 +192,8 @@ class TestModTracker(TestCase):
("post_fw", "Foo.linears.1", True, True),
("post_bw", "Foo.linears.1", True, True),
("pre_bw", "Foo.linears.0", True, True),
("post_bw", "Foo.linears.0", True, True),
("post_bw", "Foo", True, True),
]
self.assertEqual(test_op, expected_op)

View File

@ -535,6 +535,12 @@ class ProcessGroup:
class FakeProcessGroup(Backend):
def __init__(self, rank: int, world_size: int) -> None: ...
class FakeWork(Work):
seq_id: int
def __init__(self) -> None: ...
def wait(self, timeout: timedelta = ...) -> bool: ...
def getFuture(self) -> Future: ...
class ProcessGroupGloo(Backend):
class Device: ...

View File

@ -6,7 +6,8 @@ namespace c10d {
class FakeWork : public Work {
public:
bool wait(std::chrono::milliseconds timeout) override {
int seq_id = -1;
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
return true;
}
@ -177,6 +178,18 @@ class FakeProcessGroup : public Backend {
return c10::make_intrusive<FakeWork>();
}
void startCoalescing() override {
// No-op
}
c10::intrusive_ptr<Work> endCoalescing(OpType /* optype */) {
return c10::make_intrusive<FakeWork>();
}
c10::intrusive_ptr<Work> endCoalescing() override {
return c10::make_intrusive<FakeWork>();
}
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& /* opts */ = BarrierOptions()) override {
return c10::make_intrusive<FakeWork>();

View File

@ -3221,67 +3221,68 @@ Example::
.def_readonly("time_started", &::c10d::WorkInfo::timeStarted)
.def_readonly("time_finished", &::c10d::WorkInfo::timeFinished)
.def_readonly("active_duration", &::c10d::WorkInfo::activeDuration);
py::class_<
::c10d::Work,
c10::intrusive_ptr<::c10d::Work>,
::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
auto work =
py::class_<
::c10d::Work,
c10::intrusive_ptr<::c10d::Work>,
::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
A `Work` object represents the handle to a pending asynchronous operation in
PyTorch's distributed package. It is returned by non-blocking collective operations,
such as `dist.all_reduce(tensor, async_op=True)`.
)")
.def(py::init<>())
.def("is_completed", &::c10d::Work::isCompleted)
.def(
"is_success",
[](::c10d::Work& work) -> bool {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::is_success"));
return work.isSuccess();
})
.def(
"exception",
[](::c10d::Work& work) -> std::exception_ptr {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::exception"));
return work.exception();
})
.def(
"source_rank",
[](::c10d::Work& work) -> int {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::source_rank"));
return work.sourceRank();
})
.def("_source_rank", &::c10d::Work::sourceRank)
.def(
"result",
[](::c10d::Work& work) -> std::vector<at::Tensor> {
// Deprecation reason:
// Work.result() returns a vector of tensors. This signature is
// problematic as some collectives may just return one tensor
// (e.g all-reduce), while some others may return multiple
// tensors (e.g. all-gather).
// Deprecating work.result() would
// also allow us to remove the `outputs_` field in the Work
// class, avoiding an "artificial" reference to the tensors,
// which could potentially hold up the tensors' memory.
TORCH_WARN_ONCE(fmt::format(kDeprecationWarning, "Work::result"));
return work.result();
})
.def(
"synchronize",
[](::c10d::Work& work) -> void {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::synchronize"));
work.synchronize();
})
.def(
"wait",
&::c10d::Work::wait,
py::arg("timeout") = kNoTimeout,
py::call_guard<py::gil_scoped_release>(),
R"(
.def(py::init<>())
.def("is_completed", &::c10d::Work::isCompleted)
.def(
"is_success",
[](::c10d::Work& work) -> bool {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::is_success"));
return work.isSuccess();
})
.def(
"exception",
[](::c10d::Work& work) -> std::exception_ptr {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::exception"));
return work.exception();
})
.def(
"source_rank",
[](::c10d::Work& work) -> int {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::source_rank"));
return work.sourceRank();
})
.def("_source_rank", &::c10d::Work::sourceRank)
.def(
"result",
[](::c10d::Work& work) -> std::vector<at::Tensor> {
// Deprecation reason:
// Work.result() returns a vector of tensors. This signature is
// problematic as some collectives may just return one tensor
// (e.g all-reduce), while some others may return multiple
// tensors (e.g. all-gather).
// Deprecating work.result() would
// also allow us to remove the `outputs_` field in the Work
// class, avoiding an "artificial" reference to the tensors,
// which could potentially hold up the tensors' memory.
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::result"));
return work.result();
})
.def(
"synchronize",
[](::c10d::Work& work) -> void {
TORCH_WARN_ONCE(
fmt::format(kDeprecationWarning, "Work::synchronize"));
work.synchronize();
})
.def(
"wait",
&::c10d::Work::wait,
py::arg("timeout") = kNoTimeout,
py::call_guard<py::gil_scoped_release>(),
R"(
Returns:
true/false.
@ -3298,13 +3299,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
or timed out. If timeout, exception will be thrown.
)")
.def(
"get_future_result",
[](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(
work.getFutureResult());
},
R"(
.def(
"get_future_result",
[](::c10d::Work& work)
-> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(
work.getFutureResult());
},
R"(
Returns:
A ``torch.futures.Future`` object of int type which maps to the enum type of WorkResult
As an example, a future object can be retrieved
@ -3319,12 +3321,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
.. warning ::
``get_future_result`` API supports NCCL
)")
.def(
"get_future",
[](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
},
R"(
.def(
"get_future",
[](::c10d::Work& work)
-> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(
work.getFuture());
},
R"(
Returns:
A ``torch.futures.Future`` object which is associated with the completion of
the ``Work``. As an example, a future object can be retrieved
@ -3363,16 +3367,16 @@ such as `dist.all_reduce(tensor, async_op=True)`.
true when tensors have arrived on respective nodes, but not yet necessarily synched on
respective GPUs (similarly to GPU work).
)")
.def(
"_get_op_type",
[](::c10d::Work& work) -> int {
return static_cast<int>(work.retrieveOpType());
})
.def(
"_get_duration",
&::c10d::Work::getDuration,
py::call_guard<py::gil_scoped_release>(),
R"(
.def(
"_get_op_type",
[](::c10d::Work& work) -> int {
return static_cast<int>(work.retrieveOpType());
})
.def(
"_get_duration",
&::c10d::Work::getDuration,
py::call_guard<py::gil_scoped_release>(),
R"(
Returns:
Duration of the corresponding collective communication.
@ -3380,17 +3384,17 @@ such as `dist.all_reduce(tensor, async_op=True)`.
This API only works for NCCL backend for now and must set
TORCH_NCCL_ENABLE_TIMING environment variable.
)")
.def(
"boxed",
[](c10::intrusive_ptr<::c10d::Work> self) {
return torch::jit::toPyObject(c10::IValue(std::move(self)));
})
.def_static("unbox", [](py::object obj) {
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.c10d.Work");
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
return ivalue.toCustomClass<::c10d::Work>();
});
.def(
"boxed",
[](c10::intrusive_ptr<::c10d::Work> self) {
return torch::jit::toPyObject(c10::IValue(std::move(self)));
})
.def_static("unbox", [](py::object obj) {
auto typePtr =
torch::getCustomClass("__torch__.torch.classes.c10d.Work");
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
return ivalue.toCustomClass<::c10d::Work>();
});
auto fakeProcessGroup =
intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>(
@ -3402,6 +3406,13 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("rank"),
py::arg("world_size"));
auto fakeWork =
intrusive_ptr_no_gil_destructor_class_<::c10d::FakeWork>(
module, "FakeWork", work)
.def(py::init<>())
.def_readwrite("seq_id", &::c10d::FakeWork::seq_id) // Expose seq_id
.def("wait", &::c10d::FakeWork::wait, py::arg("timeout") = kNoTimeout)
.def("getFuture", &::c10d::FakeWork::getFuture);
py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
.def(py::init<>())

View File

@ -0,0 +1,33 @@
import warnings
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
Args:
t (torch.Tensor): The tensor to extract storages from.
Returns:
Set[torch.UntypedStorage]: A set of untyped storages.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages

View File

@ -0,0 +1,307 @@
import random
from typing import Any
import torch
from torch._C._distributed_c10d import (
_resolve_process_group,
FakeWork,
ProcessGroup,
Work,
)
from torch.utils._pytree import tree_map_only
torch.distributed.batch_isend_irecv
c10d = torch.ops.c10d
_c10d_functional = torch.ops._c10d_functional
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
_dtensor = torch.ops._dtensor
used_ids: set[int] = set()
def generate_unique_id() -> int:
while True:
new_id = random.randint(1, 10**9)
if new_id not in used_ids:
used_ids.add(new_id)
return new_id
# Function to create and return FakeWork object
def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def]
work = FakeWork()
work.seq_id = generate_unique_id()
fakework_script_obj = work.boxed()
return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj
# Dictionary mapping collective operations to their meta functions
# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included
# _DEPRECATED_META_FUNCTIONS = {
# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# }
_META_FUNCTIONS = {
"broadcast_": lambda *args: create_fakework(args),
"allreduce_": lambda *args: create_fakework(args),
"allgather_": lambda *args: create_fakework(args),
"_allgather_base_": lambda *args: create_fakework(args),
"reduce_scatter_": lambda *args: create_fakework(args),
"_reduce_scatter_base_": lambda *args: create_fakework(args),
"reduce_": lambda *args: create_fakework(args, return_first_arg=False),
"gather_": lambda *args: create_fakework(args, return_first_arg=False),
"scatter_": lambda *args: create_fakework(args),
"alltoall_": lambda *args: create_fakework(args),
"alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False),
"barrier": lambda *args: create_fakework(args, return_first_arg=False),
"monitored_barrier_": lambda *args: None,
"send": lambda *args: create_fakework(args, return_first_arg=False),
"recv_": lambda *args: create_fakework(args, return_first_arg=False),
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
}
if not torch._running_with_deploy():
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
for op, meta_func in _META_FUNCTIONS.items():
lib_impl.impl(op, meta_func, "Meta")
# List of collective operation functions including functional collectives
# Note: The following collectives might be deprecated soon hence not adding them
# depcreated_non_functional_collectives = [
# c10d.allreduce_coalesced_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default,
# c10d.allgather_coalesced_.default,
# ]
non_functional_collectives: set[torch._ops.OpOverload] = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
c10d.monitored_barrier_.default,
}
functional_collectives: set[torch._ops.OpOverload] = {
_c10d_functional.broadcast.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_c10d_functional.wait_tensor.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.broadcast_.default,
_dtensor.shard_dim_alltoall.default,
}
sync_ops: set[torch._ops.OpOverload] = {
c10d.barrier.default,
c10d.monitored_barrier_.default,
_c10d_functional.wait_tensor.default,
}
collective_ops = set.union(functional_collectives, non_functional_collectives)
class CollectiveOp:
# Static sets for performance optimization
PG_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.barrier.default,
# c10d.allreduce_coalesced_.default
}
PG_ARG_2 = {
c10d.allgather_.default,
c10d._allgather_base_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
# c10d.allgather_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default
# c10d.reduce_scatter_tensor_coalesced_.default
}
PG_ARG_3 = {
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
}
PG_ARG_4 = {
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
WK_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
}
WK = {
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.reduce_.default,
c10d.gather_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
}
COMM_TENSOR_ARG_0 = {
c10d.allreduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.gather_.default,
c10d.reduce_.default,
c10d.broadcast_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
# c10d.allreduce_coalesced_.default
# c10d.allgather_coalesced_.default
# c10d.allgather_into_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_1 = {
c10d.reduce_scatter_.default,
c10d.scatter_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_RES = {
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
}
COMM_TENSOR_SINGLE_UNTYPED_STORAGE = {
c10d._allgather_base_.default,
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
}
COMM_TENSOR_ARG_0_AND_RES = {
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
COMM_TENSOR_RES_SUM = {
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
}
@staticmethod
def sum_tensors(arg: Any) -> int:
"""Calculate total memory consumed by the tensors in the argument."""
total_memory = 0
def sum_bytes(t: torch.Tensor) -> None:
nonlocal total_memory
total_memory += t.untyped_storage().nbytes()
tree_map_only(torch.Tensor, sum_bytes, arg)
return total_memory
@staticmethod
def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def]
"""Retrieve the process group for collective operations, except `wait_tensor`."""
if func in CollectiveOp.PG_ARG_1:
return ProcessGroup.unbox(args[1])
if func in CollectiveOp.PG_ARG_2:
return ProcessGroup.unbox(args[2])
if func in CollectiveOp.PG_ARG_3:
return _resolve_process_group(args[2])
if func in CollectiveOp.PG_ARG_4:
return _resolve_process_group(args[3])
raise TypeError(f"Func {func} not found in {collective_ops}")
@staticmethod
def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def]
"""Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`."""
if func in CollectiveOp.COMM_TENSOR_ARG_0:
return CollectiveOp.sum_tensors(args[0])
if func in CollectiveOp.COMM_TENSOR_ARG_1:
return CollectiveOp.sum_tensors(args[1])
if func in CollectiveOp.COMM_TENSOR_ARG_RES:
return res.untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE:
return args[0].untyped_storage().nbytes()
if func == c10d._reduce_scatter_base_.default:
return args[1].untyped_storage().nbytes()
if func == c10d.alltoall_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1])
)
if func == c10d.alltoall_base_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes()
)
if func == _c10d_functional.all_gather_into_tensor_out.default:
return args[-1].untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_RES_SUM:
return CollectiveOp.sum_tensors(res)
if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES:
# TODO(@sanketpurandare) - Confirm size computation
return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes()
raise TypeError(f"Unknown function: {func} in {collective_ops}")
@staticmethod
def get_work(func, res) -> Work: # type: ignore[no-untyped-def]
if func in CollectiveOp.WK:
return FakeWork.unbox(res)
elif func in CollectiveOp.WK_ARG_1:
return FakeWork.unbox(res[1])
raise TypeError(f"Func {func} not found in {collective_ops}")

View File

@ -1,23 +1,16 @@
from copy import deepcopy
from datetime import timedelta
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
import torch.distributed as dist
import torch.distributed._tools.fake_collectives
from torch import nn, optim
from torch._guards import active_fake_mode
from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
from torch.distributed.distributed_c10d import (
_IllegalWork,
ProcessGroup,
ReduceOp,
Work,
)
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
from torch.futures import Future
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
@ -31,6 +24,8 @@ _P = ParamSpec("_P")
_R = TypeVar("_R")
_Ts = TypeVarTuple("_Ts")
c10d = torch.ops.c10d
class _FSDPRefType(_RefType):
"""
@ -68,13 +63,6 @@ class _SavedFSDPMethods(NamedTuple):
post_backward: Callable
class _SavedCollectives(NamedTuple):
all_gather_into_tensor: Callable
reduce_scatter_tensor: Callable
all_reduce: Callable
barrier: Callable
class _FSDPModState(_State):
"""
Enumerates the states of FSDP modules during the forward and backward passes.
@ -117,6 +105,15 @@ class _FSDPModMemStats:
] = {}
class _FSDPState(Enum):
PRE_FW = auto()
FW = auto()
POST_FW = auto()
PRE_BW = auto()
BW = auto()
POST_BW = auto()
class FSDPMemTracker(MemTracker):
"""
A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track
@ -166,9 +163,8 @@ class FSDPMemTracker(MemTracker):
assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules"
self._root_mod = mod
self._optm = optm
self._in_fake_mode: bool = False
self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary()
self._saved_collectives: _SavedCollectives
self._fsdp_state: _FSDPState = _FSDPState.PRE_FW
self._ref_class: type[_RefType] = _FSDPRefType
def _instrument_fsdp_sharded_params_grads(
@ -209,6 +205,7 @@ class FSDPMemTracker(MemTracker):
def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]:
self._fsdp_state = _FSDPState.PRE_FW
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
assert mod_fqn is not None
if fsdp_mod not in self.memory_tracking:
@ -251,6 +248,7 @@ class FSDPMemTracker(MemTracker):
else:
state = _FSDPModState.AFT_PRE_FW
mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
self._fsdp_state = _FSDPState.FW
return args, kwargs
return inner
@ -276,6 +274,7 @@ class FSDPMemTracker(MemTracker):
else:
state = _FSDPModState.BEF_POST_FW
mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
self._fsdp_state = _FSDPState.POST_FW
output = orig_fsdp_state_post_fw(*args, **kwargs)
@ -296,6 +295,7 @@ class FSDPMemTracker(MemTracker):
# and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module.
@wraps(orig_fsdp_param_group_pre_backward)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
self._fsdp_state = _FSDPState.PRE_BW
mod_stat = self.memory_tracking[fsdp_mod]
snapshot = self.get_tracker_snapshot()
mod_stat.local_peak = {
@ -310,6 +310,7 @@ class FSDPMemTracker(MemTracker):
mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append(
self.get_tracker_snapshot()
)
self._fsdp_state = _FSDPState.BW
return inner
@ -338,7 +339,7 @@ class FSDPMemTracker(MemTracker):
mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append(
self.get_tracker_snapshot()
)
self._fsdp_state = _FSDPState.POST_BW
orig_fsdp_param_group_post_backward(*args, **kwargs)
if fsdp_param_group := fsdp_state._fsdp_param_group:
@ -453,110 +454,6 @@ class FSDPMemTracker(MemTracker):
handle.remove()
self._optimizer_hook_handles = None
def _instrument_and_maybe_bypass_collectives(self) -> None:
# Monkey-patching collectives is required because they do not work with `FakeTensorMode`
# It's also easier to track `all_gather` and `reduce_scatter` buffers faithfully.
self._saved_collectives = _SavedCollectives(
dist.all_gather_into_tensor,
dist.reduce_scatter_tensor,
dist.all_reduce,
dist.barrier,
)
class FakeWork(Work):
def __init__(self) -> None:
super().__init__()
def get_future(self) -> Future:
future: Future = Future()
future.set_result(None)
return future
def wait(self, timeout: Optional[timedelta] = None) -> bool:
return True
@wraps(dist.all_gather_into_tensor)
def all_gather_into_tensor(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[ProcessGroup, None] = None,
async_op: bool = False,
) -> Union[Work, _IllegalWork, None]:
self._update_and_maybe_create_winfos(
output_tensor,
_FSDPRefType.ALL_GATHER,
update_existing=True,
)
if self._in_fake_mode:
if async_op:
return FakeWork()
return None
else:
return self._saved_collectives.all_gather_into_tensor(
output_tensor, input_tensor, group, async_op
)
@wraps(dist.reduce_scatter_tensor)
def reduce_scatter_tensor(
output: torch.Tensor,
input: torch.Tensor,
op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
group: Union[ProcessGroup, None] = None,
async_op: bool = False,
) -> Union[Work, _IllegalWork, None]:
self._update_and_maybe_create_winfos(
input,
_FSDPRefType.REDUCE_SCATTER,
update_existing=True,
)
if self._in_fake_mode:
if async_op:
return FakeWork()
return None
else:
return self._saved_collectives.reduce_scatter_tensor(
output, input, op, group, async_op
)
@wraps(dist.all_reduce)
def all_reduce(
tensor: torch.Tensor,
op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
group: Union[ProcessGroup, None] = None,
async_op: bool = False,
) -> Union[Work, _IllegalWork, None]:
if self._in_fake_mode:
if async_op:
return FakeWork()
return None
else:
return self._saved_collectives.all_reduce(tensor, op, group, async_op)
@wraps(dist.barrier)
def barrier(
group: Union[ProcessGroup, None] = dist.GroupMember.WORLD,
async_op: bool = False,
device_ids: Union[list[int], None] = None,
) -> Union[Work, None]:
if self._in_fake_mode:
return None
else:
return self._saved_collectives.barrier(group, async_op, device_ids)
dist.all_gather_into_tensor = all_gather_into_tensor
dist.reduce_scatter_tensor = reduce_scatter_tensor
dist.all_reduce = all_reduce
dist.barrier = barrier
def _restore_collectives(self) -> None:
dist.all_gather_into_tensor = self._saved_collectives.all_gather_into_tensor
dist.reduce_scatter_tensor = self._saved_collectives.reduce_scatter_tensor
dist.all_reduce = self._saved_collectives.all_reduce
dist.barrier = self._saved_collectives.barrier
del self._saved_collectives
def track_inputs(self, inputs: tuple[Any, ...]) -> None:
"""
This is used to track the input tensors to the model and annotate them as ``Inputs``.
@ -579,27 +476,39 @@ class FSDPMemTracker(MemTracker):
"""This is no-op for ``FSDPMemTracker``"""
def __enter__(self) -> "FSDPMemTracker":
self._in_fake_mode = True if active_fake_mode() else False
self._register_module_and_optimizer_hooks()
self._instrument_and_maybe_bypass_collectives()
self._track_resize()
self._peak_mem_snap = self.get_tracker_snapshot()
self._peak_mem = {
dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
}
self._mod_tracker.__enter__()
if self._depth == 0:
self._register_module_and_optimizer_hooks()
self._track_resize()
self._track_dtensor_dispatch()
self._peak_mem_snap = self.get_tracker_snapshot()
self._peak_mem = {
dev: dev_snap[_TOTAL_KEY]
for dev, dev_snap in self._peak_mem_snap.items()
}
self._mod_tracker.__enter__()
TorchDispatchMode.__enter__(self)
self._depth += 1
return self
def __exit__(self, *args: Any) -> None:
self._deregister_module_and_optimizer_hooks()
self._restore_collectives()
self._restore_resize()
self._depth -= 1
if self._depth == 0:
self._deregister_module_and_optimizer_hooks()
self._restore_resize()
self._restore_dtensor_dispatch()
self._mod_tracker.__exit__(*args)
TorchDispatchMode.__exit__(self, *args)
self._mod_tracker.__exit__(*args)
def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def]
res = func(*args, **kwargs or {})
if (
func == torch.ops._c10d_functional.wait_tensor.default
and active_fake_mode()
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
res = args[0]
else:
res = func(*args, **kwargs or {})
# If we are tracking an optimizer state, we use the optimizer reference type.
# If we are in backward region and not in AC region, we use the backward reference type.
# Else we use the forward reference type.
@ -609,6 +518,27 @@ class FSDPMemTracker(MemTracker):
reftype = _FSDPRefType.TEMP
else:
reftype = _FSDPRefType.ACT
if func == c10d._allgather_base_.default and self._fsdp_state in [
_FSDPState.PRE_FW,
_FSDPState.PRE_BW,
]:
output_tensor = args[0]
self._update_and_maybe_create_winfos(
output_tensor,
_FSDPRefType.ALL_GATHER,
update_existing=True,
)
if (
func == c10d._reduce_scatter_base_.default
and self._fsdp_state == _FSDPState.POST_BW
):
input_tensor = args[1]
self._update_and_maybe_create_winfos(
input_tensor,
_FSDPRefType.REDUCE_SCATTER,
update_existing=True,
)
tree_map_only(torch.Tensor, partial(self._track, reftype), res)
peak_state = (
_FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW

View File

@ -2,6 +2,7 @@ import math
import os
import re
import warnings
from contextlib import nullcontext
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
@ -9,16 +10,17 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
import torch
import torch.distributed._tools.fake_collectives
from torch import nn, optim
from torch._guards import active_fake_mode
from torch.distributed._tools.common_utils import get_untyped_storages
from torch.distributed._tools.mod_tracker import ModTracker
from torch.distributed.tensor import DTensor
from torch.optim.optimizer import (
register_optimizer_step_post_hook,
register_optimizer_step_pre_hook,
)
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
TorchDispatchMode,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
@ -170,35 +172,6 @@ class _WeakRefInfo:
self.mem_consumed = self._calculate_mem_consumed()
return self.mem_consumed
@staticmethod
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
Args:
t (torch.Tensor): The tensor to extract storages from.
Returns:
set[torch.UntypedStorage]: A set of untyped storages.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages
@classmethod
def create_winfo(
cls,
@ -253,7 +226,9 @@ def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) ->
print(
f"Device: {dev}",
*(
f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}"
f"\t{k.value}: {_rounding_fn(v, divisor, 2)} {units}"
if isinstance(k, _RefType)
else f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}"
for k, v in dev_snap.items()
),
sep="\n",
@ -275,7 +250,9 @@ def _print_snapshot_tabular(
divisor = _get_mem_divisor(units)
table_data = []
key_list = list(next(iter(snapshot.values())).keys())
headers = ["Device"] + [f"{key}" for key in key_list]
headers = ["Device"] + [
f"{key.value}" if isinstance(key, _RefType) else f"{key}" for key in key_list
]
for dev, dev_snap in snapshot.items():
if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
@ -290,7 +267,7 @@ def _print_state_snapshots(
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
) -> None:
for state, snapshot_list in snapshots.items():
print(f"{state}")
print(f"{state.value}")
for i, snapshot in enumerate(snapshot_list):
print(f"# {i + 1}:")
_print_snapshot(snapshot, units)
@ -312,7 +289,7 @@ def _print_state_snapshots_tabular(
divisor = _get_mem_divisor(units)
for state, snapshot_list in snapshots.items():
for i, snapshot in enumerate(snapshot_list):
state_call = f"{state} # {i + 1}"
state_call = f"{state.value} # {i + 1}"
for dev, dev_snap in snapshot.items():
if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
continue
@ -324,7 +301,9 @@ def _print_state_snapshots_tabular(
}
last_state_call = state_call
for k, v in dev_snap.items():
row[f"{k}"] = f"{_rounding_fn(v, divisor, 2)} {units}"
row[f"{k.value}" if isinstance(k, _RefType) else f"{k}"] = (
f"{_rounding_fn(v, divisor, 2)} {units}"
)
table_data.append(row)
print(tabulate(table_data, headers="keys", tablefmt="rst"))
@ -411,6 +390,8 @@ class MemTracker(TorchDispatchMode):
# Weak references to the topmost AC module currently active
self._ac_mod: Optional[weakref.ref] = None
self._orig_resize = torch.UntypedStorage.resize_
self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch
self._depth = 0
def _update_snap(
self,
@ -462,7 +443,7 @@ class MemTracker(TorchDispatchMode):
reftype: _RefType,
update_existing: bool = False,
) -> set[_WeakRefInfo]:
sts = _WeakRefInfo.get_untyped_storages(t)
sts = get_untyped_storages(t)
winfos = set()
for st in sts:
# Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary.
@ -543,7 +524,7 @@ class MemTracker(TorchDispatchMode):
# Get the storages of the tensor and check if we have already tracked them.
# If yes, then check if the storage size has changed and update the current snapshot.
# Else create a new ``_WeakRefInfo`` instance and add it to the dictionary.
sts = _WeakRefInfo.get_untyped_storages(t)
sts = get_untyped_storages(t)
for st in sts:
winfo, _ = self._WINFO.get(st, (None, None))
if winfo is not None:
@ -640,7 +621,7 @@ class MemTracker(TorchDispatchMode):
def add_inps_or_outs(t: torch.Tensor) -> None:
nonlocal input_or_output_memory
sts = _WeakRefInfo.get_untyped_storages(t)
sts = get_untyped_storages(t)
for st in sts:
winfo, _ = self._WINFO.get(st, (None, None))
if winfo is not None:
@ -694,6 +675,7 @@ class MemTracker(TorchDispatchMode):
mod_stats = self.memory_tracking[module]
state = _ModState.PRE_FW
input_mem = self._track_inputs_or_outputs(inputs)
mod_stats.mod_fqn = mod_name
mod_stats.input_mem = input_mem
mem_snapshot = self.get_tracker_snapshot()
@ -826,6 +808,8 @@ class MemTracker(TorchDispatchMode):
self._track_module_params_and_buffers(obj, install_grad_hooks=False)
elif isinstance(obj, optim.Optimizer):
self._track_optimizer_states(_MemRefType.OPT, obj)
elif obj is None:
continue
else:
raise TypeError(
f"Object of type {type(obj)} is not supported for tracking. "
@ -891,32 +875,65 @@ class MemTracker(TorchDispatchMode):
"""
self.memory_tracking.clear()
def _track_dtensor_dispatch(self) -> None:
def track_dtensor_dispatch(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
with (
self
if op_call in DTensor._op_dispatcher._custom_op_handlers
else nullcontext()
):
return self._orig_dtensor_dispatch(op_call, args, kwargs)
DTensor._op_dispatcher.dispatch = track_dtensor_dispatch # type: ignore[method-assign, assignment]
def _restore_dtensor_dispatch(self) -> None:
DTensor._op_dispatcher.dispatch = self._orig_dtensor_dispatch # type: ignore[method-assign]
def __enter__(self) -> "MemTracker":
self._register_global_optimizer_hook()
self._mod_tracker.register_user_hooks(
self._pre_fw_hook,
self._post_fw_hook,
self._pre_bw_hook,
self._post_bw_hook,
)
self._track_resize()
self._peak_mem_snap = self.get_tracker_snapshot()
self._peak_mem = {
dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
}
self._mod_tracker.__enter__()
if self._depth == 0:
self._register_global_optimizer_hook()
self._mod_tracker.register_user_hooks(
self._pre_fw_hook,
self._post_fw_hook,
self._pre_bw_hook,
self._post_bw_hook,
)
self._track_resize()
self._track_dtensor_dispatch()
self._peak_mem_snap = self.get_tracker_snapshot()
self._peak_mem = {
dev: dev_snap[_TOTAL_KEY]
for dev, dev_snap in self._peak_mem_snap.items()
}
self._mod_tracker.__enter__()
super().__enter__()
self._depth += 1
return self
def __exit__(self, *args: Any) -> None:
self._deregister_param_and_optimizer_hooks()
self._mod_tracker.clear_user_hooks()
self._restore_resize()
self._depth -= 1
if self._depth == 0:
self._deregister_param_and_optimizer_hooks()
self._mod_tracker.clear_user_hooks()
self._restore_resize()
self._restore_dtensor_dispatch()
self._mod_tracker.__exit__(*args)
super().__exit__(*args)
self._mod_tracker.__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def]
res = func(*args, **kwargs or {})
if (
func == torch.ops._c10d_functional.wait_tensor.default
and active_fake_mode()
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
res = args[0]
else:
res = func(*args, **kwargs or {})
# If we are tracking an optimizer state, we use the optimizer reference type.
# If we are in backward region and not in AC region, we use the backward reference type.
# Else we use the forward reference type.

View File

@ -60,6 +60,7 @@ class ModTracker:
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
self._has_callback = False
self._post_bw_callbacks_to_enqueue: list[Callable] = []
self._user_pre_fw_hook = None
self._user_post_fw_hook = None
self._user_pre_bw_hook = None
@ -70,6 +71,10 @@ class ModTracker:
if self._has_callback:
return
for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue):
torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback)
self._post_bw_callbacks_to_enqueue.clear()
def callback():
self.parents = {"Global"}
self._has_callback = False
@ -213,8 +218,13 @@ class ModTracker:
self._user_pre_fw_hook(mod, input)
args, _ = tree_flatten(input)
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if not self.is_bw and tensors:
register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
if not self.is_bw:
if tensors:
register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
else:
self._post_bw_callbacks_to_enqueue.append(
self._get_pop_fn(w_mod, name, True)
)
def _fw_post_hook(self, mod, input, output):
name = self._get_mod_name(mod)
@ -225,7 +235,9 @@ class ModTracker:
args, _ = tree_flatten(output)
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if not self.is_bw and tensors:
register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True))
register_multi_grad_hook(
tensors, self._get_append_fn(w_mod, name, True), mode="any"
)
def __enter__(self):
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)

View File

@ -1,7 +1,6 @@
import math
import os
import sys
import warnings
from collections import OrderedDict
from dataclasses import astuple, dataclass
from typing import Any, NamedTuple, Optional
@ -11,6 +10,7 @@ import torch
from torch import nan, nn, UntypedStorage
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.common_utils import get_untyped_storages
from torch.distributed._tools.mod_tracker import ModTracker
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.testing._internal.composite_compliance import (
@ -18,10 +18,7 @@ from torch.testing._internal.composite_compliance import (
is_inplace_view_fn,
is_view_fn,
)
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
TorchDispatchMode,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
from torch.utils.checkpoint import SAC_IGNORED_OPS
@ -42,38 +39,6 @@ _PYTORCH_MIN_ALLOCATE = (
)
def _get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass.
Args:
t (torch.Tensor): Input `torch.Tensor` or traceable wrapper-subclass of `torch.Tensor`.
Returns:
set[torch.UntypedStorage]: Set of untyped storages.
Warns:
UserWarning: If the flattened input is not a tensor or traceable wrapper-subclass.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
try:
from tabulate import tabulate
@ -268,7 +233,7 @@ class SACEstimator(TorchDispatchMode):
# Hook function to track underlying storage IDs of tensors
# Updates the _saved_tensor_ids set with the IDs of the tensor's storages
# Used in conjunction with torch.autograd.graph.saved_tensors_hooks
untyped_storages = _get_untyped_storages(x)
untyped_storages = get_untyped_storages(x)
storage_ids = (hash(st) for st in untyped_storages)
self._saved_tensor_ids.update(storage_ids)
return x
@ -436,10 +401,10 @@ class SACEstimator(TorchDispatchMode):
for o in flat_outs:
if isinstance(o, torch.Tensor):
if o.device.type == "cuda":
out_storages_cuda.update(_get_untyped_storages(o))
out_storages_cuda.update(get_untyped_storages(o))
cuda_devices.add(o.device)
else:
out_storages_cpu.update(_get_untyped_storages(o))
out_storages_cpu.update(get_untyped_storages(o))
# Check if there's more than 1 CUDA device
assert len(cuda_devices) <= 1, (