mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for non functional collectives under FakeTensorMode and fake_pg for memory tracking (#147566)
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation. It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do. For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`. Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566 Approved by: https://github.com/weifengpy
This commit is contained in:
committed by
PyTorch MergeBot
parent
439782960c
commit
9841f0ddcf
216
test/distributed/_tools/test_fake_collectives.py
Normal file
216
test/distributed/_tools/test_fake_collectives.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_c10d import FakeWork, ProcessGroup
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed._functional_collectives import (
|
||||
all_gather_into_tensor_coalesced,
|
||||
all_gather_tensor,
|
||||
all_gather_tensor_autograd,
|
||||
all_reduce,
|
||||
all_reduce_coalesced,
|
||||
all_to_all_single,
|
||||
all_to_all_single_autograd,
|
||||
broadcast,
|
||||
reduce_scatter_tensor,
|
||||
reduce_scatter_tensor_autograd,
|
||||
reduce_scatter_tensor_coalesced,
|
||||
wait_tensor,
|
||||
)
|
||||
from torch.distributed._tools.fake_collectives import (
|
||||
collective_ops,
|
||||
CollectiveOp,
|
||||
non_functional_collectives,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
c10d = torch.ops.c10d
|
||||
_c10d_functional = torch.ops._c10d_functional
|
||||
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
|
||||
|
||||
|
||||
class TestFakeCollectives(TestCase):
|
||||
def _setup_distributed(self):
|
||||
world_size = 4
|
||||
store = FakeStore()
|
||||
dist.init_process_group("fake", rank=0, world_size=world_size, store=store)
|
||||
torch.cuda.set_device(torch.cuda.current_device())
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_collectives(self):
|
||||
try:
|
||||
self._setup_distributed()
|
||||
with FakeTensorMode(), CollectiveTest(test=self):
|
||||
test_tensor_list = [torch.randn(100, device="cuda") for _ in range(4)]
|
||||
test_tensor_list_2 = [torch.randn(400, device="cuda") for _ in range(4)]
|
||||
test_tensor = torch.randn(100, device="cuda")
|
||||
# Used as gather output or scatter input
|
||||
test_tensor2 = torch.randn(400, device="cuda")
|
||||
|
||||
# Testing non-functional collective operations
|
||||
dist.broadcast(test_tensor, src=0)
|
||||
dist.all_reduce(test_tensor)
|
||||
dist.reduce(test_tensor, dst=0)
|
||||
dist.send(test_tensor, dst=1)
|
||||
dist.recv(test_tensor, src=1)
|
||||
dist.all_gather(test_tensor_list, test_tensor)
|
||||
dist.reduce_scatter(test_tensor, test_tensor_list)
|
||||
dist.reduce_scatter_tensor(test_tensor, test_tensor2)
|
||||
dist.scatter(test_tensor, scatter_list=test_tensor_list, src=0)
|
||||
dist.gather(test_tensor, gather_list=test_tensor_list, dst=0)
|
||||
dist.all_gather_into_tensor(test_tensor2, test_tensor)
|
||||
dist.all_to_all(test_tensor_list, test_tensor_list)
|
||||
dist.all_to_all_single(test_tensor2, test_tensor2)
|
||||
dist.barrier()
|
||||
|
||||
# Testing functional collectives
|
||||
wait_tensor(test_tensor)
|
||||
broadcast(test_tensor, src=0, group=dist.group.WORLD)
|
||||
all_reduce(test_tensor, reduceOp="avg", group=dist.group.WORLD)
|
||||
all_gather_tensor(test_tensor, gather_dim=0, group=dist.group.WORLD)
|
||||
all_gather_tensor_autograd(
|
||||
test_tensor, gather_dim=0, group=dist.group.WORLD
|
||||
)
|
||||
reduce_scatter_tensor(
|
||||
test_tensor2, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD
|
||||
)
|
||||
reduce_scatter_tensor_autograd(
|
||||
test_tensor2, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD
|
||||
)
|
||||
all_to_all_single(
|
||||
test_tensor,
|
||||
output_split_sizes=[0],
|
||||
input_split_sizes=[1],
|
||||
group=dist.group.WORLD,
|
||||
)
|
||||
all_reduce_coalesced(
|
||||
test_tensor_list, reduceOp="avg", group=dist.group.WORLD
|
||||
)
|
||||
all_gather_into_tensor_coalesced(
|
||||
test_tensor_list, group=dist.group.WORLD
|
||||
)
|
||||
reduce_scatter_tensor_coalesced(
|
||||
test_tensor_list_2,
|
||||
scatter_dim=[0] * 4,
|
||||
reduceOp="sum",
|
||||
group=dist.group.WORLD,
|
||||
)
|
||||
all_to_all_single_autograd(
|
||||
test_tensor,
|
||||
output_split_sizes=[0],
|
||||
input_split_sizes=[1],
|
||||
group=dist.group.WORLD,
|
||||
)
|
||||
finally:
|
||||
if dist.group.WORLD is not None:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class CollectiveTest(TorchDispatchMode):
|
||||
collective_size_exclude = {
|
||||
c10d.barrier.default,
|
||||
c10d.monitored_barrier_.default,
|
||||
_c10d_functional.wait_tensor.default,
|
||||
}
|
||||
|
||||
def __init__(self, test: TestFakeCollectives, _dispatch_key=None):
|
||||
super().__init__(_dispatch_key)
|
||||
self.test = test
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
res = func(*args, **(kwargs or {}))
|
||||
|
||||
if func in collective_ops:
|
||||
if func != _c10d_functional.wait_tensor.default:
|
||||
pg = CollectiveOp.get_process_group(func, args)
|
||||
self.test.assertIsInstance(
|
||||
pg, ProcessGroup, "Error: pg is not an instance of ProcessGroup"
|
||||
)
|
||||
self.test.assertEqual(
|
||||
pg, dist.group.WORLD, "Error: pg is not equal to dist.group.WORLD"
|
||||
)
|
||||
self.test.assertEqual(
|
||||
pg.size(),
|
||||
4,
|
||||
f"Error: Expected pg.size() to be 4, but got {pg.size()}",
|
||||
)
|
||||
self.test.assertNotEqual(
|
||||
pg.name(), "", "Error: pg.name() should not be an empty string"
|
||||
)
|
||||
|
||||
if func not in CollectiveTest.collective_size_exclude:
|
||||
# Compute expected communication tensor size
|
||||
computed_size = CollectiveOp.get_comm_tensor_size(
|
||||
func, res, args, kwargs
|
||||
)
|
||||
expected_size = self.get_expected_size(func, res, args, kwargs)
|
||||
|
||||
self.test.assertEqual(
|
||||
computed_size,
|
||||
expected_size,
|
||||
msg=f"Size mismatch for {func.__name__}: expected {expected_size}, got {computed_size}",
|
||||
)
|
||||
|
||||
if (
|
||||
func in non_functional_collectives
|
||||
and func != c10d.monitored_barrier_.default
|
||||
):
|
||||
work = res[-1] if isinstance(res, (tuple, list)) else res
|
||||
self.test.assertIsInstance(FakeWork.unbox(work), FakeWork)
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_expected_size(func, res, args, kwargs):
|
||||
"""Return expected tensor size for collectives explicitly used in run_test()."""
|
||||
WORLD_SIZE, TENSOR_100, TENSOR_400 = 4, 100 * 4, 400 * 4
|
||||
TENSOR_LIST_100, TENSOR_LIST_400 = (
|
||||
WORLD_SIZE * TENSOR_100,
|
||||
WORLD_SIZE * TENSOR_400,
|
||||
)
|
||||
|
||||
size_map = {
|
||||
# Non-functional collectives
|
||||
c10d.broadcast_.default: TENSOR_100,
|
||||
c10d.allreduce_.default: TENSOR_100,
|
||||
c10d.reduce_.default: TENSOR_100,
|
||||
c10d.send.default: TENSOR_100,
|
||||
c10d.recv_.default: TENSOR_100,
|
||||
c10d.allgather_.default: TENSOR_LIST_100,
|
||||
c10d.reduce_scatter_.default: TENSOR_LIST_100,
|
||||
c10d._reduce_scatter_base_.default: TENSOR_400,
|
||||
c10d.scatter_.default: TENSOR_LIST_100,
|
||||
c10d.gather_.default: TENSOR_LIST_100,
|
||||
c10d._allgather_base_.default: TENSOR_400,
|
||||
c10d.alltoall_.default: TENSOR_LIST_100,
|
||||
c10d.alltoall_base_.default: TENSOR_400,
|
||||
# Functional collectives
|
||||
_c10d_functional.broadcast.default: TENSOR_100,
|
||||
_c10d_functional.all_reduce.default: TENSOR_100,
|
||||
_c10d_functional.all_gather_into_tensor.default: TENSOR_LIST_100,
|
||||
_c10d_functional_autograd.all_gather_into_tensor.default: TENSOR_LIST_100,
|
||||
_c10d_functional.reduce_scatter_tensor.default: TENSOR_400,
|
||||
_c10d_functional_autograd.reduce_scatter_tensor.default: TENSOR_400,
|
||||
_c10d_functional.all_to_all_single.default: TENSOR_100,
|
||||
_c10d_functional_autograd.all_to_all_single.default: TENSOR_100,
|
||||
_c10d_functional.all_reduce_coalesced.default: TENSOR_LIST_100,
|
||||
_c10d_functional.all_gather_into_tensor_coalesced.default: TENSOR_LIST_400,
|
||||
_c10d_functional.reduce_scatter_tensor_coalesced.default: TENSOR_LIST_100,
|
||||
}
|
||||
|
||||
if func in size_map:
|
||||
return size_map[func]
|
||||
|
||||
raise ValueError(f"Unhandled function: {func}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -192,6 +192,8 @@ class TestModTracker(TestCase):
|
||||
("post_fw", "Foo.linears.1", True, True),
|
||||
("post_bw", "Foo.linears.1", True, True),
|
||||
("pre_bw", "Foo.linears.0", True, True),
|
||||
("post_bw", "Foo.linears.0", True, True),
|
||||
("post_bw", "Foo", True, True),
|
||||
]
|
||||
self.assertEqual(test_op, expected_op)
|
||||
|
||||
|
||||
@ -535,6 +535,12 @@ class ProcessGroup:
|
||||
class FakeProcessGroup(Backend):
|
||||
def __init__(self, rank: int, world_size: int) -> None: ...
|
||||
|
||||
class FakeWork(Work):
|
||||
seq_id: int
|
||||
def __init__(self) -> None: ...
|
||||
def wait(self, timeout: timedelta = ...) -> bool: ...
|
||||
def getFuture(self) -> Future: ...
|
||||
|
||||
class ProcessGroupGloo(Backend):
|
||||
class Device: ...
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ namespace c10d {
|
||||
|
||||
class FakeWork : public Work {
|
||||
public:
|
||||
bool wait(std::chrono::milliseconds timeout) override {
|
||||
int seq_id = -1;
|
||||
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -177,6 +178,18 @@ class FakeProcessGroup : public Backend {
|
||||
return c10::make_intrusive<FakeWork>();
|
||||
}
|
||||
|
||||
void startCoalescing() override {
|
||||
// No-op
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> endCoalescing(OpType /* optype */) {
|
||||
return c10::make_intrusive<FakeWork>();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> endCoalescing() override {
|
||||
return c10::make_intrusive<FakeWork>();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> barrier(
|
||||
const BarrierOptions& /* opts */ = BarrierOptions()) override {
|
||||
return c10::make_intrusive<FakeWork>();
|
||||
|
||||
@ -3221,67 +3221,68 @@ Example::
|
||||
.def_readonly("time_started", &::c10d::WorkInfo::timeStarted)
|
||||
.def_readonly("time_finished", &::c10d::WorkInfo::timeFinished)
|
||||
.def_readonly("active_duration", &::c10d::WorkInfo::activeDuration);
|
||||
|
||||
py::class_<
|
||||
::c10d::Work,
|
||||
c10::intrusive_ptr<::c10d::Work>,
|
||||
::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
|
||||
auto work =
|
||||
py::class_<
|
||||
::c10d::Work,
|
||||
c10::intrusive_ptr<::c10d::Work>,
|
||||
::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
|
||||
A `Work` object represents the handle to a pending asynchronous operation in
|
||||
PyTorch's distributed package. It is returned by non-blocking collective operations,
|
||||
such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
)")
|
||||
.def(py::init<>())
|
||||
.def("is_completed", &::c10d::Work::isCompleted)
|
||||
.def(
|
||||
"is_success",
|
||||
[](::c10d::Work& work) -> bool {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::is_success"));
|
||||
return work.isSuccess();
|
||||
})
|
||||
.def(
|
||||
"exception",
|
||||
[](::c10d::Work& work) -> std::exception_ptr {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::exception"));
|
||||
return work.exception();
|
||||
})
|
||||
.def(
|
||||
"source_rank",
|
||||
[](::c10d::Work& work) -> int {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::source_rank"));
|
||||
return work.sourceRank();
|
||||
})
|
||||
.def("_source_rank", &::c10d::Work::sourceRank)
|
||||
.def(
|
||||
"result",
|
||||
[](::c10d::Work& work) -> std::vector<at::Tensor> {
|
||||
// Deprecation reason:
|
||||
// Work.result() returns a vector of tensors. This signature is
|
||||
// problematic as some collectives may just return one tensor
|
||||
// (e.g all-reduce), while some others may return multiple
|
||||
// tensors (e.g. all-gather).
|
||||
// Deprecating work.result() would
|
||||
// also allow us to remove the `outputs_` field in the Work
|
||||
// class, avoiding an "artificial" reference to the tensors,
|
||||
// which could potentially hold up the tensors' memory.
|
||||
TORCH_WARN_ONCE(fmt::format(kDeprecationWarning, "Work::result"));
|
||||
return work.result();
|
||||
})
|
||||
.def(
|
||||
"synchronize",
|
||||
[](::c10d::Work& work) -> void {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::synchronize"));
|
||||
work.synchronize();
|
||||
})
|
||||
.def(
|
||||
"wait",
|
||||
&::c10d::Work::wait,
|
||||
py::arg("timeout") = kNoTimeout,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
.def(py::init<>())
|
||||
.def("is_completed", &::c10d::Work::isCompleted)
|
||||
.def(
|
||||
"is_success",
|
||||
[](::c10d::Work& work) -> bool {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::is_success"));
|
||||
return work.isSuccess();
|
||||
})
|
||||
.def(
|
||||
"exception",
|
||||
[](::c10d::Work& work) -> std::exception_ptr {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::exception"));
|
||||
return work.exception();
|
||||
})
|
||||
.def(
|
||||
"source_rank",
|
||||
[](::c10d::Work& work) -> int {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::source_rank"));
|
||||
return work.sourceRank();
|
||||
})
|
||||
.def("_source_rank", &::c10d::Work::sourceRank)
|
||||
.def(
|
||||
"result",
|
||||
[](::c10d::Work& work) -> std::vector<at::Tensor> {
|
||||
// Deprecation reason:
|
||||
// Work.result() returns a vector of tensors. This signature is
|
||||
// problematic as some collectives may just return one tensor
|
||||
// (e.g all-reduce), while some others may return multiple
|
||||
// tensors (e.g. all-gather).
|
||||
// Deprecating work.result() would
|
||||
// also allow us to remove the `outputs_` field in the Work
|
||||
// class, avoiding an "artificial" reference to the tensors,
|
||||
// which could potentially hold up the tensors' memory.
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::result"));
|
||||
return work.result();
|
||||
})
|
||||
.def(
|
||||
"synchronize",
|
||||
[](::c10d::Work& work) -> void {
|
||||
TORCH_WARN_ONCE(
|
||||
fmt::format(kDeprecationWarning, "Work::synchronize"));
|
||||
work.synchronize();
|
||||
})
|
||||
.def(
|
||||
"wait",
|
||||
&::c10d::Work::wait,
|
||||
py::arg("timeout") = kNoTimeout,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Returns:
|
||||
true/false.
|
||||
|
||||
@ -3298,13 +3299,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
|
||||
or timed out. If timeout, exception will be thrown.
|
||||
)")
|
||||
.def(
|
||||
"get_future_result",
|
||||
[](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
work.getFutureResult());
|
||||
},
|
||||
R"(
|
||||
.def(
|
||||
"get_future_result",
|
||||
[](::c10d::Work& work)
|
||||
-> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
work.getFutureResult());
|
||||
},
|
||||
R"(
|
||||
Returns:
|
||||
A ``torch.futures.Future`` object of int type which maps to the enum type of WorkResult
|
||||
As an example, a future object can be retrieved
|
||||
@ -3319,12 +3321,14 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
.. warning ::
|
||||
``get_future_result`` API supports NCCL
|
||||
)")
|
||||
.def(
|
||||
"get_future",
|
||||
[](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
|
||||
},
|
||||
R"(
|
||||
.def(
|
||||
"get_future",
|
||||
[](::c10d::Work& work)
|
||||
-> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
work.getFuture());
|
||||
},
|
||||
R"(
|
||||
Returns:
|
||||
A ``torch.futures.Future`` object which is associated with the completion of
|
||||
the ``Work``. As an example, a future object can be retrieved
|
||||
@ -3363,16 +3367,16 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
true when tensors have arrived on respective nodes, but not yet necessarily synched on
|
||||
respective GPUs (similarly to GPU work).
|
||||
)")
|
||||
.def(
|
||||
"_get_op_type",
|
||||
[](::c10d::Work& work) -> int {
|
||||
return static_cast<int>(work.retrieveOpType());
|
||||
})
|
||||
.def(
|
||||
"_get_duration",
|
||||
&::c10d::Work::getDuration,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
.def(
|
||||
"_get_op_type",
|
||||
[](::c10d::Work& work) -> int {
|
||||
return static_cast<int>(work.retrieveOpType());
|
||||
})
|
||||
.def(
|
||||
"_get_duration",
|
||||
&::c10d::Work::getDuration,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Returns:
|
||||
Duration of the corresponding collective communication.
|
||||
|
||||
@ -3380,17 +3384,17 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
This API only works for NCCL backend for now and must set
|
||||
TORCH_NCCL_ENABLE_TIMING environment variable.
|
||||
)")
|
||||
.def(
|
||||
"boxed",
|
||||
[](c10::intrusive_ptr<::c10d::Work> self) {
|
||||
return torch::jit::toPyObject(c10::IValue(std::move(self)));
|
||||
})
|
||||
.def_static("unbox", [](py::object obj) {
|
||||
auto typePtr =
|
||||
torch::getCustomClass("__torch__.torch.classes.c10d.Work");
|
||||
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
|
||||
return ivalue.toCustomClass<::c10d::Work>();
|
||||
});
|
||||
.def(
|
||||
"boxed",
|
||||
[](c10::intrusive_ptr<::c10d::Work> self) {
|
||||
return torch::jit::toPyObject(c10::IValue(std::move(self)));
|
||||
})
|
||||
.def_static("unbox", [](py::object obj) {
|
||||
auto typePtr =
|
||||
torch::getCustomClass("__torch__.torch.classes.c10d.Work");
|
||||
auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
|
||||
return ivalue.toCustomClass<::c10d::Work>();
|
||||
});
|
||||
|
||||
auto fakeProcessGroup =
|
||||
intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>(
|
||||
@ -3402,6 +3406,13 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("rank"),
|
||||
py::arg("world_size"));
|
||||
auto fakeWork =
|
||||
intrusive_ptr_no_gil_destructor_class_<::c10d::FakeWork>(
|
||||
module, "FakeWork", work)
|
||||
.def(py::init<>())
|
||||
.def_readwrite("seq_id", &::c10d::FakeWork::seq_id) // Expose seq_id
|
||||
.def("wait", &::c10d::FakeWork::wait, py::arg("timeout") = kNoTimeout)
|
||||
.def("getFuture", &::c10d::FakeWork::getFuture);
|
||||
|
||||
py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
|
||||
.def(py::init<>())
|
||||
|
||||
33
torch/distributed/_tools/common_utils.py
Normal file
33
torch/distributed/_tools/common_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
|
||||
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
||||
"""
|
||||
Recursively extracts untyped storages from a tensor or its subclasses.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): The tensor to extract storages from.
|
||||
|
||||
Returns:
|
||||
Set[torch.UntypedStorage]: A set of untyped storages.
|
||||
"""
|
||||
unflattened_tensors = [t]
|
||||
flattened_tensor_storages = set()
|
||||
while len(unflattened_tensors) > 0:
|
||||
obj = unflattened_tensors.pop()
|
||||
if is_traceable_wrapper_subclass(obj):
|
||||
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
|
||||
else:
|
||||
if not hasattr(obj, "untyped_storage"):
|
||||
warnings.warn(
|
||||
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
flattened_tensor_storages.add(obj.untyped_storage())
|
||||
return flattened_tensor_storages
|
||||
307
torch/distributed/_tools/fake_collectives.py
Normal file
307
torch/distributed/_tools/fake_collectives.py
Normal file
@ -0,0 +1,307 @@
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_c10d import (
|
||||
_resolve_process_group,
|
||||
FakeWork,
|
||||
ProcessGroup,
|
||||
Work,
|
||||
)
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
torch.distributed.batch_isend_irecv
|
||||
|
||||
c10d = torch.ops.c10d
|
||||
_c10d_functional = torch.ops._c10d_functional
|
||||
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
|
||||
_dtensor = torch.ops._dtensor
|
||||
used_ids: set[int] = set()
|
||||
|
||||
|
||||
def generate_unique_id() -> int:
|
||||
while True:
|
||||
new_id = random.randint(1, 10**9)
|
||||
if new_id not in used_ids:
|
||||
used_ids.add(new_id)
|
||||
return new_id
|
||||
|
||||
|
||||
# Function to create and return FakeWork object
|
||||
def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def]
|
||||
work = FakeWork()
|
||||
work.seq_id = generate_unique_id()
|
||||
fakework_script_obj = work.boxed()
|
||||
return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj
|
||||
|
||||
|
||||
# Dictionary mapping collective operations to their meta functions
|
||||
# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included
|
||||
# _DEPRECATED_META_FUNCTIONS = {
|
||||
# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
# }
|
||||
_META_FUNCTIONS = {
|
||||
"broadcast_": lambda *args: create_fakework(args),
|
||||
"allreduce_": lambda *args: create_fakework(args),
|
||||
"allgather_": lambda *args: create_fakework(args),
|
||||
"_allgather_base_": lambda *args: create_fakework(args),
|
||||
"reduce_scatter_": lambda *args: create_fakework(args),
|
||||
"_reduce_scatter_base_": lambda *args: create_fakework(args),
|
||||
"reduce_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"gather_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"scatter_": lambda *args: create_fakework(args),
|
||||
"alltoall_": lambda *args: create_fakework(args),
|
||||
"alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"barrier": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"monitored_barrier_": lambda *args: None,
|
||||
"send": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"recv_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
}
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
|
||||
for op, meta_func in _META_FUNCTIONS.items():
|
||||
lib_impl.impl(op, meta_func, "Meta")
|
||||
|
||||
# List of collective operation functions including functional collectives
|
||||
# Note: The following collectives might be deprecated soon hence not adding them
|
||||
# depcreated_non_functional_collectives = [
|
||||
# c10d.allreduce_coalesced_.default,
|
||||
# c10d.reduce_scatter_tensor_coalesced_.default,
|
||||
# c10d.allgather_into_tensor_coalesced_.default,
|
||||
# c10d.allgather_coalesced_.default,
|
||||
# ]
|
||||
non_functional_collectives: set[torch._ops.OpOverload] = {
|
||||
c10d.broadcast_.default,
|
||||
c10d.allreduce_.default,
|
||||
c10d.reduce_.default,
|
||||
c10d.send.default,
|
||||
c10d.recv_.default,
|
||||
c10d.recv_any_source_.default,
|
||||
c10d.allgather_.default,
|
||||
c10d.reduce_scatter_.default,
|
||||
c10d._reduce_scatter_base_.default,
|
||||
c10d._allgather_base_.default,
|
||||
c10d.gather_.default,
|
||||
c10d.scatter_.default,
|
||||
c10d.alltoall_.default,
|
||||
c10d.alltoall_base_.default,
|
||||
c10d.barrier.default,
|
||||
c10d.monitored_barrier_.default,
|
||||
}
|
||||
functional_collectives: set[torch._ops.OpOverload] = {
|
||||
_c10d_functional.broadcast.default,
|
||||
_c10d_functional.all_reduce.default,
|
||||
_c10d_functional.all_gather_into_tensor.default,
|
||||
_c10d_functional.reduce_scatter_tensor.default,
|
||||
_c10d_functional.all_to_all_single.default,
|
||||
_c10d_functional_autograd.all_to_all_single.default,
|
||||
_c10d_functional.wait_tensor.default,
|
||||
_c10d_functional.all_reduce_.default,
|
||||
_c10d_functional.all_reduce_coalesced.default,
|
||||
_c10d_functional.all_reduce_coalesced_.default,
|
||||
_c10d_functional.all_gather_into_tensor_out.default,
|
||||
_c10d_functional.all_gather_into_tensor_coalesced.default,
|
||||
_c10d_functional_autograd.all_gather_into_tensor.default,
|
||||
_c10d_functional.reduce_scatter_tensor_coalesced.default,
|
||||
_c10d_functional_autograd.reduce_scatter_tensor.default,
|
||||
_c10d_functional.broadcast_.default,
|
||||
_dtensor.shard_dim_alltoall.default,
|
||||
}
|
||||
|
||||
sync_ops: set[torch._ops.OpOverload] = {
|
||||
c10d.barrier.default,
|
||||
c10d.monitored_barrier_.default,
|
||||
_c10d_functional.wait_tensor.default,
|
||||
}
|
||||
|
||||
collective_ops = set.union(functional_collectives, non_functional_collectives)
|
||||
|
||||
|
||||
class CollectiveOp:
|
||||
# Static sets for performance optimization
|
||||
PG_ARG_1 = {
|
||||
c10d.broadcast_.default,
|
||||
c10d.allreduce_.default,
|
||||
c10d.reduce_.default,
|
||||
c10d.send.default,
|
||||
c10d.recv_.default,
|
||||
c10d.recv_any_source_.default,
|
||||
c10d.barrier.default,
|
||||
# c10d.allreduce_coalesced_.default
|
||||
}
|
||||
|
||||
PG_ARG_2 = {
|
||||
c10d.allgather_.default,
|
||||
c10d._allgather_base_.default,
|
||||
c10d.reduce_scatter_.default,
|
||||
c10d._reduce_scatter_base_.default,
|
||||
c10d.gather_.default,
|
||||
c10d.scatter_.default,
|
||||
c10d.alltoall_.default,
|
||||
c10d.alltoall_base_.default,
|
||||
# c10d.allgather_coalesced_.default,
|
||||
# c10d.allgather_into_tensor_coalesced_.default
|
||||
# c10d.reduce_scatter_tensor_coalesced_.default
|
||||
}
|
||||
|
||||
PG_ARG_3 = {
|
||||
_c10d_functional.broadcast.default,
|
||||
_c10d_functional.broadcast_.default,
|
||||
_c10d_functional.all_reduce.default,
|
||||
_c10d_functional.all_reduce_.default,
|
||||
_c10d_functional.all_reduce_coalesced.default,
|
||||
_c10d_functional.all_reduce_coalesced_.default,
|
||||
_c10d_functional.all_gather_into_tensor.default,
|
||||
_c10d_functional.all_gather_into_tensor_out.default,
|
||||
_c10d_functional_autograd.all_gather_into_tensor.default,
|
||||
_c10d_functional.all_gather_into_tensor_coalesced.default,
|
||||
}
|
||||
|
||||
PG_ARG_4 = {
|
||||
_c10d_functional.reduce_scatter_tensor.default,
|
||||
_c10d_functional.reduce_scatter_tensor_coalesced.default,
|
||||
_c10d_functional_autograd.reduce_scatter_tensor.default,
|
||||
_c10d_functional.all_to_all_single.default,
|
||||
_c10d_functional_autograd.all_to_all_single.default,
|
||||
_dtensor.shard_dim_alltoall.default,
|
||||
}
|
||||
|
||||
WK_ARG_1 = {
|
||||
c10d.broadcast_.default,
|
||||
c10d.allreduce_.default,
|
||||
c10d.allgather_.default,
|
||||
c10d.reduce_scatter_.default,
|
||||
c10d._reduce_scatter_base_.default,
|
||||
c10d._allgather_base_.default,
|
||||
c10d.scatter_.default,
|
||||
c10d.alltoall_.default,
|
||||
}
|
||||
|
||||
WK = {
|
||||
c10d.send.default,
|
||||
c10d.recv_.default,
|
||||
c10d.recv_any_source_.default,
|
||||
c10d.reduce_.default,
|
||||
c10d.gather_.default,
|
||||
c10d.alltoall_base_.default,
|
||||
c10d.barrier.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_ARG_0 = {
|
||||
c10d.allreduce_.default,
|
||||
c10d.send.default,
|
||||
c10d.recv_.default,
|
||||
c10d.recv_any_source_.default,
|
||||
c10d.allgather_.default,
|
||||
c10d.gather_.default,
|
||||
c10d.reduce_.default,
|
||||
c10d.broadcast_.default,
|
||||
_c10d_functional.all_reduce_coalesced.default,
|
||||
_c10d_functional.all_reduce_coalesced_.default,
|
||||
# c10d.allreduce_coalesced_.default
|
||||
# c10d.allgather_coalesced_.default
|
||||
# c10d.allgather_into_tensor_coalesced_.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_ARG_1 = {
|
||||
c10d.reduce_scatter_.default,
|
||||
c10d.scatter_.default,
|
||||
# c10d.reduce_scatter_tensor_coalesced_.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_ARG_RES = {
|
||||
_c10d_functional.all_gather_into_tensor.default,
|
||||
_c10d_functional_autograd.all_gather_into_tensor.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_SINGLE_UNTYPED_STORAGE = {
|
||||
c10d._allgather_base_.default,
|
||||
_c10d_functional.broadcast.default,
|
||||
_c10d_functional.broadcast_.default,
|
||||
_c10d_functional.all_reduce.default,
|
||||
_c10d_functional.all_reduce_.default,
|
||||
_c10d_functional.reduce_scatter_tensor.default,
|
||||
_c10d_functional_autograd.reduce_scatter_tensor.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_ARG_0_AND_RES = {
|
||||
_c10d_functional.all_to_all_single.default,
|
||||
_c10d_functional_autograd.all_to_all_single.default,
|
||||
_dtensor.shard_dim_alltoall.default,
|
||||
}
|
||||
|
||||
COMM_TENSOR_RES_SUM = {
|
||||
_c10d_functional.all_gather_into_tensor_coalesced.default,
|
||||
_c10d_functional.reduce_scatter_tensor_coalesced.default,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def sum_tensors(arg: Any) -> int:
|
||||
"""Calculate total memory consumed by the tensors in the argument."""
|
||||
total_memory = 0
|
||||
|
||||
def sum_bytes(t: torch.Tensor) -> None:
|
||||
nonlocal total_memory
|
||||
total_memory += t.untyped_storage().nbytes()
|
||||
|
||||
tree_map_only(torch.Tensor, sum_bytes, arg)
|
||||
return total_memory
|
||||
|
||||
@staticmethod
|
||||
def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def]
|
||||
"""Retrieve the process group for collective operations, except `wait_tensor`."""
|
||||
if func in CollectiveOp.PG_ARG_1:
|
||||
return ProcessGroup.unbox(args[1])
|
||||
if func in CollectiveOp.PG_ARG_2:
|
||||
return ProcessGroup.unbox(args[2])
|
||||
if func in CollectiveOp.PG_ARG_3:
|
||||
return _resolve_process_group(args[2])
|
||||
if func in CollectiveOp.PG_ARG_4:
|
||||
return _resolve_process_group(args[3])
|
||||
raise TypeError(f"Func {func} not found in {collective_ops}")
|
||||
|
||||
@staticmethod
|
||||
def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def]
|
||||
"""Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`."""
|
||||
if func in CollectiveOp.COMM_TENSOR_ARG_0:
|
||||
return CollectiveOp.sum_tensors(args[0])
|
||||
if func in CollectiveOp.COMM_TENSOR_ARG_1:
|
||||
return CollectiveOp.sum_tensors(args[1])
|
||||
if func in CollectiveOp.COMM_TENSOR_ARG_RES:
|
||||
return res.untyped_storage().nbytes()
|
||||
if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE:
|
||||
return args[0].untyped_storage().nbytes()
|
||||
if func == c10d._reduce_scatter_base_.default:
|
||||
return args[1].untyped_storage().nbytes()
|
||||
if func == c10d.alltoall_.default:
|
||||
# TODO(@sanketpurandare) - Confirm size computation
|
||||
return max(
|
||||
CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1])
|
||||
)
|
||||
if func == c10d.alltoall_base_.default:
|
||||
# TODO(@sanketpurandare) - Confirm size computation
|
||||
return max(
|
||||
args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes()
|
||||
)
|
||||
if func == _c10d_functional.all_gather_into_tensor_out.default:
|
||||
return args[-1].untyped_storage().nbytes()
|
||||
if func in CollectiveOp.COMM_TENSOR_RES_SUM:
|
||||
return CollectiveOp.sum_tensors(res)
|
||||
if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES:
|
||||
# TODO(@sanketpurandare) - Confirm size computation
|
||||
return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes()
|
||||
raise TypeError(f"Unknown function: {func} in {collective_ops}")
|
||||
|
||||
@staticmethod
|
||||
def get_work(func, res) -> Work: # type: ignore[no-untyped-def]
|
||||
if func in CollectiveOp.WK:
|
||||
return FakeWork.unbox(res)
|
||||
elif func in CollectiveOp.WK_ARG_1:
|
||||
return FakeWork.unbox(res[1])
|
||||
raise TypeError(f"Func {func} not found in {collective_ops}")
|
||||
@ -1,23 +1,16 @@
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._tools.fake_collectives
|
||||
from torch import nn, optim
|
||||
from torch._guards import active_fake_mode
|
||||
from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_IllegalWork,
|
||||
ProcessGroup,
|
||||
ReduceOp,
|
||||
Work,
|
||||
)
|
||||
from torch.distributed.fsdp import FSDPModule
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||
from torch.futures import Future
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
@ -31,6 +24,8 @@ _P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
c10d = torch.ops.c10d
|
||||
|
||||
|
||||
class _FSDPRefType(_RefType):
|
||||
"""
|
||||
@ -68,13 +63,6 @@ class _SavedFSDPMethods(NamedTuple):
|
||||
post_backward: Callable
|
||||
|
||||
|
||||
class _SavedCollectives(NamedTuple):
|
||||
all_gather_into_tensor: Callable
|
||||
reduce_scatter_tensor: Callable
|
||||
all_reduce: Callable
|
||||
barrier: Callable
|
||||
|
||||
|
||||
class _FSDPModState(_State):
|
||||
"""
|
||||
Enumerates the states of FSDP modules during the forward and backward passes.
|
||||
@ -117,6 +105,15 @@ class _FSDPModMemStats:
|
||||
] = {}
|
||||
|
||||
|
||||
class _FSDPState(Enum):
|
||||
PRE_FW = auto()
|
||||
FW = auto()
|
||||
POST_FW = auto()
|
||||
PRE_BW = auto()
|
||||
BW = auto()
|
||||
POST_BW = auto()
|
||||
|
||||
|
||||
class FSDPMemTracker(MemTracker):
|
||||
"""
|
||||
A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track
|
||||
@ -166,9 +163,8 @@ class FSDPMemTracker(MemTracker):
|
||||
assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules"
|
||||
self._root_mod = mod
|
||||
self._optm = optm
|
||||
self._in_fake_mode: bool = False
|
||||
self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary()
|
||||
self._saved_collectives: _SavedCollectives
|
||||
self._fsdp_state: _FSDPState = _FSDPState.PRE_FW
|
||||
self._ref_class: type[_RefType] = _FSDPRefType
|
||||
|
||||
def _instrument_fsdp_sharded_params_grads(
|
||||
@ -209,6 +205,7 @@ class FSDPMemTracker(MemTracker):
|
||||
def inner(
|
||||
*args: _P.args, **kwargs: _P.kwargs
|
||||
) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]:
|
||||
self._fsdp_state = _FSDPState.PRE_FW
|
||||
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
|
||||
assert mod_fqn is not None
|
||||
if fsdp_mod not in self.memory_tracking:
|
||||
@ -251,6 +248,7 @@ class FSDPMemTracker(MemTracker):
|
||||
else:
|
||||
state = _FSDPModState.AFT_PRE_FW
|
||||
mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
|
||||
self._fsdp_state = _FSDPState.FW
|
||||
return args, kwargs
|
||||
|
||||
return inner
|
||||
@ -276,6 +274,7 @@ class FSDPMemTracker(MemTracker):
|
||||
else:
|
||||
state = _FSDPModState.BEF_POST_FW
|
||||
mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
|
||||
self._fsdp_state = _FSDPState.POST_FW
|
||||
|
||||
output = orig_fsdp_state_post_fw(*args, **kwargs)
|
||||
|
||||
@ -296,6 +295,7 @@ class FSDPMemTracker(MemTracker):
|
||||
# and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module.
|
||||
@wraps(orig_fsdp_param_group_pre_backward)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
self._fsdp_state = _FSDPState.PRE_BW
|
||||
mod_stat = self.memory_tracking[fsdp_mod]
|
||||
snapshot = self.get_tracker_snapshot()
|
||||
mod_stat.local_peak = {
|
||||
@ -310,6 +310,7 @@ class FSDPMemTracker(MemTracker):
|
||||
mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append(
|
||||
self.get_tracker_snapshot()
|
||||
)
|
||||
self._fsdp_state = _FSDPState.BW
|
||||
|
||||
return inner
|
||||
|
||||
@ -338,7 +339,7 @@ class FSDPMemTracker(MemTracker):
|
||||
mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append(
|
||||
self.get_tracker_snapshot()
|
||||
)
|
||||
|
||||
self._fsdp_state = _FSDPState.POST_BW
|
||||
orig_fsdp_param_group_post_backward(*args, **kwargs)
|
||||
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
@ -453,110 +454,6 @@ class FSDPMemTracker(MemTracker):
|
||||
handle.remove()
|
||||
self._optimizer_hook_handles = None
|
||||
|
||||
def _instrument_and_maybe_bypass_collectives(self) -> None:
|
||||
# Monkey-patching collectives is required because they do not work with `FakeTensorMode`
|
||||
# It's also easier to track `all_gather` and `reduce_scatter` buffers faithfully.
|
||||
self._saved_collectives = _SavedCollectives(
|
||||
dist.all_gather_into_tensor,
|
||||
dist.reduce_scatter_tensor,
|
||||
dist.all_reduce,
|
||||
dist.barrier,
|
||||
)
|
||||
|
||||
class FakeWork(Work):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_future(self) -> Future:
|
||||
future: Future = Future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
||||
return True
|
||||
|
||||
@wraps(dist.all_gather_into_tensor)
|
||||
def all_gather_into_tensor(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
group: Union[ProcessGroup, None] = None,
|
||||
async_op: bool = False,
|
||||
) -> Union[Work, _IllegalWork, None]:
|
||||
self._update_and_maybe_create_winfos(
|
||||
output_tensor,
|
||||
_FSDPRefType.ALL_GATHER,
|
||||
update_existing=True,
|
||||
)
|
||||
|
||||
if self._in_fake_mode:
|
||||
if async_op:
|
||||
return FakeWork()
|
||||
return None
|
||||
else:
|
||||
return self._saved_collectives.all_gather_into_tensor(
|
||||
output_tensor, input_tensor, group, async_op
|
||||
)
|
||||
|
||||
@wraps(dist.reduce_scatter_tensor)
|
||||
def reduce_scatter_tensor(
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
|
||||
group: Union[ProcessGroup, None] = None,
|
||||
async_op: bool = False,
|
||||
) -> Union[Work, _IllegalWork, None]:
|
||||
self._update_and_maybe_create_winfos(
|
||||
input,
|
||||
_FSDPRefType.REDUCE_SCATTER,
|
||||
update_existing=True,
|
||||
)
|
||||
|
||||
if self._in_fake_mode:
|
||||
if async_op:
|
||||
return FakeWork()
|
||||
return None
|
||||
else:
|
||||
return self._saved_collectives.reduce_scatter_tensor(
|
||||
output, input, op, group, async_op
|
||||
)
|
||||
|
||||
@wraps(dist.all_reduce)
|
||||
def all_reduce(
|
||||
tensor: torch.Tensor,
|
||||
op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
|
||||
group: Union[ProcessGroup, None] = None,
|
||||
async_op: bool = False,
|
||||
) -> Union[Work, _IllegalWork, None]:
|
||||
if self._in_fake_mode:
|
||||
if async_op:
|
||||
return FakeWork()
|
||||
return None
|
||||
else:
|
||||
return self._saved_collectives.all_reduce(tensor, op, group, async_op)
|
||||
|
||||
@wraps(dist.barrier)
|
||||
def barrier(
|
||||
group: Union[ProcessGroup, None] = dist.GroupMember.WORLD,
|
||||
async_op: bool = False,
|
||||
device_ids: Union[list[int], None] = None,
|
||||
) -> Union[Work, None]:
|
||||
if self._in_fake_mode:
|
||||
return None
|
||||
else:
|
||||
return self._saved_collectives.barrier(group, async_op, device_ids)
|
||||
|
||||
dist.all_gather_into_tensor = all_gather_into_tensor
|
||||
dist.reduce_scatter_tensor = reduce_scatter_tensor
|
||||
dist.all_reduce = all_reduce
|
||||
dist.barrier = barrier
|
||||
|
||||
def _restore_collectives(self) -> None:
|
||||
dist.all_gather_into_tensor = self._saved_collectives.all_gather_into_tensor
|
||||
dist.reduce_scatter_tensor = self._saved_collectives.reduce_scatter_tensor
|
||||
dist.all_reduce = self._saved_collectives.all_reduce
|
||||
dist.barrier = self._saved_collectives.barrier
|
||||
del self._saved_collectives
|
||||
|
||||
def track_inputs(self, inputs: tuple[Any, ...]) -> None:
|
||||
"""
|
||||
This is used to track the input tensors to the model and annotate them as ``Inputs``.
|
||||
@ -579,27 +476,39 @@ class FSDPMemTracker(MemTracker):
|
||||
"""This is no-op for ``FSDPMemTracker``"""
|
||||
|
||||
def __enter__(self) -> "FSDPMemTracker":
|
||||
self._in_fake_mode = True if active_fake_mode() else False
|
||||
self._register_module_and_optimizer_hooks()
|
||||
self._instrument_and_maybe_bypass_collectives()
|
||||
self._track_resize()
|
||||
self._peak_mem_snap = self.get_tracker_snapshot()
|
||||
self._peak_mem = {
|
||||
dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
|
||||
}
|
||||
self._mod_tracker.__enter__()
|
||||
if self._depth == 0:
|
||||
self._register_module_and_optimizer_hooks()
|
||||
self._track_resize()
|
||||
self._track_dtensor_dispatch()
|
||||
self._peak_mem_snap = self.get_tracker_snapshot()
|
||||
self._peak_mem = {
|
||||
dev: dev_snap[_TOTAL_KEY]
|
||||
for dev, dev_snap in self._peak_mem_snap.items()
|
||||
}
|
||||
self._mod_tracker.__enter__()
|
||||
TorchDispatchMode.__enter__(self)
|
||||
self._depth += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self._deregister_module_and_optimizer_hooks()
|
||||
self._restore_collectives()
|
||||
self._restore_resize()
|
||||
self._depth -= 1
|
||||
if self._depth == 0:
|
||||
self._deregister_module_and_optimizer_hooks()
|
||||
self._restore_resize()
|
||||
self._restore_dtensor_dispatch()
|
||||
self._mod_tracker.__exit__(*args)
|
||||
TorchDispatchMode.__exit__(self, *args)
|
||||
self._mod_tracker.__exit__(*args)
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def]
|
||||
res = func(*args, **kwargs or {})
|
||||
if (
|
||||
func == torch.ops._c10d_functional.wait_tensor.default
|
||||
and active_fake_mode()
|
||||
):
|
||||
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
|
||||
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
|
||||
res = args[0]
|
||||
else:
|
||||
res = func(*args, **kwargs or {})
|
||||
# If we are tracking an optimizer state, we use the optimizer reference type.
|
||||
# If we are in backward region and not in AC region, we use the backward reference type.
|
||||
# Else we use the forward reference type.
|
||||
@ -609,6 +518,27 @@ class FSDPMemTracker(MemTracker):
|
||||
reftype = _FSDPRefType.TEMP
|
||||
else:
|
||||
reftype = _FSDPRefType.ACT
|
||||
if func == c10d._allgather_base_.default and self._fsdp_state in [
|
||||
_FSDPState.PRE_FW,
|
||||
_FSDPState.PRE_BW,
|
||||
]:
|
||||
output_tensor = args[0]
|
||||
self._update_and_maybe_create_winfos(
|
||||
output_tensor,
|
||||
_FSDPRefType.ALL_GATHER,
|
||||
update_existing=True,
|
||||
)
|
||||
if (
|
||||
func == c10d._reduce_scatter_base_.default
|
||||
and self._fsdp_state == _FSDPState.POST_BW
|
||||
):
|
||||
input_tensor = args[1]
|
||||
self._update_and_maybe_create_winfos(
|
||||
input_tensor,
|
||||
_FSDPRefType.REDUCE_SCATTER,
|
||||
update_existing=True,
|
||||
)
|
||||
|
||||
tree_map_only(torch.Tensor, partial(self._track, reftype), res)
|
||||
peak_state = (
|
||||
_FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW
|
||||
|
||||
@ -2,6 +2,7 @@ import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
@ -9,16 +10,17 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
import torch.distributed._tools.fake_collectives
|
||||
from torch import nn, optim
|
||||
from torch._guards import active_fake_mode
|
||||
from torch.distributed._tools.common_utils import get_untyped_storages
|
||||
from torch.distributed._tools.mod_tracker import ModTracker
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.optim.optimizer import (
|
||||
register_optimizer_step_post_hook,
|
||||
register_optimizer_step_pre_hook,
|
||||
)
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
|
||||
@ -170,35 +172,6 @@ class _WeakRefInfo:
|
||||
self.mem_consumed = self._calculate_mem_consumed()
|
||||
return self.mem_consumed
|
||||
|
||||
@staticmethod
|
||||
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
||||
"""
|
||||
Recursively extracts untyped storages from a tensor or its subclasses.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): The tensor to extract storages from.
|
||||
|
||||
Returns:
|
||||
set[torch.UntypedStorage]: A set of untyped storages.
|
||||
"""
|
||||
unflattened_tensors = [t]
|
||||
flattened_tensor_storages = set()
|
||||
while len(unflattened_tensors) > 0:
|
||||
obj = unflattened_tensors.pop()
|
||||
if is_traceable_wrapper_subclass(obj):
|
||||
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
|
||||
else:
|
||||
if not hasattr(obj, "untyped_storage"):
|
||||
warnings.warn(
|
||||
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
flattened_tensor_storages.add(obj.untyped_storage())
|
||||
return flattened_tensor_storages
|
||||
|
||||
@classmethod
|
||||
def create_winfo(
|
||||
cls,
|
||||
@ -253,7 +226,9 @@ def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) ->
|
||||
print(
|
||||
f"Device: {dev}",
|
||||
*(
|
||||
f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}"
|
||||
f"\t{k.value}: {_rounding_fn(v, divisor, 2)} {units}"
|
||||
if isinstance(k, _RefType)
|
||||
else f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}"
|
||||
for k, v in dev_snap.items()
|
||||
),
|
||||
sep="\n",
|
||||
@ -275,7 +250,9 @@ def _print_snapshot_tabular(
|
||||
divisor = _get_mem_divisor(units)
|
||||
table_data = []
|
||||
key_list = list(next(iter(snapshot.values())).keys())
|
||||
headers = ["Device"] + [f"{key}" for key in key_list]
|
||||
headers = ["Device"] + [
|
||||
f"{key.value}" if isinstance(key, _RefType) else f"{key}" for key in key_list
|
||||
]
|
||||
|
||||
for dev, dev_snap in snapshot.items():
|
||||
if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
|
||||
@ -290,7 +267,7 @@ def _print_state_snapshots(
|
||||
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
|
||||
) -> None:
|
||||
for state, snapshot_list in snapshots.items():
|
||||
print(f"{state}")
|
||||
print(f"{state.value}")
|
||||
for i, snapshot in enumerate(snapshot_list):
|
||||
print(f"# {i + 1}:")
|
||||
_print_snapshot(snapshot, units)
|
||||
@ -312,7 +289,7 @@ def _print_state_snapshots_tabular(
|
||||
divisor = _get_mem_divisor(units)
|
||||
for state, snapshot_list in snapshots.items():
|
||||
for i, snapshot in enumerate(snapshot_list):
|
||||
state_call = f"{state} # {i + 1}"
|
||||
state_call = f"{state.value} # {i + 1}"
|
||||
for dev, dev_snap in snapshot.items():
|
||||
if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
|
||||
continue
|
||||
@ -324,7 +301,9 @@ def _print_state_snapshots_tabular(
|
||||
}
|
||||
last_state_call = state_call
|
||||
for k, v in dev_snap.items():
|
||||
row[f"{k}"] = f"{_rounding_fn(v, divisor, 2)} {units}"
|
||||
row[f"{k.value}" if isinstance(k, _RefType) else f"{k}"] = (
|
||||
f"{_rounding_fn(v, divisor, 2)} {units}"
|
||||
)
|
||||
table_data.append(row)
|
||||
print(tabulate(table_data, headers="keys", tablefmt="rst"))
|
||||
|
||||
@ -411,6 +390,8 @@ class MemTracker(TorchDispatchMode):
|
||||
# Weak references to the topmost AC module currently active
|
||||
self._ac_mod: Optional[weakref.ref] = None
|
||||
self._orig_resize = torch.UntypedStorage.resize_
|
||||
self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch
|
||||
self._depth = 0
|
||||
|
||||
def _update_snap(
|
||||
self,
|
||||
@ -462,7 +443,7 @@ class MemTracker(TorchDispatchMode):
|
||||
reftype: _RefType,
|
||||
update_existing: bool = False,
|
||||
) -> set[_WeakRefInfo]:
|
||||
sts = _WeakRefInfo.get_untyped_storages(t)
|
||||
sts = get_untyped_storages(t)
|
||||
winfos = set()
|
||||
for st in sts:
|
||||
# Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary.
|
||||
@ -543,7 +524,7 @@ class MemTracker(TorchDispatchMode):
|
||||
# Get the storages of the tensor and check if we have already tracked them.
|
||||
# If yes, then check if the storage size has changed and update the current snapshot.
|
||||
# Else create a new ``_WeakRefInfo`` instance and add it to the dictionary.
|
||||
sts = _WeakRefInfo.get_untyped_storages(t)
|
||||
sts = get_untyped_storages(t)
|
||||
for st in sts:
|
||||
winfo, _ = self._WINFO.get(st, (None, None))
|
||||
if winfo is not None:
|
||||
@ -640,7 +621,7 @@ class MemTracker(TorchDispatchMode):
|
||||
|
||||
def add_inps_or_outs(t: torch.Tensor) -> None:
|
||||
nonlocal input_or_output_memory
|
||||
sts = _WeakRefInfo.get_untyped_storages(t)
|
||||
sts = get_untyped_storages(t)
|
||||
for st in sts:
|
||||
winfo, _ = self._WINFO.get(st, (None, None))
|
||||
if winfo is not None:
|
||||
@ -694,6 +675,7 @@ class MemTracker(TorchDispatchMode):
|
||||
mod_stats = self.memory_tracking[module]
|
||||
state = _ModState.PRE_FW
|
||||
input_mem = self._track_inputs_or_outputs(inputs)
|
||||
mod_stats.mod_fqn = mod_name
|
||||
mod_stats.input_mem = input_mem
|
||||
|
||||
mem_snapshot = self.get_tracker_snapshot()
|
||||
@ -826,6 +808,8 @@ class MemTracker(TorchDispatchMode):
|
||||
self._track_module_params_and_buffers(obj, install_grad_hooks=False)
|
||||
elif isinstance(obj, optim.Optimizer):
|
||||
self._track_optimizer_states(_MemRefType.OPT, obj)
|
||||
elif obj is None:
|
||||
continue
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Object of type {type(obj)} is not supported for tracking. "
|
||||
@ -891,32 +875,65 @@ class MemTracker(TorchDispatchMode):
|
||||
"""
|
||||
self.memory_tracking.clear()
|
||||
|
||||
def _track_dtensor_dispatch(self) -> None:
|
||||
def track_dtensor_dispatch(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
with (
|
||||
self
|
||||
if op_call in DTensor._op_dispatcher._custom_op_handlers
|
||||
else nullcontext()
|
||||
):
|
||||
return self._orig_dtensor_dispatch(op_call, args, kwargs)
|
||||
|
||||
DTensor._op_dispatcher.dispatch = track_dtensor_dispatch # type: ignore[method-assign, assignment]
|
||||
|
||||
def _restore_dtensor_dispatch(self) -> None:
|
||||
DTensor._op_dispatcher.dispatch = self._orig_dtensor_dispatch # type: ignore[method-assign]
|
||||
|
||||
def __enter__(self) -> "MemTracker":
|
||||
self._register_global_optimizer_hook()
|
||||
self._mod_tracker.register_user_hooks(
|
||||
self._pre_fw_hook,
|
||||
self._post_fw_hook,
|
||||
self._pre_bw_hook,
|
||||
self._post_bw_hook,
|
||||
)
|
||||
self._track_resize()
|
||||
self._peak_mem_snap = self.get_tracker_snapshot()
|
||||
self._peak_mem = {
|
||||
dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
|
||||
}
|
||||
self._mod_tracker.__enter__()
|
||||
if self._depth == 0:
|
||||
self._register_global_optimizer_hook()
|
||||
self._mod_tracker.register_user_hooks(
|
||||
self._pre_fw_hook,
|
||||
self._post_fw_hook,
|
||||
self._pre_bw_hook,
|
||||
self._post_bw_hook,
|
||||
)
|
||||
self._track_resize()
|
||||
self._track_dtensor_dispatch()
|
||||
self._peak_mem_snap = self.get_tracker_snapshot()
|
||||
self._peak_mem = {
|
||||
dev: dev_snap[_TOTAL_KEY]
|
||||
for dev, dev_snap in self._peak_mem_snap.items()
|
||||
}
|
||||
self._mod_tracker.__enter__()
|
||||
super().__enter__()
|
||||
self._depth += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self._deregister_param_and_optimizer_hooks()
|
||||
self._mod_tracker.clear_user_hooks()
|
||||
self._restore_resize()
|
||||
self._depth -= 1
|
||||
if self._depth == 0:
|
||||
self._deregister_param_and_optimizer_hooks()
|
||||
self._mod_tracker.clear_user_hooks()
|
||||
self._restore_resize()
|
||||
self._restore_dtensor_dispatch()
|
||||
self._mod_tracker.__exit__(*args)
|
||||
super().__exit__(*args)
|
||||
self._mod_tracker.__exit__(*args)
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def]
|
||||
res = func(*args, **kwargs or {})
|
||||
if (
|
||||
func == torch.ops._c10d_functional.wait_tensor.default
|
||||
and active_fake_mode()
|
||||
):
|
||||
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
|
||||
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
|
||||
res = args[0]
|
||||
else:
|
||||
res = func(*args, **kwargs or {})
|
||||
# If we are tracking an optimizer state, we use the optimizer reference type.
|
||||
# If we are in backward region and not in AC region, we use the backward reference type.
|
||||
# Else we use the forward reference type.
|
||||
|
||||
@ -60,6 +60,7 @@ class ModTracker:
|
||||
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
|
||||
self._has_callback = False
|
||||
self._post_bw_callbacks_to_enqueue: list[Callable] = []
|
||||
self._user_pre_fw_hook = None
|
||||
self._user_post_fw_hook = None
|
||||
self._user_pre_bw_hook = None
|
||||
@ -70,6 +71,10 @@ class ModTracker:
|
||||
if self._has_callback:
|
||||
return
|
||||
|
||||
for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue):
|
||||
torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback)
|
||||
self._post_bw_callbacks_to_enqueue.clear()
|
||||
|
||||
def callback():
|
||||
self.parents = {"Global"}
|
||||
self._has_callback = False
|
||||
@ -213,8 +218,13 @@ class ModTracker:
|
||||
self._user_pre_fw_hook(mod, input)
|
||||
args, _ = tree_flatten(input)
|
||||
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
|
||||
if not self.is_bw and tensors:
|
||||
register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
|
||||
if not self.is_bw:
|
||||
if tensors:
|
||||
register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
|
||||
else:
|
||||
self._post_bw_callbacks_to_enqueue.append(
|
||||
self._get_pop_fn(w_mod, name, True)
|
||||
)
|
||||
|
||||
def _fw_post_hook(self, mod, input, output):
|
||||
name = self._get_mod_name(mod)
|
||||
@ -225,7 +235,9 @@ class ModTracker:
|
||||
args, _ = tree_flatten(output)
|
||||
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
|
||||
if not self.is_bw and tensors:
|
||||
register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True))
|
||||
register_multi_grad_hook(
|
||||
tensors, self._get_append_fn(w_mod, name, True), mode="any"
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, NamedTuple, Optional
|
||||
@ -11,6 +10,7 @@ import torch
|
||||
from torch import nan, nn, UntypedStorage
|
||||
from torch._guards import active_fake_mode
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed._tools.common_utils import get_untyped_storages
|
||||
from torch.distributed._tools.mod_tracker import ModTracker
|
||||
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
|
||||
from torch.testing._internal.composite_compliance import (
|
||||
@ -18,10 +18,7 @@ from torch.testing._internal.composite_compliance import (
|
||||
is_inplace_view_fn,
|
||||
is_view_fn,
|
||||
)
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.utils.checkpoint import SAC_IGNORED_OPS
|
||||
|
||||
@ -42,38 +39,6 @@ _PYTORCH_MIN_ALLOCATE = (
|
||||
)
|
||||
|
||||
|
||||
def _get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
||||
"""
|
||||
Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): Input `torch.Tensor` or traceable wrapper-subclass of `torch.Tensor`.
|
||||
|
||||
Returns:
|
||||
set[torch.UntypedStorage]: Set of untyped storages.
|
||||
|
||||
Warns:
|
||||
UserWarning: If the flattened input is not a tensor or traceable wrapper-subclass.
|
||||
"""
|
||||
unflattened_tensors = [t]
|
||||
flattened_tensor_storages = set()
|
||||
while len(unflattened_tensors) > 0:
|
||||
obj = unflattened_tensors.pop()
|
||||
if is_traceable_wrapper_subclass(obj):
|
||||
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
|
||||
else:
|
||||
if not hasattr(obj, "untyped_storage"):
|
||||
warnings.warn(
|
||||
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
flattened_tensor_storages.add(obj.untyped_storage())
|
||||
return flattened_tensor_storages
|
||||
|
||||
|
||||
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
@ -268,7 +233,7 @@ class SACEstimator(TorchDispatchMode):
|
||||
# Hook function to track underlying storage IDs of tensors
|
||||
# Updates the _saved_tensor_ids set with the IDs of the tensor's storages
|
||||
# Used in conjunction with torch.autograd.graph.saved_tensors_hooks
|
||||
untyped_storages = _get_untyped_storages(x)
|
||||
untyped_storages = get_untyped_storages(x)
|
||||
storage_ids = (hash(st) for st in untyped_storages)
|
||||
self._saved_tensor_ids.update(storage_ids)
|
||||
return x
|
||||
@ -436,10 +401,10 @@ class SACEstimator(TorchDispatchMode):
|
||||
for o in flat_outs:
|
||||
if isinstance(o, torch.Tensor):
|
||||
if o.device.type == "cuda":
|
||||
out_storages_cuda.update(_get_untyped_storages(o))
|
||||
out_storages_cuda.update(get_untyped_storages(o))
|
||||
cuda_devices.add(o.device)
|
||||
else:
|
||||
out_storages_cpu.update(_get_untyped_storages(o))
|
||||
out_storages_cpu.update(get_untyped_storages(o))
|
||||
|
||||
# Check if there's more than 1 CUDA device
|
||||
assert len(cuda_devices) <= 1, (
|
||||
|
||||
Reference in New Issue
Block a user