mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
**Motivation:** Enable FSDP tests on XPU device Pull Request resolved: https://github.com/pytorch/pytorch/pull/147518 Approved by: https://github.com/weifengpy
220 lines
8.1 KiB
Python
220 lines
8.1 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
import warnings
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
from torch import distributed as dist
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
run_tests,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
|
|
device_type = torch.device(get_devtype())
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
"""
|
|
Model that supports two computation paths: `layer0` -> `layer1` and
|
|
`layer0` -> `layer2`. Notably, both `layer1` and `layer2` have 36 elements
|
|
when flattened, which means that their corresponding all-gathers and
|
|
reduce-scatters may be silently matched if we do not perform any checks.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layer0 = torch.nn.Linear(5, 6)
|
|
self.layer1 = torch.nn.Linear(6, 6, bias=False)
|
|
self.layer2 = torch.nn.Sequential(
|
|
torch.nn.Linear(6, 3, bias=False),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(3, 6, bias=False),
|
|
)
|
|
self.relu = torch.nn.ReLU()
|
|
self.use_alt_path = False
|
|
for param in self.layer2.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
# `layer0` -> `layer1` (normal)
|
|
# `layer0` -> `layer2` (alternate)
|
|
z = self.relu(self.layer0(x))
|
|
z = (
|
|
self.relu(self.layer2(z))
|
|
if self.use_alt_path
|
|
else self.relu(self.layer1(z))
|
|
)
|
|
return z
|
|
|
|
def get_input(self, device):
|
|
return (torch.randn((8, 5)).to(device),)
|
|
|
|
def get_loss(self, input, output):
|
|
return output.sum()
|
|
|
|
def run_backward(self, loss):
|
|
loss.backward()
|
|
|
|
def flip_path(self):
|
|
params_to_freeze = (
|
|
self.layer2.parameters() if self.use_alt_path else self.layer1.parameters()
|
|
)
|
|
params_to_unfreeze = (
|
|
self.layer1.parameters() if self.use_alt_path else self.layer2.parameters()
|
|
)
|
|
for param in params_to_freeze:
|
|
param.requires_grad = False
|
|
for param in params_to_unfreeze:
|
|
param.requires_grad = True
|
|
self.use_alt_path = not self.use_alt_path
|
|
|
|
@staticmethod
|
|
def wrap(sharding_strategy: ShardingStrategy, device):
|
|
model = Model()
|
|
model.layer1 = FSDP(
|
|
model.layer1, sharding_strategy=sharding_strategy, device_id=device
|
|
)
|
|
model.layer2 = FSDP(
|
|
model.layer2, sharding_strategy=sharding_strategy, device_id=device
|
|
)
|
|
fsdp_model = FSDP(model, sharding_strategy=sharding_strategy, device_id=device)
|
|
return fsdp_model.to(device)
|
|
|
|
|
|
class TestFSDPExecOrder(FSDPTest):
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
|
|
)
|
|
def test_invalid_first_iter_order(
|
|
self,
|
|
device,
|
|
sharding_strategy: ShardingStrategy,
|
|
):
|
|
"""Tests that FSDP errors if the all-gather order differs across ranks
|
|
in the first iteration."""
|
|
# Rank 0 runs the forward pass in one order and all other ranks run in
|
|
# different order
|
|
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
|
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
|
if self.rank != 0:
|
|
fsdp_model.flip_path()
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
# Match the error message with the following prefix
|
|
error_regex = "^(Forward order differs across ranks)"
|
|
with self.assertRaisesRegex(RuntimeError, error_regex):
|
|
fsdp_model(*inp)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
|
|
)
|
|
@parametrize("iters_before_path_change", [1, 3])
|
|
def test_invalid_later_iter_order(
|
|
self,
|
|
device,
|
|
sharding_strategy: ShardingStrategy,
|
|
iters_before_path_change: int,
|
|
):
|
|
"""Tests that FSDP warns the user if the all-gather order changes after
|
|
the first iteration."""
|
|
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
|
# On the first iteration, all ranks run the same order, and on the next
|
|
# iteration, all but rank 0 run in a different order
|
|
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
|
for _ in range(iters_before_path_change):
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
output = fsdp_model(*inp)
|
|
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
|
fsdp_model.module.run_backward(loss)
|
|
# Match the warning message with the following prefix
|
|
regex = (
|
|
"^(Forward order differs from that of the first iteration "
|
|
f"on rank {self.rank}. Collectives are unchecked and may give "
|
|
"incorrect results or hang)"
|
|
)
|
|
context = (
|
|
self.assertWarnsRegex(
|
|
expected_warning=UserWarning,
|
|
expected_regex=regex,
|
|
)
|
|
if self.rank != 0
|
|
else nullcontext()
|
|
)
|
|
if self.rank != 0:
|
|
fsdp_model.flip_path()
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
# Expect a warning for the forward pass all-gather
|
|
with context: # warning for forward pass all-gather
|
|
output = fsdp_model(*inp)
|
|
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
|
fsdp_model.module.run_backward(loss)
|
|
# Run an additional iteration to check that there are no more warnings
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
output = fsdp_model(*inp)
|
|
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
|
fsdp_model.module.run_backward(loss)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
|
|
)
|
|
def test_train_eval(self, device, sharding_strategy: ShardingStrategy):
|
|
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
|
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
|
NUM_ITERS = 3
|
|
NUM_EPOCHS = 2
|
|
with warnings.catch_warnings(record=True) as w: # records warnings to `w`
|
|
for _ in range(NUM_EPOCHS):
|
|
fsdp_model.train()
|
|
for _ in range(NUM_ITERS):
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
output = fsdp_model(*inp)
|
|
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
|
fsdp_model.module.run_backward(loss)
|
|
fsdp_model.eval()
|
|
for _ in range(NUM_ITERS):
|
|
inp = fsdp_model.module.get_input(device_type)
|
|
output = fsdp_model(*inp)
|
|
fsdp_model.module.get_loss(inp, output).to(device_type)
|
|
# Check that the order validation warning was not issued (errors do not
|
|
# need to be checked since they will be directly reported)
|
|
warning_prefix = "Forward order differs"
|
|
for warning in w:
|
|
if str(warning.message).startswith(warning_prefix):
|
|
raise AssertionError(
|
|
f"Warning was incorrectly issued: {warning.message}"
|
|
)
|
|
# If we still validate the forward execution order in eval mode, then
|
|
# an `AssertionError` will be raised above for both sharding strategies
|
|
|
|
|
|
devices = ("cuda", "hpu", "xpu")
|
|
instantiate_device_type_tests(
|
|
TestFSDPExecOrder, globals(), only_for=devices, allow_xpu=True
|
|
)
|
|
if __name__ == "__main__":
|
|
run_tests()
|