mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Add aten.amin/amax to linear_reduction_strategy (#143747)
In the same vein as https://github.com/pytorch/pytorch/pull/134206, these two ops still seemed missing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143747 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
b77406a9ec
commit
aec3b46274
@ -98,8 +98,6 @@ dtensor_fails = {
|
||||
xfail("all"),
|
||||
xfail("allclose"),
|
||||
xfail("alias_copy"),
|
||||
xfail("amax"),
|
||||
xfail("amin"),
|
||||
xfail("aminmax"),
|
||||
xfail("any"),
|
||||
xfail("arange"),
|
||||
@ -207,13 +205,10 @@ dtensor_fails = {
|
||||
xfail("linalg.lu_factor"),
|
||||
xfail("linalg.lu_factor_ex"),
|
||||
xfail("linalg.lu_solve"),
|
||||
xfail("linalg.matrix_norm"),
|
||||
xfail("linalg.matrix_power"),
|
||||
xfail("linalg.matrix_rank"),
|
||||
xfail("linalg.matrix_rank", "hermitian"),
|
||||
xfail("linalg.multi_dot"),
|
||||
xfail("linalg.norm"),
|
||||
xfail("linalg.norm", "subgradients_at_zero"),
|
||||
xfail("linalg.pinv"),
|
||||
xfail("linalg.pinv", "hermitian"),
|
||||
xfail("linalg.slogdet"),
|
||||
@ -238,8 +233,6 @@ dtensor_fails = {
|
||||
xfail("masked_fill"),
|
||||
xfail("masked_scatter"),
|
||||
xfail("masked_select"),
|
||||
xfail("masked.amax"),
|
||||
xfail("masked.amin"),
|
||||
xfail("masked.argmax"),
|
||||
xfail("masked.argmin"),
|
||||
xfail("masked.cumprod"),
|
||||
|
@ -83,7 +83,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_linear_op_reductions(self):
|
||||
for op_str in ("all", "sum", "prod", "max", "min", "any"):
|
||||
for op_str in ("all", "sum", "prod", "max", "min", "any", "amax", "amin"):
|
||||
self.linear_op_reductions(op_str)
|
||||
|
||||
@with_comms
|
||||
|
@ -112,7 +112,7 @@ class FlightRecorderEventTest(TestCase):
|
||||
)
|
||||
|
||||
def test_all_events(self):
|
||||
for collective in COLLECTIVES:
|
||||
for collective in sorted(COLLECTIVES):
|
||||
event = create_one_event(
|
||||
collective, ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
|
@ -322,6 +322,10 @@ LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.any.default: "sum",
|
||||
aten.any.dim: "sum",
|
||||
aten.any.out: "sum",
|
||||
aten.amax.default: "max",
|
||||
aten.amax.out: "max",
|
||||
aten.amin.default: "min",
|
||||
aten.amin.out: "min",
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user