[FSDP][Replicate] tests replicate type casting behavior and edge cases in mixed precision (#162861)

**Summary:** Ensures that replicate can handle the same type casting behavior and edge cases that fully shard can when mixed precision is used

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_float16_on_one_submodule
2. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_submodules_with_external_inputs
3. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_bf16
4. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_fp16
5. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_clamp_reduce_dtype
6. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_dataclass_input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162861
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851, #162853, #162855
This commit is contained in:
Anshul Sinha
2025-09-25 12:23:46 -07:00
committed by PyTorch MergeBot
parent ae4fd4ea75
commit 2810977d3a

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import copy
import dataclasses
import functools
from typing import Optional, Union
@ -16,17 +17,23 @@ from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocmVersionLessThan
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocmVersionLessThan,
TEST_HPU,
)
device_type = torch.device(get_devtype())
@ -383,5 +390,237 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
ref_param_compute.detach().copy_(ref_param)
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(1)
def test_float16_on_one_submodule(self):
x = torch.zeros(2, 100, device=device_type)
# Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic
forward_inputs: dict[str, nn.Module] = {}
model = SaveForwardInputsModel(
forward_inputs,
cast_forward_inputs=False,
).to(device_type)
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
# Subtest 2: use fp16 on the second child module, where the user module
# owns the cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True
).to(device_type)
replicate(
model.c2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=False
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
# Subtest 3: use fp16 on the first child module and specify its output
# dtype so that the second child module does not need to cast
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False
).to(device_type)
replicate(
model.c1,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, output_dtype=torch.float32
),
)
replicate(model)
model(x).sum().backward()
self.assertEqual(forward_inputs[model].dtype, torch.float32)
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
@skip_if_lt_x_gpu(1)
def test_submodules_with_external_inputs(self):
self.run_subtests(
{"enable_submodule_cast": [False, True]},
self._test_submodules_with_external_inputs,
)
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
class ToyModule(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
self.forward_inputs["l2_input_x"] = x
self.forward_inputs["l2_input_y"] = y
return self.l(x)
class ToyModel(nn.Module):
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs)
self.forward_inputs = forward_inputs
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.forward_inputs["model_input_x"] = x
y = torch.ones(
2, 100, device=device_type.type, dtype=torch.float32
) # external input
return self.l2(self.l1(x), y)
forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).to(device_type)
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
replicate(
model.l2,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
),
)
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
model(x).sum().backward()
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
# cast to fp16, and if we disable, then it says as fp32.
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
self.assertEqual(
forward_inputs["l2_input_y"].dtype,
torch.float16 if enable_submodule_cast else torch.float32,
)
@skip_if_lt_x_gpu(1)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_norm_modules_bf16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
self._test_norm_modules(mp_policy)
@skip_if_lt_x_gpu(1)
def test_norm_modules_fp16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
self._test_norm_modules(mp_policy)
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
def inner(model: nn.Module, x: torch.Tensor):
# Run forward and backward to check for no type mismatch errors
z = model(x)
self.assertEqual(z.dtype, mp_policy.param_dtype)
z.sum().backward()
# Layer norm
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 1D
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((4, 32)))
# Batch norm 2D: error in backward from buffer dtype mismatch
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
if TEST_HPU:
inner(model, torch.randn((3, 1, 9, 9)))
else:
with self.assertRaisesRegex(
RuntimeError,
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
):
# Errors in batch norm 2D backward
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: cast buffers down to lower precision
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
replicate(module, mp_policy=mp_policy)
# Casting batch norm buffers to the lower precision allows backward
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: use special mixed precision policy
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
replicate(model[1], mp_policy=bn_mp_policy)
for module in (model[0], model[2], model):
replicate(module, mp_policy=mp_policy)
inner(model, torch.randn((3, 1, 9, 9)))
@skip_if_lt_x_gpu(1)
def test_clamp_reduce_dtype(self):
# Initialize the model directly in bf16
init_dtype = torch.bfloat16
model = nn.Sequential(
nn.Linear(32, 32, dtype=init_dtype),
nn.Linear(32, 32, dtype=init_dtype),
).to(device_type.type)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
)
# Check that we did not clamp the reduce dtype
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
for module in model:
replicate((module), mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
# Check that the reduce-scatter runs in bf16 even after we change the
# model from bf16 to fp32
model.to(torch.float32)
orig_reduce_scatter = dist.reduce_scatter_tensor
def assert_fn(output: torch.Tensor):
self.assertEqual(output.dtype, torch.bfloat16)
reduce_scatter = functools.partial(
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
with patch_reduce_scatter(reduce_scatter):
inp = torch.randn((4, 32), device=device_type.type)
loss = model(inp).sum()
loss.backward()
@skip_if_lt_x_gpu(1)
def test_dataclass_input(self):
@dataclasses.dataclass
class Input:
x: torch.Tensor
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._layer = nn.Linear(10, 10)
def forward(self, input: Input):
return self._layer(input.x)
mp_policy = MixedPrecisionPolicy(
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
)
model = Model()
inp = Input(torch.randn(2, 10).cuda())
replicate(model, mp_policy=mp_policy)
loss = model(inp).sum()
loss.backward()
if __name__ == "__main__":
run_tests()