mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplifying computation of the final result for equals op on DTensor (#164999)
Instead of collecting local results using all_gather_object followed by local reduction, with this change we switch to using a single all_reduce with MIN reduction operation to compute the final equals result. This change is needed to enable LocalTensor work (all_gather_object introduces challenges in for DTensor and LocalTensor integration). topic: not user facing Pull Request resolved: https://github.com/pytorch/pytorch/pull/164999 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a61d0de9f9
commit
ae25dd51fc
@ -300,6 +300,7 @@ dtensor_fails = {
|
||||
xfail("nn.functional.multi_margin_loss"),
|
||||
xfail("nn.functional.multilabel_margin_loss"),
|
||||
xfail("nn.functional.multilabel_soft_margin_loss"),
|
||||
xfail("nn.functional.multi_head_attention_forward"),
|
||||
xfail("nn.functional.pad", "reflect"),
|
||||
xfail("nn.functional.pad", "replicate"),
|
||||
xfail("nn.functional.pad", "replicate_negative"),
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import operator
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional
|
||||
@ -175,6 +173,7 @@ class OpDispatcher:
|
||||
|
||||
mesh = op_info.compute_mesh
|
||||
participating = mesh.get_coordinate() is not None
|
||||
local_results = None
|
||||
if participating:
|
||||
# computation that happens in the current rank of the mesh, normal case
|
||||
if output_sharding.needs_redistribute:
|
||||
@ -278,14 +277,16 @@ class OpDispatcher:
|
||||
|
||||
if output_sharding.output_spec is None:
|
||||
if op_call == aten.equal.default:
|
||||
# For equal operator, The local results from all devices should be all-gathered
|
||||
# and a reduce op (AND) will be performed on the list of results to ensure SPMD
|
||||
# execution. We can extend this for more ops if necessary.
|
||||
obj_list = [None for _ in range(dist.get_world_size())]
|
||||
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
|
||||
obj_list = list(filter(lambda x: x is not None, obj_list))
|
||||
# perform reduce on the collection with AND op
|
||||
local_results = functools.reduce(operator.and_, obj_list, True)
|
||||
# The output of the equal op is a bool, by converting it into a
|
||||
# a single value tensor, we can use all-reduce with min reduce op
|
||||
# to simulate logical and.
|
||||
assert local_results is None or isinstance(local_results, bool)
|
||||
r = torch.tensor(
|
||||
int(local_results) if local_results is not None else 1,
|
||||
device=mesh.device_type,
|
||||
)
|
||||
dist.all_reduce(r, op=dist.ReduceOp.MIN)
|
||||
local_results = bool(r.item())
|
||||
|
||||
if op_info.schema.is_inplace_op():
|
||||
# inplace op should return self instead of re-wrapping
|
||||
|
||||
Reference in New Issue
Block a user