Files
pytorch/torch/distributed/_tools/fake_collectives.py
PaliC 6162e650b0 [BE] remove torch deploy - conditionals (#158288)
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always

Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
2025-07-29 17:40:49 +00:00

307 lines
12 KiB
Python

import random
from typing import Any
import torch
from torch._C._distributed_c10d import (
_resolve_process_group,
FakeWork,
ProcessGroup,
Work,
)
from torch.utils._pytree import tree_map_only
torch.distributed.batch_isend_irecv
c10d = torch.ops.c10d
_c10d_functional = torch.ops._c10d_functional
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
_dtensor = torch.ops._dtensor
used_ids: set[int] = set()
def generate_unique_id() -> int:
while True:
new_id = random.randint(1, 10**9)
if new_id not in used_ids:
used_ids.add(new_id)
return new_id
# Function to create and return FakeWork object
def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def]
work = FakeWork()
work.seq_id = generate_unique_id()
fakework_script_obj = work.boxed()
return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj
# Dictionary mapping collective operations to their meta functions
# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included
# _DEPRECATED_META_FUNCTIONS = {
# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# }
_META_FUNCTIONS = {
"broadcast_": lambda *args: create_fakework(args),
"allreduce_": lambda *args: create_fakework(args),
"allgather_": lambda *args: create_fakework(args),
"_allgather_base_": lambda *args: create_fakework(args),
"reduce_scatter_": lambda *args: create_fakework(args),
"_reduce_scatter_base_": lambda *args: create_fakework(args),
"reduce_": lambda *args: create_fakework(args, return_first_arg=False),
"gather_": lambda *args: create_fakework(args, return_first_arg=False),
"scatter_": lambda *args: create_fakework(args),
"alltoall_": lambda *args: create_fakework(args),
"alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False),
"barrier": lambda *args: create_fakework(args, return_first_arg=False),
"monitored_barrier_": lambda *args: None,
"send": lambda *args: create_fakework(args, return_first_arg=False),
"recv_": lambda *args: create_fakework(args, return_first_arg=False),
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
}
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
for op, meta_func in _META_FUNCTIONS.items():
lib_impl.impl(op, meta_func, "Meta")
# List of collective operation functions including functional collectives
# Note: The following collectives might be deprecated soon hence not adding them
# depcreated_non_functional_collectives = [
# c10d.allreduce_coalesced_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default,
# c10d.allgather_coalesced_.default,
# ]
non_functional_collectives: set[torch._ops.OpOverload] = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
c10d.monitored_barrier_.default,
}
functional_collectives: set[torch._ops.OpOverload] = {
_c10d_functional.broadcast.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_c10d_functional.wait_tensor.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.broadcast_.default,
_dtensor.shard_dim_alltoall.default,
}
sync_ops: set[torch._ops.OpOverload] = {
c10d.barrier.default,
c10d.monitored_barrier_.default,
_c10d_functional.wait_tensor.default,
}
collective_ops = set.union(functional_collectives, non_functional_collectives)
class CollectiveOp:
# Static sets for performance optimization
PG_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.barrier.default,
# c10d.allreduce_coalesced_.default
}
PG_ARG_2 = {
c10d.allgather_.default,
c10d._allgather_base_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
# c10d.allgather_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default
# c10d.reduce_scatter_tensor_coalesced_.default
}
PG_ARG_3 = {
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
}
PG_ARG_4 = {
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
WK_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
}
WK = {
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.reduce_.default,
c10d.gather_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
}
COMM_TENSOR_ARG_0 = {
c10d.allreduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.gather_.default,
c10d.reduce_.default,
c10d.broadcast_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
# c10d.allreduce_coalesced_.default
# c10d.allgather_coalesced_.default
# c10d.allgather_into_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_1 = {
c10d.reduce_scatter_.default,
c10d.scatter_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_RES = {
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
}
COMM_TENSOR_SINGLE_UNTYPED_STORAGE = {
c10d._allgather_base_.default,
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
}
COMM_TENSOR_ARG_0_AND_RES = {
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
COMM_TENSOR_RES_SUM = {
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
}
@staticmethod
def sum_tensors(arg: Any) -> int:
"""Calculate total memory consumed by the tensors in the argument."""
total_memory = 0
def sum_bytes(t: torch.Tensor) -> None:
nonlocal total_memory
total_memory += t.untyped_storage().nbytes()
tree_map_only(torch.Tensor, sum_bytes, arg)
return total_memory
@staticmethod
def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def]
"""Retrieve the process group for collective operations, except `wait_tensor`."""
if func in CollectiveOp.PG_ARG_1:
return ProcessGroup.unbox(args[1])
if func in CollectiveOp.PG_ARG_2:
return ProcessGroup.unbox(args[2])
if func in CollectiveOp.PG_ARG_3:
return _resolve_process_group(args[2])
if func in CollectiveOp.PG_ARG_4:
return _resolve_process_group(args[3])
raise TypeError(f"Func {func} not found in {collective_ops}")
@staticmethod
def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def]
"""Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`."""
if func in CollectiveOp.COMM_TENSOR_ARG_0:
return CollectiveOp.sum_tensors(args[0])
if func in CollectiveOp.COMM_TENSOR_ARG_1:
return CollectiveOp.sum_tensors(args[1])
if func in CollectiveOp.COMM_TENSOR_ARG_RES:
return res.untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE:
return args[0].untyped_storage().nbytes()
if func == c10d._reduce_scatter_base_.default:
return args[1].untyped_storage().nbytes()
if func == c10d.alltoall_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1])
)
if func == c10d.alltoall_base_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes()
)
if func == _c10d_functional.all_gather_into_tensor_out.default:
return args[-1].untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_RES_SUM:
return CollectiveOp.sum_tensors(res)
if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES:
# TODO(@sanketpurandare) - Confirm size computation
return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes()
raise TypeError(f"Unknown function: {func} in {collective_ops}")
@staticmethod
def get_work(func, res) -> Work: # type: ignore[no-untyped-def]
if func in CollectiveOp.WK:
return FakeWork.unbox(res)
elif func in CollectiveOp.WK_ARG_1:
return FakeWork.unbox(res[1])
raise TypeError(f"Func {func} not found in {collective_ops}")