mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "add torchrec collectives to enforce global ordering (#141970)"
This reverts commit ceb94d6a7d38930d662e7eb71b9c7620de8c2997. Reverted https://github.com/pytorch/pytorch/pull/141970 on behalf of https://github.com/malfet due to Apologies for reverting this change, but it broke MacOS testing, but CI was broken at the time ([comment](https://github.com/pytorch/pytorch/pull/141970#issuecomment-2529367680))
This commit is contained in:
@ -681,10 +681,6 @@ class BaseSchedulerNode:
|
||||
# falling back to 0
|
||||
log.info(e)
|
||||
return 0
|
||||
except TypeError as e:
|
||||
# this happens when the collective is not of type ir._CollectiveKernel
|
||||
log.info(e)
|
||||
return 0
|
||||
|
||||
elif is_wait(self.node):
|
||||
# ir.Wait is only used for collective ops.
|
||||
@ -1853,11 +1849,12 @@ class Scheduler:
|
||||
self.dead_node_elimination()
|
||||
self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
|
||||
self.compute_ancestors()
|
||||
self.nodes = comms.decide_global_ordering_of_comms(
|
||||
self.nodes,
|
||||
self.name_to_buf,
|
||||
self.name_to_fused_node,
|
||||
)
|
||||
if config.reorder_for_compute_comm_overlap:
|
||||
self.nodes = comms.decide_global_ordering_of_comms(
|
||||
self.nodes,
|
||||
self.name_to_buf,
|
||||
self.name_to_fused_node,
|
||||
)
|
||||
|
||||
metrics.ir_nodes_pre_fusion += len(self.nodes)
|
||||
V.debug.ir_pre_fusion(self.nodes)
|
||||
|
||||
@ -1754,31 +1754,7 @@ def pass_execution_and_save(func, gm, inp, msg):
|
||||
def is_collective(node, op=None):
|
||||
from . import ir
|
||||
|
||||
return (
|
||||
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
|
||||
) or (
|
||||
# TODO: this is a temporary solution to ensure that we can identify torchrec's
|
||||
# communication ops. But in order to allow better communication and computation
|
||||
# overlap, torchrec's communication ops should be not used.
|
||||
type(node) == ir.FallbackKernel
|
||||
and (
|
||||
# NOTE: the `hasattr()` check is to bypass errors such as the following:
|
||||
# AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
|
||||
(
|
||||
hasattr(torch.ops.torchrec, "all_to_all_single")
|
||||
and node.op_overload == torch.ops.torchrec.all_to_all_single.default
|
||||
)
|
||||
or (
|
||||
hasattr(torch.ops.torchrec, "all_gather_into_tensor")
|
||||
and node.op_overload
|
||||
== torch.ops.torchrec.all_gather_into_tensor.default
|
||||
)
|
||||
or (
|
||||
hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
|
||||
and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
|
||||
)
|
||||
)
|
||||
)
|
||||
return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
|
||||
|
||||
|
||||
def is_wait(node):
|
||||
|
||||
Reference in New Issue
Block a user