diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 471ba4f901a7..1be373ee0e25 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -98,8 +98,6 @@ dtensor_fails = { xfail("all"), xfail("allclose"), xfail("alias_copy"), - xfail("amax"), - xfail("amin"), xfail("aminmax"), xfail("any"), xfail("arange"), @@ -207,13 +205,10 @@ dtensor_fails = { xfail("linalg.lu_factor"), xfail("linalg.lu_factor_ex"), xfail("linalg.lu_solve"), - xfail("linalg.matrix_norm"), xfail("linalg.matrix_power"), xfail("linalg.matrix_rank"), xfail("linalg.matrix_rank", "hermitian"), xfail("linalg.multi_dot"), - xfail("linalg.norm"), - xfail("linalg.norm", "subgradients_at_zero"), xfail("linalg.pinv"), xfail("linalg.pinv", "hermitian"), xfail("linalg.slogdet"), @@ -238,8 +233,6 @@ dtensor_fails = { xfail("masked_fill"), xfail("masked_scatter"), xfail("masked_select"), - xfail("masked.amax"), - xfail("masked.amin"), xfail("masked.argmax"), xfail("masked.argmin"), xfail("masked.cumprod"), diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index 1a8ee437342e..67fbeaa93066 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -83,7 +83,7 @@ class DistMathOpsTest(DTensorTestBase): @with_comms def test_linear_op_reductions(self): - for op_str in ("all", "sum", "prod", "max", "min", "any"): + for op_str in ("all", "sum", "prod", "max", "min", "any", "amax", "amin"): self.linear_op_reductions(op_str) @with_comms diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index f47b8b6c0e51..92b981928776 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -112,7 +112,7 @@ class FlightRecorderEventTest(TestCase): ) def test_all_events(self): - for collective in COLLECTIVES: + for collective in sorted(COLLECTIVES): event = create_one_event( collective, ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 7233afe87a69..b3af6252663e 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -322,6 +322,10 @@ LINEAR_REDUCTION_OP_MAP = { aten.any.default: "sum", aten.any.dim: "sum", aten.any.out: "sum", + aten.amax.default: "max", + aten.amax.out: "max", + aten.amin.default: "min", + aten.amin.out: "min", }