mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
**Motivation:** Enable FSDP tests on XPU device Pull Request resolved: https://github.com/pytorch/pytorch/pull/147518 Approved by: https://github.com/weifengpy
413 lines
15 KiB
Python
413 lines
15 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import copy
|
|
import sys
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch._utils import _get_device_module
|
|
from torch.distributed.fsdp import BackwardPrefetch, CPUOffload, MixedPrecision
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|
FullyShardedDataParallel as FSDP,
|
|
ShardingStrategy,
|
|
)
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TEST_CUDA,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
|
|
device_type = torch.device(get_devtype())
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
class LinearUnusedInput(nn.Linear):
|
|
def forward(self, frozen_input, learnable_input):
|
|
return super().forward(frozen_input)
|
|
|
|
|
|
class ModelUnusedInput(nn.Module):
|
|
def __init__(self, freeze: bool):
|
|
super().__init__()
|
|
self.layer0 = LinearUnusedInput(4, 4)
|
|
self.layer1_frozen = LinearUnusedInput(4, 4)
|
|
if freeze:
|
|
for param in self.layer1_frozen.parameters():
|
|
param.requires_grad = False
|
|
self.layer2 = LinearUnusedInput(4, 4)
|
|
|
|
def forward(self, frozen_input, learnable_input):
|
|
x = self.layer0(frozen_input, learnable_input)
|
|
y = self.layer1_frozen(frozen_input, learnable_input)
|
|
z = self.layer2(frozen_input, learnable_input)
|
|
return torch.concat([x, y, z, learnable_input])
|
|
|
|
|
|
class TestFSDPFineTune(FSDPTest):
|
|
"""Tests fine-tuning cases where some parameters are frozen."""
|
|
|
|
NUM_LINEARS = 6
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(_get_device_module(self.device_type).device_count(), 2)
|
|
|
|
def _init_seq_module(self, device) -> nn.Module:
|
|
torch.manual_seed(42)
|
|
modules = []
|
|
for _ in range(self.NUM_LINEARS):
|
|
modules += [nn.Linear(5, 5, device=device), nn.ReLU()]
|
|
seq = nn.Sequential(*modules)
|
|
self._set_seq_module_requires_grad(seq, False)
|
|
return seq
|
|
|
|
def _set_seq_module_requires_grad(self, seq: nn.Module, requires_grad: bool):
|
|
# Assume that the linears are leaf modules, meaning that we can pass
|
|
# `recurse=True` to have this to work for both pre/post FSDP wrapping
|
|
for i in range(self.NUM_LINEARS):
|
|
# Only set for every other linear to test mixing frozen/non-frozen
|
|
if i % 2 == 0:
|
|
for param in seq[i * 2].parameters(recurse=True):
|
|
param.requires_grad = requires_grad
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_backward_reshard_hooks(self, device):
|
|
"""
|
|
Tests that the post-backward reshard happens even for flat parameters
|
|
that do not require gradients.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"device_id": [device],
|
|
"sharding_strategy": [
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
ShardingStrategy.NO_SHARD,
|
|
],
|
|
"use_orig_params": [False, True],
|
|
"inp_requires_grad": [False, True],
|
|
"unfreeze_params": [False, True],
|
|
},
|
|
self._test_backward_reshard_hooks,
|
|
)
|
|
|
|
def _test_backward_reshard_hooks(
|
|
self,
|
|
device_id,
|
|
sharding_strategy: ShardingStrategy,
|
|
use_orig_params: bool,
|
|
inp_requires_grad: bool,
|
|
unfreeze_params: bool,
|
|
):
|
|
seq = self._init_seq_module(device_type)
|
|
policy = ModuleWrapPolicy({nn.Linear})
|
|
fsdp_kwargs = {"device_id": device_type}
|
|
seq = FSDP(
|
|
seq,
|
|
auto_wrap_policy=policy,
|
|
sharding_strategy=sharding_strategy,
|
|
use_orig_params=use_orig_params,
|
|
**fsdp_kwargs,
|
|
)
|
|
orig_post_backward_reshard = (
|
|
torch.distributed.fsdp._runtime_utils._post_backward_reshard
|
|
)
|
|
post_backward_reshard_count = 0
|
|
|
|
def _post_backward_reshard_with_count(*args, **kwargs):
|
|
nonlocal post_backward_reshard_count
|
|
post_backward_reshard_count += 1
|
|
return orig_post_backward_reshard(*args, **kwargs)
|
|
|
|
def _assert_post_backward_requires_grad(seq):
|
|
if step_idx == num_steps - 1 and unfreeze_params:
|
|
self.assertTrue(
|
|
all(p.requires_grad for p in seq.parameters()),
|
|
msg="Expected all parameters to require grad but some did not!",
|
|
)
|
|
|
|
def _assert_post_backward_reshard_count(step_idx, num_steps):
|
|
if step_idx < num_steps - 1 or not unfreeze_params:
|
|
# If the input does not require gradient, then the 0th
|
|
# frozen linear gets resharded in the catch-all reshard
|
|
# since we cannot register an autograd hook on it
|
|
expected_post_backward_reshard_count = (
|
|
self.NUM_LINEARS if inp_requires_grad else self.NUM_LINEARS - 1
|
|
)
|
|
else:
|
|
# This follows the normal post-backward hook path
|
|
expected_post_backward_reshard_count = self.NUM_LINEARS
|
|
self.assertEqual(
|
|
post_backward_reshard_count, expected_post_backward_reshard_count
|
|
)
|
|
|
|
with mock.patch(
|
|
"torch.distributed.fsdp._runtime_utils._post_backward_reshard",
|
|
_post_backward_reshard_with_count,
|
|
):
|
|
num_steps = 3
|
|
# interleave a `no_grad` step to validate post-backward hooks are not registered in that context
|
|
# and that `requires_grad` is reset appropriately when unfreezing
|
|
nograd_step_idx = 1
|
|
for step_idx in range(num_steps):
|
|
if unfreeze_params and step_idx == num_steps - 1:
|
|
# Unfreeze the parameters on the last step to emulate some
|
|
# kinds of fine-tuning
|
|
self._set_seq_module_requires_grad(seq, True)
|
|
|
|
inp = torch.randn(
|
|
(8, 5), device=device_type, requires_grad=inp_requires_grad
|
|
)
|
|
if step_idx == nograd_step_idx:
|
|
with torch.no_grad():
|
|
output = seq(inp)
|
|
else:
|
|
output = seq(inp)
|
|
if step_idx != nograd_step_idx:
|
|
output.sum().backward()
|
|
_assert_post_backward_requires_grad(seq)
|
|
_assert_post_backward_reshard_count(step_idx, num_steps)
|
|
post_backward_reshard_count = 0
|
|
|
|
def _init_multi_traversal_module(self, device) -> nn.Module:
|
|
torch.manual_seed(42)
|
|
|
|
class TestModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layer_0 = nn.Linear(5, 5, device=device)
|
|
self.layer_no_grad = nn.Linear(5, 5, device=device)
|
|
self.layer_with_grad = nn.Linear(5, 5, device=device)
|
|
self.layer_no_grad.requires_grad_(False)
|
|
|
|
def forward(self, x):
|
|
# Layer `layer_no_grad` and `layer_with_grad` are called
|
|
# multiple times, IOW, their parameters are used multiple times
|
|
# during forward pass.
|
|
x = self.layer_0(x)
|
|
for _ in range(10):
|
|
x = self.layer_no_grad(self.layer_with_grad(x))
|
|
# Make sure calling the same layer multiple times works
|
|
# regardless whether gradient is enabled.
|
|
with torch.no_grad():
|
|
x += self.layer_with_grad(x)
|
|
return x
|
|
|
|
return TestModule()
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_hooks_multi_traversal(self):
|
|
"""
|
|
Tests that the hooks do reshard / unshard correctly in the case of same
|
|
parameters being used multiple times during forward pass.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"sharding_strategy": [
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
ShardingStrategy.NO_SHARD,
|
|
],
|
|
"use_orig_params": [False, True],
|
|
"inp_requires_grad": [False, True],
|
|
"forward_prefetch": [False, True],
|
|
},
|
|
self._test_hooks_multi_traversal,
|
|
)
|
|
|
|
def _test_hooks_multi_traversal(
|
|
self,
|
|
sharding_strategy: ShardingStrategy,
|
|
use_orig_params: bool,
|
|
inp_requires_grad: bool,
|
|
forward_prefetch: bool,
|
|
):
|
|
seq = self._init_multi_traversal_module(device_type.type)
|
|
policy = ModuleWrapPolicy({nn.Linear})
|
|
fsdp_kwargs = {"device_id": device_type}
|
|
fsdp_seq = FSDP(
|
|
copy.deepcopy(seq),
|
|
auto_wrap_policy=policy,
|
|
sharding_strategy=sharding_strategy,
|
|
use_orig_params=use_orig_params,
|
|
forward_prefetch=forward_prefetch,
|
|
**fsdp_kwargs,
|
|
)
|
|
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
|
|
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
|
|
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
|
|
torch.manual_seed(self.rank + 1)
|
|
losses = []
|
|
for _ in range(6):
|
|
inp = torch.randn(
|
|
(8, 5), device=device_type, requires_grad=inp_requires_grad
|
|
)
|
|
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
|
|
loss = seq(inp).sum()
|
|
losses.append(loss)
|
|
loss.backward()
|
|
optim.step()
|
|
optim.zero_grad()
|
|
torch.testing.assert_close(losses[0], losses[1])
|
|
losses.clear()
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_parity_with_ddp(self):
|
|
"""
|
|
Tests parity with DDP when mixing flat parameters that require and do
|
|
not require gradients.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"sharding_strategy": [
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
ShardingStrategy.NO_SHARD,
|
|
],
|
|
"use_orig_params": [False, True],
|
|
},
|
|
self._test_parity_with_ddp,
|
|
)
|
|
|
|
def _test_parity_with_ddp(
|
|
self,
|
|
sharding_strategy: ShardingStrategy,
|
|
use_orig_params: bool,
|
|
):
|
|
seq = self._init_seq_module(device_type)
|
|
policy = ModuleWrapPolicy({nn.Linear})
|
|
fsdp_kwargs = {"device_id": device_type}
|
|
fsdp_seq = FSDP(
|
|
copy.deepcopy(seq),
|
|
auto_wrap_policy=policy,
|
|
sharding_strategy=sharding_strategy,
|
|
use_orig_params=use_orig_params,
|
|
**fsdp_kwargs,
|
|
)
|
|
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
|
|
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
|
|
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
|
|
torch.manual_seed(self.rank + 1)
|
|
losses = []
|
|
for _ in range(6):
|
|
inp = torch.randn((8, 5), device=device_type.type)
|
|
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
|
|
loss = seq(inp).sum()
|
|
losses.append(loss)
|
|
loss.backward()
|
|
optim.step()
|
|
optim.zero_grad()
|
|
if TEST_CUDA:
|
|
torch.testing.assert_close(losses[0], losses[1])
|
|
else:
|
|
torch.testing.assert_close(losses[0], losses[1], atol=1e-03, rtol=1e-03)
|
|
losses.clear()
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_parity_with_non_frozen_fsdp(self, device):
|
|
"""
|
|
For frozen modules with unused input, reshard could happen without unshard
|
|
Verify numerical parity between `_post_backward_reshard_only_hook` and
|
|
`_post_backward_hook` path
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"device_id": [device],
|
|
"sharding_strategy": [
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
"use_orig_params": [True, False],
|
|
"offload_params": [True, False],
|
|
"mixed_precision": [
|
|
MixedPrecision(),
|
|
MixedPrecision(
|
|
param_dtype=torch.float16,
|
|
buffer_dtype=torch.float16,
|
|
reduce_dtype=torch.float16,
|
|
),
|
|
],
|
|
"backward_prefetch": [
|
|
BackwardPrefetch.BACKWARD_PRE,
|
|
BackwardPrefetch.BACKWARD_POST,
|
|
],
|
|
},
|
|
self._test_parity_with_non_frozen_fsdp,
|
|
)
|
|
|
|
def _test_parity_with_non_frozen_fsdp(
|
|
self,
|
|
device_id,
|
|
sharding_strategy: ShardingStrategy,
|
|
use_orig_params: bool,
|
|
offload_params: bool,
|
|
mixed_precision: MixedPrecision,
|
|
backward_prefetch: BackwardPrefetch,
|
|
):
|
|
torch.manual_seed(42)
|
|
model = ModelUnusedInput(freeze=True).to(device_type)
|
|
torch.manual_seed(42)
|
|
ref_model = ModelUnusedInput(freeze=False).to(device_type)
|
|
fsdp_kwargs = {
|
|
"device_id": device_type,
|
|
"auto_wrap_policy": ModuleWrapPolicy({LinearUnusedInput}),
|
|
"sharding_strategy": sharding_strategy,
|
|
"use_orig_params": use_orig_params,
|
|
"cpu_offload": CPUOffload(offload_params=offload_params),
|
|
"mixed_precision": mixed_precision,
|
|
"backward_prefetch": backward_prefetch,
|
|
}
|
|
model = FSDP(model, **fsdp_kwargs)
|
|
ref_model = FSDP(ref_model, **fsdp_kwargs)
|
|
model_optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
ref_model_optim = torch.optim.Adam(
|
|
[
|
|
param
|
|
for name, param in ref_model.named_parameters()
|
|
if not name.startswith("_fsdp_wrapped_module.layer1_frozen")
|
|
],
|
|
lr=1e-2,
|
|
)
|
|
torch.manual_seed(self.rank + 1)
|
|
losses = []
|
|
for _ in range(6):
|
|
frozen_input = torch.randn((4, 4), device=device_type, requires_grad=False)
|
|
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
|
|
loss = _model(frozen_input, frozen_input).sum()
|
|
losses.append(loss)
|
|
loss.backward()
|
|
_optim.step()
|
|
_optim.zero_grad()
|
|
self.assertEqual(losses[0], losses[1])
|
|
losses.clear()
|
|
with FSDP.summon_full_params(model):
|
|
with FSDP.summon_full_params(ref_model):
|
|
for param, ref_param in zip(model.parameters(), ref_model.parameters()):
|
|
self.assertEqual(param, ref_param)
|
|
|
|
|
|
devices = ("cuda", "hpu", "xpu")
|
|
instantiate_device_type_tests(
|
|
TestFSDPFineTune, globals(), only_for=devices, allow_xpu=True
|
|
)
|
|
if __name__ == "__main__":
|
|
run_tests()
|