[DTensor] Add aten.amin/amax to linear_reduction_strategy (#143747)

In the same vein as https://github.com/pytorch/pytorch/pull/134206, these two ops still seemed missing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143747
Approved by: https://github.com/kwen2501
This commit is contained in:
Luca Wehrstedt
2024-12-24 09:03:36 +00:00
committed by PyTorch MergeBot
parent b77406a9ec
commit aec3b46274
4 changed files with 6 additions and 9 deletions

View File

@ -98,8 +98,6 @@ dtensor_fails = {
xfail("all"),
xfail("allclose"),
xfail("alias_copy"),
xfail("amax"),
xfail("amin"),
xfail("aminmax"),
xfail("any"),
xfail("arange"),
@ -207,13 +205,10 @@ dtensor_fails = {
xfail("linalg.lu_factor"),
xfail("linalg.lu_factor_ex"),
xfail("linalg.lu_solve"),
xfail("linalg.matrix_norm"),
xfail("linalg.matrix_power"),
xfail("linalg.matrix_rank"),
xfail("linalg.matrix_rank", "hermitian"),
xfail("linalg.multi_dot"),
xfail("linalg.norm"),
xfail("linalg.norm", "subgradients_at_zero"),
xfail("linalg.pinv"),
xfail("linalg.pinv", "hermitian"),
xfail("linalg.slogdet"),
@ -238,8 +233,6 @@ dtensor_fails = {
xfail("masked_fill"),
xfail("masked_scatter"),
xfail("masked_select"),
xfail("masked.amax"),
xfail("masked.amin"),
xfail("masked.argmax"),
xfail("masked.argmin"),
xfail("masked.cumprod"),

View File

@ -83,7 +83,7 @@ class DistMathOpsTest(DTensorTestBase):
@with_comms
def test_linear_op_reductions(self):
for op_str in ("all", "sum", "prod", "max", "min", "any"):
for op_str in ("all", "sum", "prod", "max", "min", "any", "amax", "amin"):
self.linear_op_reductions(op_str)
@with_comms

View File

@ -112,7 +112,7 @@ class FlightRecorderEventTest(TestCase):
)
def test_all_events(self):
for collective in COLLECTIVES:
for collective in sorted(COLLECTIVES):
event = create_one_event(
collective, ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)

View File

@ -322,6 +322,10 @@ LINEAR_REDUCTION_OP_MAP = {
aten.any.default: "sum",
aten.any.dim: "sum",
aten.any.out: "sum",
aten.amax.default: "max",
aten.amax.out: "max",
aten.amin.default: "min",
aten.amin.out: "min",
}