mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
[FSDP] Do not check fwd order in eval mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77195 Approved by: https://github.com/zhaojuanmao
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8f53ad1f1
commit
e912d24303
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import suppress
|
||||
|
||||
import torch
|
||||
@ -109,7 +110,7 @@ class TestFSDPExecOrder(FSDPTest):
|
||||
fsdp_model.flip_path()
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
# Match the error message with the following prefix
|
||||
error_regex = "^(All-gather order differs across ranks)"
|
||||
error_regex = "^(Forward order differs across ranks)"
|
||||
with self.assertRaisesRegex(RuntimeError, error_regex):
|
||||
fsdp_model(*inp)
|
||||
|
||||
@ -135,7 +136,7 @@ class TestFSDPExecOrder(FSDPTest):
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
# Match the warning message with the following prefix
|
||||
regex = "^(All-gather order differs from that of the first iteration " \
|
||||
regex = "^(Forward order differs from that of the first iteration " \
|
||||
f"on rank {self.rank} -- collectives are unchecked and may give " \
|
||||
"incorrect results or hang)"
|
||||
context = self.assertWarnsRegex(
|
||||
@ -155,6 +156,37 @@ class TestFSDPExecOrder(FSDPTest):
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize(
|
||||
"sharding_strategy",
|
||||
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
|
||||
)
|
||||
def test_train_eval(self, sharding_strategy: ShardingStrategy):
|
||||
fsdp_model = Model.wrap(sharding_strategy, self.device)
|
||||
NUM_ITERS = 3
|
||||
NUM_EPOCHS = 2
|
||||
with warnings.catch_warnings(record=True) as w: # records warnings to `w`
|
||||
for _ in range(NUM_EPOCHS):
|
||||
fsdp_model.train()
|
||||
for _ in range(NUM_ITERS):
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
fsdp_model.eval()
|
||||
for _ in range(NUM_ITERS):
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
# Check that the order validation warning was not issued (errors do not
|
||||
# need to be checked since they will be directly reported)
|
||||
warning_prefix = "Forward order differs"
|
||||
for warning in w:
|
||||
if str(warning.message).startswith(warning_prefix):
|
||||
raise AssertionError(f"Warning was incorrectly issued: {warning.message}")
|
||||
# If we still validate the forward execution order in eval mode, then
|
||||
# an `AssertionError` will be raised above for both sharding strategies
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestFSDPExecOrder)
|
||||
|
||||
|
||||
@ -324,9 +324,9 @@ class _ExecOrderData():
|
||||
_all_flat_params (List[FlatParameter]): A :class:`list` of all
|
||||
flattened parameters contained in the FSDP module hierarchy with
|
||||
the list index implicitly giving a unique parameter index.
|
||||
param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A mapping
|
||||
from flattened parameter to the comprising unflattened parameters'
|
||||
names.
|
||||
_param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A
|
||||
mapping from flattened parameter to the comprising unflattened
|
||||
parameters' names.
|
||||
is_first_iter (bool): Whether executing in the first iteration or not.
|
||||
param_order (List[int]): Order that parameters participate in the
|
||||
forward pass; constructed on the first iteration and validated
|
||||
@ -391,9 +391,8 @@ class _ExecOrderData():
|
||||
return self._param_to_unflat_param_names[param]
|
||||
|
||||
def reset(self):
|
||||
"""Called in :meth:`_wait_for_post_backward` or in
|
||||
:meth:`_post_backward_hook` when inside ``no_sync()`` to reset data for
|
||||
the next iteration."""
|
||||
"""Called in :meth:`_wait_for_post_backward` to reset data for the next
|
||||
iteration."""
|
||||
self.is_first_iter = False
|
||||
self.index = 0
|
||||
# `reset()` marks the end of an iteration, so transition if needed
|
||||
@ -2429,10 +2428,6 @@ class FullyShardedDataParallel(nn.Module):
|
||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
|
||||
|
||||
if not self._require_backward_grad_sync:
|
||||
# Reset the execution order data structure here since the
|
||||
# `_wait_for_post_backward()` callback is skipped
|
||||
if self._is_root:
|
||||
self._exec_order_data.reset()
|
||||
return
|
||||
|
||||
# Wait for all work in the current stream to finish, then start the
|
||||
@ -2724,6 +2719,10 @@ class FullyShardedDataParallel(nn.Module):
|
||||
# p._is_sharded = False. However when it is set, the
|
||||
# device is always self.compute_device.
|
||||
p.data = p.data.to(self.compute_device, non_blocking=True)
|
||||
# Check the validity of this `_rebuild_full_params()` call in
|
||||
# terms of execution order (regardless of if FSDP actually
|
||||
# needs to all-gather or not)
|
||||
self._check_rebuild_full_params(p)
|
||||
# e.g., when world_size == 1
|
||||
if not p._is_sharded: # type: ignore[attr-defined]
|
||||
if mixed_precision_cast_ran:
|
||||
@ -2784,7 +2783,6 @@ class FullyShardedDataParallel(nn.Module):
|
||||
# Allocate based on full size from all shards.
|
||||
_alloc_storage(p._full_param_padded, size=p_full_size) # type: ignore[attr-defined]
|
||||
output_tensor = p._full_param_padded # type: ignore[attr-defined]
|
||||
self._check_all_gather(p)
|
||||
# Fill output_tensor with (p.data for each shard in self.world_size)
|
||||
dist._all_gather_base(
|
||||
output_tensor, p_data, group=self.process_group
|
||||
@ -2805,34 +2803,37 @@ class FullyShardedDataParallel(nn.Module):
|
||||
self._free_mp_shard(cast(List[FlatParameter], [p]))
|
||||
return output_tensors
|
||||
|
||||
def _check_all_gather(self, param: FlatParameter):
|
||||
def _check_rebuild_full_params(self, param: FlatParameter):
|
||||
"""
|
||||
Checks the validity of an all-gather to rebuild the full parameter
|
||||
``param``. If on the first iteration, this uses an all-gather to check
|
||||
that all ranks plan to all-gather the same parameter, erroring if not,
|
||||
and on subsequent iterations, if the all-gather order differs from that
|
||||
of the first iteration (meaning that we can no longer guarantee correct
|
||||
execution), then we issue a warning to the user. This only issues
|
||||
Checks the validity of a call to :meth:`_rebuild_full_params` in terms
|
||||
of the execution order. If on the first iteration, this uses an
|
||||
all-gather to check that all ranks are running ``forward()`` with the
|
||||
same parameter, erroring if not, and on subsequent iterations, if the
|
||||
forward order differs from that of the first iteration (meaning that we
|
||||
can no longer guarantee correct execution since all-gathers may be
|
||||
mismatched), then we issue a warning to the user. This only issues
|
||||
warnings on the first deviating iteration and stops checking
|
||||
thereafter.
|
||||
|
||||
For now, only the all-gathers to rebuild full parameters in the forward
|
||||
pass are checked since (1) a correct forward order should imply a
|
||||
correct pre-backward order for typical cases and (2) there may be some
|
||||
issues with pre-fetching that need to be looked into:
|
||||
https://github.com/pytorch/pytorch/issues/76553
|
||||
Only the :meth:`_rebuild_full_params` calls in the forward pass are
|
||||
checked since a correct forward order should imply a correct
|
||||
pre-backward order for typical cases.
|
||||
|
||||
Executing in ``no_sync()`` does not affect this check for
|
||||
``FULL_SHARD`` and ``SHARD_GRAD_OP``: (1) Being in ``no_sync()`` in the
|
||||
first iteration does not yield a different all-gather sequence, and (2)
|
||||
being in ``no_sync()`` in a later iteration does not give false
|
||||
positive warnings since the all-gather sequence still matches the first
|
||||
first iteration does not yield a different forward
|
||||
:meth:`_rebuild_full_params()` sequence, and (2) being in ``no_sync()``
|
||||
in a later iteration does not give false positive warnings since the
|
||||
forward :meth:`_rebuild_full_params()` sequence still matches the first
|
||||
iteration sequence (for ``FULL_SHARD``) or the first iteration
|
||||
sequence's prefix (for ``SHARD_GRAD_OP``).
|
||||
"""
|
||||
# Only check all-gathers when rebuilding the full parameters in the
|
||||
# forward pass
|
||||
if self.training_state != TrainingState_.FORWARD:
|
||||
# Only check when rebuilding the full parameters in the forward pass,
|
||||
# and skip the check (1) when in eval mode since then there is not a
|
||||
# safe point at which to reset the execution order data and (2) if
|
||||
# world size is 1 since then there is no chance of desynchronization
|
||||
if self.training_state != TrainingState_.FORWARD or \
|
||||
not self.training or self.world_size == 1:
|
||||
return
|
||||
eod = self._exec_order_data
|
||||
param_index = eod.get_param_index(param)
|
||||
@ -2843,40 +2844,52 @@ class FullyShardedDataParallel(nn.Module):
|
||||
return
|
||||
# However, we may issue multiple warnings on the first deviating
|
||||
# iteration to help debugging, where either:
|
||||
# 1. This iteration sees more all-gathers than the first iteration
|
||||
msg_prefix = all_gather_seq = None # non-`None` means we warn
|
||||
# 1. This iteration sees an extra `_rebuild_full_params()` in
|
||||
# `forward()` compared to the first iteration
|
||||
msg_prefix = curr_param_order = None # non-`None` means we warn
|
||||
if eod.index >= len(eod.param_order):
|
||||
msg_prefix = "Expected no more all-gathers but got an all-gather for "
|
||||
all_gather_seq = eod.param_order + [param_index]
|
||||
msg_prefix = "Expected to not rebuild any more parameters " \
|
||||
"in `forward()` for this module but trying to rebuild " \
|
||||
"parameters for "
|
||||
curr_param_order = eod.param_order + [param_index]
|
||||
else:
|
||||
expected_param_index = eod.param_order[eod.index]
|
||||
# 2. This iteration sees the same number of all-gathers (so
|
||||
# far) but the current parameter to all-gather differs
|
||||
# 2. This iteration sees the same number of
|
||||
# `_rebuild_full_params()` (so far) but the current parameter
|
||||
# differs
|
||||
if param_index != expected_param_index:
|
||||
expected_param_names = eod.get_unflat_param_names(expected_param_index)
|
||||
assert len(expected_param_names) > 0, \
|
||||
"The expected parameter to all-gather should always be valid"
|
||||
msg_prefix = "Expected an all-gather for the FSDP module " \
|
||||
f"wrapping {expected_param_names} but got an all-gather for "
|
||||
all_gather_seq = eod.param_order[:eod.index - 1] + [param_index]
|
||||
"Expected parameter should always be valid"
|
||||
msg_prefix = "Expected to rebuild parameters in " \
|
||||
f"`forward()` for {expected_param_names} but " \
|
||||
"instead trying to rebuild parameters for "
|
||||
curr_param_order = eod.param_order[:eod.index - 1] + [param_index]
|
||||
to_issue_warning = msg_prefix is not None
|
||||
if to_issue_warning:
|
||||
assert all_gather_seq is not None
|
||||
assert curr_param_order is not None
|
||||
param_names = eod.get_unflat_param_names(param_index)
|
||||
is_added_param = len(param_names) == 0
|
||||
if is_added_param:
|
||||
msg_suffix = "a newly-added parameter since construction time"
|
||||
else:
|
||||
msg_suffix = f"the FSDP module wrapping {param_names}"
|
||||
msg_suffix = f"{param_names}"
|
||||
sub_msg = msg_prefix + msg_suffix
|
||||
first_iter_param_names = [
|
||||
eod.get_unflat_param_names(index) for index in eod.param_order
|
||||
]
|
||||
curr_iter_param_names = [
|
||||
eod.get_unflat_param_names(index) for index in curr_param_order
|
||||
]
|
||||
print(first_iter_param_names, type(first_iter_param_names))
|
||||
print(curr_iter_param_names, type(curr_iter_param_names))
|
||||
warnings.warn(
|
||||
"All-gather order differs from that of the first iteration "
|
||||
"Forward order differs from that of the first iteration "
|
||||
f"on rank {self.rank} -- collectives are unchecked and may "
|
||||
"give incorrect results or hang\n" + sub_msg + "\n" +
|
||||
f"First iteration's all-gather sequence: {eod.param_order}"
|
||||
"\nThis iteration's all-gather sequence (so far): "
|
||||
f"{all_gather_seq}\nwhere indices follow the root FSDP "
|
||||
"module's `.parameters()` order"
|
||||
f"First iteration's forward order: {first_iter_param_names}"
|
||||
"\nThis iteration's forward order (so far): "
|
||||
f"{curr_iter_param_names}"
|
||||
)
|
||||
eod.warn_status = _ExecOrderWarnStatus.WARNING
|
||||
eod.index += 1
|
||||
@ -2896,10 +2909,10 @@ class FullyShardedDataParallel(nn.Module):
|
||||
r1_param_names = eod.get_unflat_param_names(i1)
|
||||
r2_param_names = eod.get_unflat_param_names(i2)
|
||||
raise RuntimeError(
|
||||
f"All-gather order differs across ranks: rank {r1} is "
|
||||
f"all-gathering the flattened parameter wrapping "
|
||||
f"{r1_param_names} while rank {r2} is all-gathering "
|
||||
f"the flattened parameter wrapping {r2_param_names}"
|
||||
f"Forward order differs across ranks: rank {r1} is "
|
||||
"rebuilding full parameters in `forward()` for "
|
||||
f"{r1_param_names} while rank {r2} is rebuilding full "
|
||||
f"parameters in `forward()` for {r2_param_names}"
|
||||
)
|
||||
eod.param_order.append(param_index)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user