[FSDP] Do not check fwd order in eval mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77195

Approved by: https://github.com/zhaojuanmao
This commit is contained in:
Andrew Gu
2022-05-11 15:55:28 +00:00
committed by PyTorch MergeBot
parent e8f53ad1f1
commit e912d24303
2 changed files with 97 additions and 52 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import sys
import warnings
from contextlib import suppress
import torch
@ -109,7 +110,7 @@ class TestFSDPExecOrder(FSDPTest):
fsdp_model.flip_path()
inp = fsdp_model.module.get_input(self.device)
# Match the error message with the following prefix
error_regex = "^(All-gather order differs across ranks)"
error_regex = "^(Forward order differs across ranks)"
with self.assertRaisesRegex(RuntimeError, error_regex):
fsdp_model(*inp)
@ -135,7 +136,7 @@ class TestFSDPExecOrder(FSDPTest):
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
# Match the warning message with the following prefix
regex = "^(All-gather order differs from that of the first iteration " \
regex = "^(Forward order differs from that of the first iteration " \
f"on rank {self.rank} -- collectives are unchecked and may give " \
"incorrect results or hang)"
context = self.assertWarnsRegex(
@ -155,6 +156,37 @@ class TestFSDPExecOrder(FSDPTest):
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
@skip_if_lt_x_gpu(2)
@parametrize(
"sharding_strategy",
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
)
def test_train_eval(self, sharding_strategy: ShardingStrategy):
fsdp_model = Model.wrap(sharding_strategy, self.device)
NUM_ITERS = 3
NUM_EPOCHS = 2
with warnings.catch_warnings(record=True) as w: # records warnings to `w`
for _ in range(NUM_EPOCHS):
fsdp_model.train()
for _ in range(NUM_ITERS):
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
fsdp_model.eval()
for _ in range(NUM_ITERS):
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
fsdp_model.module.get_loss(inp, output).to(self.device)
# Check that the order validation warning was not issued (errors do not
# need to be checked since they will be directly reported)
warning_prefix = "Forward order differs"
for warning in w:
if str(warning.message).startswith(warning_prefix):
raise AssertionError(f"Warning was incorrectly issued: {warning.message}")
# If we still validate the forward execution order in eval mode, then
# an `AssertionError` will be raised above for both sharding strategies
instantiate_parametrized_tests(TestFSDPExecOrder)

View File

@ -324,9 +324,9 @@ class _ExecOrderData():
_all_flat_params (List[FlatParameter]): A :class:`list` of all
flattened parameters contained in the FSDP module hierarchy with
the list index implicitly giving a unique parameter index.
param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A mapping
from flattened parameter to the comprising unflattened parameters'
names.
_param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A
mapping from flattened parameter to the comprising unflattened
parameters' names.
is_first_iter (bool): Whether executing in the first iteration or not.
param_order (List[int]): Order that parameters participate in the
forward pass; constructed on the first iteration and validated
@ -391,9 +391,8 @@ class _ExecOrderData():
return self._param_to_unflat_param_names[param]
def reset(self):
"""Called in :meth:`_wait_for_post_backward` or in
:meth:`_post_backward_hook` when inside ``no_sync()`` to reset data for
the next iteration."""
"""Called in :meth:`_wait_for_post_backward` to reset data for the next
iteration."""
self.is_first_iter = False
self.index = 0
# `reset()` marks the end of an iteration, so transition if needed
@ -2429,10 +2428,6 @@ class FullyShardedDataParallel(nn.Module):
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
if not self._require_backward_grad_sync:
# Reset the execution order data structure here since the
# `_wait_for_post_backward()` callback is skipped
if self._is_root:
self._exec_order_data.reset()
return
# Wait for all work in the current stream to finish, then start the
@ -2724,6 +2719,10 @@ class FullyShardedDataParallel(nn.Module):
# p._is_sharded = False. However when it is set, the
# device is always self.compute_device.
p.data = p.data.to(self.compute_device, non_blocking=True)
# Check the validity of this `_rebuild_full_params()` call in
# terms of execution order (regardless of if FSDP actually
# needs to all-gather or not)
self._check_rebuild_full_params(p)
# e.g., when world_size == 1
if not p._is_sharded: # type: ignore[attr-defined]
if mixed_precision_cast_ran:
@ -2784,7 +2783,6 @@ class FullyShardedDataParallel(nn.Module):
# Allocate based on full size from all shards.
_alloc_storage(p._full_param_padded, size=p_full_size) # type: ignore[attr-defined]
output_tensor = p._full_param_padded # type: ignore[attr-defined]
self._check_all_gather(p)
# Fill output_tensor with (p.data for each shard in self.world_size)
dist._all_gather_base(
output_tensor, p_data, group=self.process_group
@ -2805,34 +2803,37 @@ class FullyShardedDataParallel(nn.Module):
self._free_mp_shard(cast(List[FlatParameter], [p]))
return output_tensors
def _check_all_gather(self, param: FlatParameter):
def _check_rebuild_full_params(self, param: FlatParameter):
"""
Checks the validity of an all-gather to rebuild the full parameter
``param``. If on the first iteration, this uses an all-gather to check
that all ranks plan to all-gather the same parameter, erroring if not,
and on subsequent iterations, if the all-gather order differs from that
of the first iteration (meaning that we can no longer guarantee correct
execution), then we issue a warning to the user. This only issues
Checks the validity of a call to :meth:`_rebuild_full_params` in terms
of the execution order. If on the first iteration, this uses an
all-gather to check that all ranks are running ``forward()`` with the
same parameter, erroring if not, and on subsequent iterations, if the
forward order differs from that of the first iteration (meaning that we
can no longer guarantee correct execution since all-gathers may be
mismatched), then we issue a warning to the user. This only issues
warnings on the first deviating iteration and stops checking
thereafter.
For now, only the all-gathers to rebuild full parameters in the forward
pass are checked since (1) a correct forward order should imply a
correct pre-backward order for typical cases and (2) there may be some
issues with pre-fetching that need to be looked into:
https://github.com/pytorch/pytorch/issues/76553
Only the :meth:`_rebuild_full_params` calls in the forward pass are
checked since a correct forward order should imply a correct
pre-backward order for typical cases.
Executing in ``no_sync()`` does not affect this check for
``FULL_SHARD`` and ``SHARD_GRAD_OP``: (1) Being in ``no_sync()`` in the
first iteration does not yield a different all-gather sequence, and (2)
being in ``no_sync()`` in a later iteration does not give false
positive warnings since the all-gather sequence still matches the first
first iteration does not yield a different forward
:meth:`_rebuild_full_params()` sequence, and (2) being in ``no_sync()``
in a later iteration does not give false positive warnings since the
forward :meth:`_rebuild_full_params()` sequence still matches the first
iteration sequence (for ``FULL_SHARD``) or the first iteration
sequence's prefix (for ``SHARD_GRAD_OP``).
"""
# Only check all-gathers when rebuilding the full parameters in the
# forward pass
if self.training_state != TrainingState_.FORWARD:
# Only check when rebuilding the full parameters in the forward pass,
# and skip the check (1) when in eval mode since then there is not a
# safe point at which to reset the execution order data and (2) if
# world size is 1 since then there is no chance of desynchronization
if self.training_state != TrainingState_.FORWARD or \
not self.training or self.world_size == 1:
return
eod = self._exec_order_data
param_index = eod.get_param_index(param)
@ -2843,40 +2844,52 @@ class FullyShardedDataParallel(nn.Module):
return
# However, we may issue multiple warnings on the first deviating
# iteration to help debugging, where either:
# 1. This iteration sees more all-gathers than the first iteration
msg_prefix = all_gather_seq = None # non-`None` means we warn
# 1. This iteration sees an extra `_rebuild_full_params()` in
# `forward()` compared to the first iteration
msg_prefix = curr_param_order = None # non-`None` means we warn
if eod.index >= len(eod.param_order):
msg_prefix = "Expected no more all-gathers but got an all-gather for "
all_gather_seq = eod.param_order + [param_index]
msg_prefix = "Expected to not rebuild any more parameters " \
"in `forward()` for this module but trying to rebuild " \
"parameters for "
curr_param_order = eod.param_order + [param_index]
else:
expected_param_index = eod.param_order[eod.index]
# 2. This iteration sees the same number of all-gathers (so
# far) but the current parameter to all-gather differs
# 2. This iteration sees the same number of
# `_rebuild_full_params()` (so far) but the current parameter
# differs
if param_index != expected_param_index:
expected_param_names = eod.get_unflat_param_names(expected_param_index)
assert len(expected_param_names) > 0, \
"The expected parameter to all-gather should always be valid"
msg_prefix = "Expected an all-gather for the FSDP module " \
f"wrapping {expected_param_names} but got an all-gather for "
all_gather_seq = eod.param_order[:eod.index - 1] + [param_index]
"Expected parameter should always be valid"
msg_prefix = "Expected to rebuild parameters in " \
f"`forward()` for {expected_param_names} but " \
"instead trying to rebuild parameters for "
curr_param_order = eod.param_order[:eod.index - 1] + [param_index]
to_issue_warning = msg_prefix is not None
if to_issue_warning:
assert all_gather_seq is not None
assert curr_param_order is not None
param_names = eod.get_unflat_param_names(param_index)
is_added_param = len(param_names) == 0
if is_added_param:
msg_suffix = "a newly-added parameter since construction time"
else:
msg_suffix = f"the FSDP module wrapping {param_names}"
msg_suffix = f"{param_names}"
sub_msg = msg_prefix + msg_suffix
first_iter_param_names = [
eod.get_unflat_param_names(index) for index in eod.param_order
]
curr_iter_param_names = [
eod.get_unflat_param_names(index) for index in curr_param_order
]
print(first_iter_param_names, type(first_iter_param_names))
print(curr_iter_param_names, type(curr_iter_param_names))
warnings.warn(
"All-gather order differs from that of the first iteration "
"Forward order differs from that of the first iteration "
f"on rank {self.rank} -- collectives are unchecked and may "
"give incorrect results or hang\n" + sub_msg + "\n" +
f"First iteration's all-gather sequence: {eod.param_order}"
"\nThis iteration's all-gather sequence (so far): "
f"{all_gather_seq}\nwhere indices follow the root FSDP "
"module's `.parameters()` order"
f"First iteration's forward order: {first_iter_param_names}"
"\nThis iteration's forward order (so far): "
f"{curr_iter_param_names}"
)
eod.warn_status = _ExecOrderWarnStatus.WARNING
eod.index += 1
@ -2896,10 +2909,10 @@ class FullyShardedDataParallel(nn.Module):
r1_param_names = eod.get_unflat_param_names(i1)
r2_param_names = eod.get_unflat_param_names(i2)
raise RuntimeError(
f"All-gather order differs across ranks: rank {r1} is "
f"all-gathering the flattened parameter wrapping "
f"{r1_param_names} while rank {r2} is all-gathering "
f"the flattened parameter wrapping {r2_param_names}"
f"Forward order differs across ranks: rank {r1} is "
"rebuilding full parameters in `forward()` for "
f"{r1_param_names} while rank {r2} is rebuilding full "
f"parameters in `forward()` for {r2_param_names}"
)
eod.param_order.append(param_index)