Revert "add torchrec collectives to enforce global ordering (#141970)"

This reverts commit ceb94d6a7d38930d662e7eb71b9c7620de8c2997.

Reverted https://github.com/pytorch/pytorch/pull/141970 on behalf of https://github.com/malfet due to Apologies for reverting this change, but it broke MacOS testing, but CI was broken at the time ([comment](https://github.com/pytorch/pytorch/pull/141970#issuecomment-2529367680))
This commit is contained in:
PyTorch MergeBot
2024-12-09 20:25:04 +00:00
parent 960a81fdcd
commit 5c76a2834d
2 changed files with 7 additions and 34 deletions

View File

@ -681,10 +681,6 @@ class BaseSchedulerNode:
# falling back to 0
log.info(e)
return 0
except TypeError as e:
# this happens when the collective is not of type ir._CollectiveKernel
log.info(e)
return 0
elif is_wait(self.node):
# ir.Wait is only used for collective ops.
@ -1853,11 +1849,12 @@ class Scheduler:
self.dead_node_elimination()
self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
self.compute_ancestors()
self.nodes = comms.decide_global_ordering_of_comms(
self.nodes,
self.name_to_buf,
self.name_to_fused_node,
)
if config.reorder_for_compute_comm_overlap:
self.nodes = comms.decide_global_ordering_of_comms(
self.nodes,
self.name_to_buf,
self.name_to_fused_node,
)
metrics.ir_nodes_pre_fusion += len(self.nodes)
V.debug.ir_pre_fusion(self.nodes)

View File

@ -1754,31 +1754,7 @@ def pass_execution_and_save(func, gm, inp, msg):
def is_collective(node, op=None):
from . import ir
return (
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
) or (
# TODO: this is a temporary solution to ensure that we can identify torchrec's
# communication ops. But in order to allow better communication and computation
# overlap, torchrec's communication ops should be not used.
type(node) == ir.FallbackKernel
and (
# NOTE: the `hasattr()` check is to bypass errors such as the following:
# AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
(
hasattr(torch.ops.torchrec, "all_to_all_single")
and node.op_overload == torch.ops.torchrec.all_to_all_single.default
)
or (
hasattr(torch.ops.torchrec, "all_gather_into_tensor")
and node.op_overload
== torch.ops.torchrec.all_gather_into_tensor.default
)
or (
hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
)
)
)
return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
def is_wait(node):