mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP][Replicate] tests replicate type casting behavior and edge cases in mixed precision (#162861)
**Summary:** Ensures that replicate can handle the same type casting behavior and edge cases that fully shard can when mixed precision is used **Test Cases** 1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_float16_on_one_submodule 2. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_submodules_with_external_inputs 3. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_bf16 4. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_fp16 5. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_clamp_reduce_dtype 6. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_dataclass_input Pull Request resolved: https://github.com/pytorch/pytorch/pull/162861 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839, #162851, #162853, #162855
This commit is contained in:
committed by
PyTorch MergeBot
parent
ae4fd4ea75
commit
2810977d3a
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -16,17 +17,23 @@ from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
patch_reduce_scatter,
|
||||
reduce_scatter_with_assert,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocmVersionLessThan
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocmVersionLessThan,
|
||||
TEST_HPU,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
@ -383,5 +390,237 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
ref_param_compute.detach().copy_(ref_param)
|
||||
|
||||
|
||||
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_float16_on_one_submodule(self):
|
||||
x = torch.zeros(2, 100, device=device_type)
|
||||
|
||||
# Subtest 1: use fp16 on the second child submodule -- does not require
|
||||
# any additional casting logic
|
||||
forward_inputs: dict[str, nn.Module] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs,
|
||||
cast_forward_inputs=False,
|
||||
).to(device_type)
|
||||
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
|
||||
|
||||
# Subtest 2: use fp16 on the second child module, where the user module
|
||||
# owns the cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=True
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=False
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
# Subtest 3: use fp16 on the first child module and specify its output
|
||||
# dtype so that the second child module does not need to cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c1,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, output_dtype=torch.float32
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_submodules_with_external_inputs(self):
|
||||
self.run_subtests(
|
||||
{"enable_submodule_cast": [False, True]},
|
||||
self._test_submodules_with_external_inputs,
|
||||
)
|
||||
|
||||
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
|
||||
class ToyModule(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l = nn.Linear(100, 100)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["l2_input_x"] = x
|
||||
self.forward_inputs["l2_input_y"] = y
|
||||
return self.l(x)
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(100, 100)
|
||||
self.l2 = ToyModule(forward_inputs)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["model_input_x"] = x
|
||||
y = torch.ones(
|
||||
2, 100, device=device_type.type, dtype=torch.float32
|
||||
) # external input
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
model = ToyModel(forward_inputs).to(device_type)
|
||||
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
|
||||
replicate(
|
||||
model.l2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
|
||||
),
|
||||
)
|
||||
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
model(x).sum().backward()
|
||||
|
||||
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
|
||||
# cast to fp16, and if we disable, then it says as fp32.
|
||||
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(
|
||||
forward_inputs["l2_input_y"].dtype,
|
||||
torch.float16 if enable_submodule_cast else torch.float32,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_norm_modules_bf16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_norm_modules_fp16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
|
||||
def inner(model: nn.Module, x: torch.Tensor):
|
||||
# Run forward and backward to check for no type mismatch errors
|
||||
z = model(x)
|
||||
self.assertEqual(z.dtype, mp_policy.param_dtype)
|
||||
z.sum().backward()
|
||||
|
||||
# Layer norm
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 1D
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 2D: error in backward from buffer dtype mismatch
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
if TEST_HPU:
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
|
||||
):
|
||||
# Errors in batch norm 2D backward
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: cast buffers down to lower precision
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
# Casting batch norm buffers to the lower precision allows backward
|
||||
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
|
||||
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: use special mixed precision policy
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
|
||||
replicate(model[1], mp_policy=bn_mp_policy)
|
||||
for module in (model[0], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_clamp_reduce_dtype(self):
|
||||
# Initialize the model directly in bf16
|
||||
init_dtype = torch.bfloat16
|
||||
model = nn.Sequential(
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
).to(device_type.type)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
|
||||
)
|
||||
# Check that we did not clamp the reduce dtype
|
||||
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
|
||||
for module in model:
|
||||
replicate((module), mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
|
||||
# Check that the reduce-scatter runs in bf16 even after we change the
|
||||
# model from bf16 to fp32
|
||||
model.to(torch.float32)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, torch.bfloat16)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
inp = torch.randn((4, 32), device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_dataclass_input(self):
|
||||
@dataclasses.dataclass
|
||||
class Input:
|
||||
x: torch.Tensor
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._layer = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, input: Input):
|
||||
return self._layer(input.x)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
|
||||
)
|
||||
model = Model()
|
||||
inp = Input(torch.randn(2, 10).cuda())
|
||||
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user