[ROCm] Enable and fix several FSDP + Inductor distributed unit tests (#165011)

This PR enables a number of distributed unit tests and applies necessary fixes to ensure they pass on ROCm platforms. The changes have been successfully tested on both MI200 and MI300 hardware.

This work addresses the following issues:
**https://github.com/ROCm/frameworks-internal/issues/13586
https://github.com/ROCm/frameworks-internal/issues/13578**

**Enabled Tests**

The following tests have been enabled and are now passing:
1. test_compiled_autograd_ctx
2. test_simple_mlp_fullgraph_backend_aot_eager
3. test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition
4. test_simple_mlp_fullgraph_backend_inductor
5. test_nested_fully_shard_backend_aot_eager
6. test_nested_fully_shard_backend_aot_eager_decomp_partition
7. test_nested_fully_shard_backend_inductor_fullgraph_True
8. test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition
9. test_transformer_backend_aot_eager
10. test_transformer_backend_aot_eager_decomp_partition
11. test_storage_resize_zero_gpu
12. test_storage_resize_nonzero_gpu
13. test_fake_distributed_inductor

**Tests skipped due to upstream issues:**
1. test_nested_fully_shard_backend_inductor_fullgraph_False
2. test_transformer_backend_inductor_fullgraph_True
3. test_transformer_backend_inductor_fullgraph_True_graph_partition
4. test_transformer_backend_inductor_fullgraph_False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165011
Approved by: https://github.com/jeffdaily
This commit is contained in:
Chinmay Kuchinad
2025-10-10 14:10:54 +00:00
committed by PyTorch MergeBot
parent 68913d8f2a
commit 55f01a48af
3 changed files with 8 additions and 22 deletions

View File

@ -32,7 +32,7 @@ from torch.testing._internal.common_distributed import (
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
@ -133,7 +133,11 @@ class TestFullyShardCompile(FSDPTest):
device_type.type,
self.rank % torch.get_device_module(device_type).device_count(),
)
if device_type.type == "cuda" and not sm_is_or_higher_than(device, 8, 0):
if (
device_type.type == "cuda"
and not torch.version.hip
and not sm_is_or_higher_than(device, 8, 0)
):
self.skipTest("bf16 requires sm >= 8.0")
def test_dynamo_trace_use_training_state(self):
@ -478,7 +482,6 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
return file_check
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_compiled_autograd_ctx(self):
self.skipTestForOldSm()
@ -643,14 +646,12 @@ Unsupported Tensor.backward() call
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fwd_fullgraph=True
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
self._test_traceable_fsdp(
@ -659,7 +660,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=True,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_inductor(self):
self.skipTestForOldSm()
@ -731,7 +731,6 @@ Unsupported Tensor.backward() call
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
@ -744,7 +743,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
# TODO: fix fwd_fullgraph=False case
@ -866,19 +864,16 @@ Unsupported Tensor.backward() call
pass
file_check.run(bwd_code)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch("graph_partition", True)
def test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition(self):
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
self.skipTestForOldSm()
@ -956,7 +951,6 @@ Unsupported Tensor.backward() call
else:
return contextlib.nullcontext()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_transformer_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
@ -975,7 +969,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1111,7 +1104,6 @@ Unsupported Tensor.backward() call
file_check.run(bwd_code)
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1119,7 +1111,6 @@ Unsupported Tensor.backward() call
self._test_transformer_backend_inductor_fullgraph_True()
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1128,7 +1119,6 @@ Unsupported Tensor.backward() call
self._test_transformer_backend_inductor_fullgraph_True()
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)

View File

@ -7,7 +7,7 @@ from torch import nn
from torch._dynamo import compiled_autograd
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu
from torch.testing._internal.common_utils import IS_MACOS, skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
@ -205,7 +205,6 @@ class DistributedPatternTests(TestCase):
def test_storage_resize_zero_cpu(self):
self._test_storage_resize_zero("cpu")
@skipIfRocm
@requires_gpu()
def test_storage_resize_zero_gpu(self):
self._test_storage_resize_zero(GPU_TYPE)
@ -230,7 +229,6 @@ class DistributedPatternTests(TestCase):
def test_storage_resize_nonzero_cpu(self):
self._test_storage_resize_nonzero("cpu")
@skipIfRocm
@requires_gpu()
def test_storage_resize_nonzero_gpu(self):
self._test_storage_resize_nonzero(GPU_TYPE)
@ -485,7 +483,6 @@ class DistributedPatternTests(TestCase):
# Recompile on grad==None/grad!=None
self.assertEqual(bw_cnt.frame_count, 2)
@skipIfRocm
@skipIfXpu
@requires_gpu()
@torch._functorch.config.patch(recompute_views=True)

View File

@ -14,8 +14,7 @@ using namespace at;
static void resize_storage_bytes_(const Tensor& variable, SymInt new_size) {
// similar to THPStorage_resize_ in StorageMethods.cpp, but is traceable
if (variable.storage().device_type() == at::kCUDA) {
// rocm build has undefined reference to resize_bytes_cuda
#if defined(USE_CUDA) && !defined(USE_ROCM)
#if defined(USE_CUDA)
at::native::resize_bytes_cuda(
variable.storage().unsafeGetStorageImpl(), new_size.expect_int());
#else