mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] Enable and fix several FSDP + Inductor distributed unit tests (#165011)
This PR enables a number of distributed unit tests and applies necessary fixes to ensure they pass on ROCm platforms. The changes have been successfully tested on both MI200 and MI300 hardware. This work addresses the following issues: **https://github.com/ROCm/frameworks-internal/issues/13586 https://github.com/ROCm/frameworks-internal/issues/13578** **Enabled Tests** The following tests have been enabled and are now passing: 1. test_compiled_autograd_ctx 2. test_simple_mlp_fullgraph_backend_aot_eager 3. test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition 4. test_simple_mlp_fullgraph_backend_inductor 5. test_nested_fully_shard_backend_aot_eager 6. test_nested_fully_shard_backend_aot_eager_decomp_partition 7. test_nested_fully_shard_backend_inductor_fullgraph_True 8. test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition 9. test_transformer_backend_aot_eager 10. test_transformer_backend_aot_eager_decomp_partition 11. test_storage_resize_zero_gpu 12. test_storage_resize_nonzero_gpu 13. test_fake_distributed_inductor **Tests skipped due to upstream issues:** 1. test_nested_fully_shard_backend_inductor_fullgraph_False 2. test_transformer_backend_inductor_fullgraph_True 3. test_transformer_backend_inductor_fullgraph_True_graph_partition 4. test_transformer_backend_inductor_fullgraph_False Pull Request resolved: https://github.com/pytorch/pytorch/pull/165011 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
68913d8f2a
commit
55f01a48af
@ -32,7 +32,7 @@ from torch.testing._internal.common_distributed import (
|
||||
sm_is_or_higher_than,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
@ -133,7 +133,11 @@ class TestFullyShardCompile(FSDPTest):
|
||||
device_type.type,
|
||||
self.rank % torch.get_device_module(device_type).device_count(),
|
||||
)
|
||||
if device_type.type == "cuda" and not sm_is_or_higher_than(device, 8, 0):
|
||||
if (
|
||||
device_type.type == "cuda"
|
||||
and not torch.version.hip
|
||||
and not sm_is_or_higher_than(device, 8, 0)
|
||||
):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
def test_dynamo_trace_use_training_state(self):
|
||||
@ -478,7 +482,6 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
||||
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
|
||||
return file_check
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_compiled_autograd_ctx(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -643,14 +646,12 @@ Unsupported Tensor.backward() call
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_aot_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_simple_mlp_factory_fns(), "aot_eager", fwd_fullgraph=True
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
|
||||
self._test_traceable_fsdp(
|
||||
@ -659,7 +660,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=True,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_inductor(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -731,7 +731,6 @@ Unsupported Tensor.backward() call
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_aot_eager(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -744,7 +743,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -866,19 +864,16 @@ Unsupported Tensor.backward() call
|
||||
pass
|
||||
file_check.run(bwd_code)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
|
||||
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition(self):
|
||||
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip("TODO: fix fwd_fullgraph=False case")
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -956,7 +951,6 @@ Unsupported Tensor.backward() call
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_transformer_backend_aot_eager(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -975,7 +969,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1111,7 +1104,6 @@ Unsupported Tensor.backward() call
|
||||
file_check.run(bwd_code)
|
||||
|
||||
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1119,7 +1111,6 @@ Unsupported Tensor.backward() call
|
||||
self._test_transformer_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1128,7 +1119,6 @@ Unsupported Tensor.backward() call
|
||||
self._test_transformer_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip("TODO: fix fwd_fullgraph=False case")
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
|
@ -7,7 +7,7 @@ from torch import nn
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu
|
||||
from torch.testing._internal.common_utils import IS_MACOS, skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
|
||||
|
||||
|
||||
@ -205,7 +205,6 @@ class DistributedPatternTests(TestCase):
|
||||
def test_storage_resize_zero_cpu(self):
|
||||
self._test_storage_resize_zero("cpu")
|
||||
|
||||
@skipIfRocm
|
||||
@requires_gpu()
|
||||
def test_storage_resize_zero_gpu(self):
|
||||
self._test_storage_resize_zero(GPU_TYPE)
|
||||
@ -230,7 +229,6 @@ class DistributedPatternTests(TestCase):
|
||||
def test_storage_resize_nonzero_cpu(self):
|
||||
self._test_storage_resize_nonzero("cpu")
|
||||
|
||||
@skipIfRocm
|
||||
@requires_gpu()
|
||||
def test_storage_resize_nonzero_gpu(self):
|
||||
self._test_storage_resize_nonzero(GPU_TYPE)
|
||||
@ -485,7 +483,6 @@ class DistributedPatternTests(TestCase):
|
||||
# Recompile on grad==None/grad!=None
|
||||
self.assertEqual(bw_cnt.frame_count, 2)
|
||||
|
||||
@skipIfRocm
|
||||
@skipIfXpu
|
||||
@requires_gpu()
|
||||
@torch._functorch.config.patch(recompute_views=True)
|
||||
|
@ -14,8 +14,7 @@ using namespace at;
|
||||
static void resize_storage_bytes_(const Tensor& variable, SymInt new_size) {
|
||||
// similar to THPStorage_resize_ in StorageMethods.cpp, but is traceable
|
||||
if (variable.storage().device_type() == at::kCUDA) {
|
||||
// rocm build has undefined reference to resize_bytes_cuda
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
at::native::resize_bytes_cuda(
|
||||
variable.storage().unsafeGetStorageImpl(), new_size.expect_int());
|
||||
#else
|
||||
|
Reference in New Issue
Block a user