mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[FSDP] Do not check fwd order in eval mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77195 Approved by: https://github.com/zhaojuanmao
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							e8f53ad1f1
						
					
				
				
					commit
					e912d24303
				
			@ -1,6 +1,7 @@
 | 
			
		||||
# Owner(s): ["oncall: distributed"]
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -109,7 +110,7 @@ class TestFSDPExecOrder(FSDPTest):
 | 
			
		||||
            fsdp_model.flip_path()
 | 
			
		||||
        inp = fsdp_model.module.get_input(self.device)
 | 
			
		||||
        # Match the error message with the following prefix
 | 
			
		||||
        error_regex = "^(All-gather order differs across ranks)"
 | 
			
		||||
        error_regex = "^(Forward order differs across ranks)"
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, error_regex):
 | 
			
		||||
            fsdp_model(*inp)
 | 
			
		||||
 | 
			
		||||
@ -135,7 +136,7 @@ class TestFSDPExecOrder(FSDPTest):
 | 
			
		||||
            loss = fsdp_model.module.get_loss(inp, output).to(self.device)
 | 
			
		||||
            fsdp_model.module.run_backward(loss)
 | 
			
		||||
        # Match the warning message with the following prefix
 | 
			
		||||
        regex = "^(All-gather order differs from that of the first iteration " \
 | 
			
		||||
        regex = "^(Forward order differs from that of the first iteration " \
 | 
			
		||||
            f"on rank {self.rank} -- collectives are unchecked and may give " \
 | 
			
		||||
            "incorrect results or hang)"
 | 
			
		||||
        context = self.assertWarnsRegex(
 | 
			
		||||
@ -155,6 +156,37 @@ class TestFSDPExecOrder(FSDPTest):
 | 
			
		||||
        loss = fsdp_model.module.get_loss(inp, output).to(self.device)
 | 
			
		||||
        fsdp_model.module.run_backward(loss)
 | 
			
		||||
 | 
			
		||||
    @skip_if_lt_x_gpu(2)
 | 
			
		||||
    @parametrize(
 | 
			
		||||
        "sharding_strategy",
 | 
			
		||||
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
 | 
			
		||||
    )
 | 
			
		||||
    def test_train_eval(self, sharding_strategy: ShardingStrategy):
 | 
			
		||||
        fsdp_model = Model.wrap(sharding_strategy, self.device)
 | 
			
		||||
        NUM_ITERS = 3
 | 
			
		||||
        NUM_EPOCHS = 2
 | 
			
		||||
        with warnings.catch_warnings(record=True) as w:  # records warnings to `w`
 | 
			
		||||
            for _ in range(NUM_EPOCHS):
 | 
			
		||||
                fsdp_model.train()
 | 
			
		||||
                for _ in range(NUM_ITERS):
 | 
			
		||||
                    inp = fsdp_model.module.get_input(self.device)
 | 
			
		||||
                    output = fsdp_model(*inp)
 | 
			
		||||
                    loss = fsdp_model.module.get_loss(inp, output).to(self.device)
 | 
			
		||||
                    fsdp_model.module.run_backward(loss)
 | 
			
		||||
                fsdp_model.eval()
 | 
			
		||||
                for _ in range(NUM_ITERS):
 | 
			
		||||
                    inp = fsdp_model.module.get_input(self.device)
 | 
			
		||||
                    output = fsdp_model(*inp)
 | 
			
		||||
                    fsdp_model.module.get_loss(inp, output).to(self.device)
 | 
			
		||||
        # Check that the order validation warning was not issued (errors do not
 | 
			
		||||
        # need to be checked since they will be directly reported)
 | 
			
		||||
        warning_prefix = "Forward order differs"
 | 
			
		||||
        for warning in w:
 | 
			
		||||
            if str(warning.message).startswith(warning_prefix):
 | 
			
		||||
                raise AssertionError(f"Warning was incorrectly issued: {warning.message}")
 | 
			
		||||
        # If we still validate the forward execution order in eval mode, then
 | 
			
		||||
        # an `AssertionError` will be raised above for both sharding strategies
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_parametrized_tests(TestFSDPExecOrder)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -324,9 +324,9 @@ class _ExecOrderData():
 | 
			
		||||
        _all_flat_params (List[FlatParameter]): A :class:`list` of all
 | 
			
		||||
            flattened parameters contained in the FSDP module hierarchy with
 | 
			
		||||
            the list index implicitly giving a unique parameter index.
 | 
			
		||||
        param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A mapping
 | 
			
		||||
            from flattened parameter to the comprising unflattened parameters'
 | 
			
		||||
            names.
 | 
			
		||||
        _param_to_unflat_param_names (Dict[FlatParameter, List[str]]): A
 | 
			
		||||
            mapping from flattened parameter to the comprising unflattened
 | 
			
		||||
            parameters' names.
 | 
			
		||||
        is_first_iter (bool): Whether executing in the first iteration or not.
 | 
			
		||||
        param_order (List[int]): Order that parameters participate in the
 | 
			
		||||
            forward pass; constructed on the first iteration and validated
 | 
			
		||||
@ -391,9 +391,8 @@ class _ExecOrderData():
 | 
			
		||||
        return self._param_to_unflat_param_names[param]
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        """Called in :meth:`_wait_for_post_backward` or in
 | 
			
		||||
        :meth:`_post_backward_hook` when inside ``no_sync()`` to reset data for
 | 
			
		||||
        the next iteration."""
 | 
			
		||||
        """Called in :meth:`_wait_for_post_backward` to reset data for the next
 | 
			
		||||
        iteration."""
 | 
			
		||||
        self.is_first_iter = False
 | 
			
		||||
        self.index = 0
 | 
			
		||||
        # `reset()` marks the end of an iteration, so transition if needed
 | 
			
		||||
@ -2429,10 +2428,6 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
            torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
 | 
			
		||||
 | 
			
		||||
        if not self._require_backward_grad_sync:
 | 
			
		||||
            # Reset the execution order data structure here since the
 | 
			
		||||
            # `_wait_for_post_backward()` callback is skipped
 | 
			
		||||
            if self._is_root:
 | 
			
		||||
                self._exec_order_data.reset()
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Wait for all work in the current stream to finish, then start the
 | 
			
		||||
@ -2724,6 +2719,10 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
                    # p._is_sharded = False. However when it is set, the
 | 
			
		||||
                    # device is always self.compute_device.
 | 
			
		||||
                    p.data = p.data.to(self.compute_device, non_blocking=True)
 | 
			
		||||
                # Check the validity of this `_rebuild_full_params()` call in
 | 
			
		||||
                # terms of execution order (regardless of if FSDP actually
 | 
			
		||||
                # needs to all-gather or not)
 | 
			
		||||
                self._check_rebuild_full_params(p)
 | 
			
		||||
                # e.g., when world_size == 1
 | 
			
		||||
                if not p._is_sharded:  # type: ignore[attr-defined]
 | 
			
		||||
                    if mixed_precision_cast_ran:
 | 
			
		||||
@ -2784,7 +2783,6 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
                        # Allocate based on full size from all shards.
 | 
			
		||||
                        _alloc_storage(p._full_param_padded, size=p_full_size)  # type: ignore[attr-defined]
 | 
			
		||||
                        output_tensor = p._full_param_padded  # type: ignore[attr-defined]
 | 
			
		||||
                    self._check_all_gather(p)
 | 
			
		||||
                    # Fill output_tensor with (p.data for each shard in self.world_size)
 | 
			
		||||
                    dist._all_gather_base(
 | 
			
		||||
                        output_tensor, p_data, group=self.process_group
 | 
			
		||||
@ -2805,34 +2803,37 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
                        self._free_mp_shard(cast(List[FlatParameter], [p]))
 | 
			
		||||
        return output_tensors
 | 
			
		||||
 | 
			
		||||
    def _check_all_gather(self, param: FlatParameter):
 | 
			
		||||
    def _check_rebuild_full_params(self, param: FlatParameter):
 | 
			
		||||
        """
 | 
			
		||||
        Checks the validity of an all-gather to rebuild the full parameter
 | 
			
		||||
        ``param``. If on the first iteration, this uses an all-gather to check
 | 
			
		||||
        that all ranks plan to all-gather the same parameter, erroring if not,
 | 
			
		||||
        and on subsequent iterations, if the all-gather order differs from that
 | 
			
		||||
        of the first iteration (meaning that we can no longer guarantee correct
 | 
			
		||||
        execution), then we issue a warning to the user. This only issues
 | 
			
		||||
        Checks the validity of a call to :meth:`_rebuild_full_params` in terms
 | 
			
		||||
        of the execution order. If on the first iteration, this uses an
 | 
			
		||||
        all-gather to check that all ranks are running ``forward()`` with the
 | 
			
		||||
        same parameter, erroring if not, and on subsequent iterations, if the
 | 
			
		||||
        forward order differs from that of the first iteration (meaning that we
 | 
			
		||||
        can no longer guarantee correct execution since all-gathers may be
 | 
			
		||||
        mismatched), then we issue a warning to the user. This only issues
 | 
			
		||||
        warnings on the first deviating iteration and stops checking
 | 
			
		||||
        thereafter.
 | 
			
		||||
 | 
			
		||||
        For now, only the all-gathers to rebuild full parameters in the forward
 | 
			
		||||
        pass are checked since (1) a correct forward order should imply a
 | 
			
		||||
        correct pre-backward order for typical cases and (2) there may be some
 | 
			
		||||
        issues with pre-fetching that need to be looked into:
 | 
			
		||||
        https://github.com/pytorch/pytorch/issues/76553
 | 
			
		||||
        Only the :meth:`_rebuild_full_params` calls in the forward pass are
 | 
			
		||||
        checked since a correct forward order should imply a correct
 | 
			
		||||
        pre-backward order for typical cases.
 | 
			
		||||
 | 
			
		||||
        Executing in ``no_sync()`` does not affect this check for
 | 
			
		||||
        ``FULL_SHARD`` and ``SHARD_GRAD_OP``: (1) Being in ``no_sync()`` in the
 | 
			
		||||
        first iteration does not yield a different all-gather sequence, and (2)
 | 
			
		||||
        being in ``no_sync()`` in a later iteration does not give false
 | 
			
		||||
        positive warnings since the all-gather sequence still matches the first
 | 
			
		||||
        first iteration does not yield a different forward
 | 
			
		||||
        :meth:`_rebuild_full_params()` sequence, and (2) being in ``no_sync()``
 | 
			
		||||
        in a later iteration does not give false positive warnings since the
 | 
			
		||||
        forward :meth:`_rebuild_full_params()` sequence still matches the first
 | 
			
		||||
        iteration sequence (for ``FULL_SHARD``) or the first iteration
 | 
			
		||||
        sequence's prefix (for ``SHARD_GRAD_OP``).
 | 
			
		||||
        """
 | 
			
		||||
        # Only check all-gathers when rebuilding the full parameters in the
 | 
			
		||||
        # forward pass
 | 
			
		||||
        if self.training_state != TrainingState_.FORWARD:
 | 
			
		||||
        # Only check when rebuilding the full parameters in the forward pass,
 | 
			
		||||
        # and skip the check (1) when in eval mode since then there is not a
 | 
			
		||||
        # safe point at which to reset the execution order data and (2) if
 | 
			
		||||
        # world size is 1 since then there is no chance of desynchronization
 | 
			
		||||
        if self.training_state != TrainingState_.FORWARD or \
 | 
			
		||||
                not self.training or self.world_size == 1:
 | 
			
		||||
            return
 | 
			
		||||
        eod = self._exec_order_data
 | 
			
		||||
        param_index = eod.get_param_index(param)
 | 
			
		||||
@ -2843,40 +2844,52 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
                return
 | 
			
		||||
            # However, we may issue multiple warnings on the first deviating
 | 
			
		||||
            # iteration to help debugging, where either:
 | 
			
		||||
            # 1. This iteration sees more all-gathers than the first iteration
 | 
			
		||||
            msg_prefix = all_gather_seq = None  # non-`None` means we warn
 | 
			
		||||
            # 1. This iteration sees an extra `_rebuild_full_params()` in
 | 
			
		||||
            # `forward()` compared to the first iteration
 | 
			
		||||
            msg_prefix = curr_param_order = None  # non-`None` means we warn
 | 
			
		||||
            if eod.index >= len(eod.param_order):
 | 
			
		||||
                msg_prefix = "Expected no more all-gathers but got an all-gather for "
 | 
			
		||||
                all_gather_seq = eod.param_order + [param_index]
 | 
			
		||||
                msg_prefix = "Expected to not rebuild any more parameters " \
 | 
			
		||||
                    "in `forward()` for this module but trying to rebuild " \
 | 
			
		||||
                    "parameters for "
 | 
			
		||||
                curr_param_order = eod.param_order + [param_index]
 | 
			
		||||
            else:
 | 
			
		||||
                expected_param_index = eod.param_order[eod.index]
 | 
			
		||||
                # 2. This iteration sees the same number of all-gathers (so
 | 
			
		||||
                # far) but the current parameter to all-gather differs
 | 
			
		||||
                # 2. This iteration sees the same number of
 | 
			
		||||
                # `_rebuild_full_params()` (so far) but the current parameter
 | 
			
		||||
                # differs
 | 
			
		||||
                if param_index != expected_param_index:
 | 
			
		||||
                    expected_param_names = eod.get_unflat_param_names(expected_param_index)
 | 
			
		||||
                    assert len(expected_param_names) > 0, \
 | 
			
		||||
                        "The expected parameter to all-gather should always be valid"
 | 
			
		||||
                    msg_prefix = "Expected an all-gather for the FSDP module " \
 | 
			
		||||
                        f"wrapping {expected_param_names} but got an all-gather for "
 | 
			
		||||
                    all_gather_seq = eod.param_order[:eod.index - 1] + [param_index]
 | 
			
		||||
                        "Expected parameter should always be valid"
 | 
			
		||||
                    msg_prefix = "Expected to rebuild parameters in " \
 | 
			
		||||
                        f"`forward()` for {expected_param_names} but " \
 | 
			
		||||
                        "instead trying to rebuild parameters for "
 | 
			
		||||
                    curr_param_order = eod.param_order[:eod.index - 1] + [param_index]
 | 
			
		||||
            to_issue_warning = msg_prefix is not None
 | 
			
		||||
            if to_issue_warning:
 | 
			
		||||
                assert all_gather_seq is not None
 | 
			
		||||
                assert curr_param_order is not None
 | 
			
		||||
                param_names = eod.get_unflat_param_names(param_index)
 | 
			
		||||
                is_added_param = len(param_names) == 0
 | 
			
		||||
                if is_added_param:
 | 
			
		||||
                    msg_suffix = "a newly-added parameter since construction time"
 | 
			
		||||
                else:
 | 
			
		||||
                    msg_suffix = f"the FSDP module wrapping {param_names}"
 | 
			
		||||
                    msg_suffix = f"{param_names}"
 | 
			
		||||
                sub_msg = msg_prefix + msg_suffix
 | 
			
		||||
                first_iter_param_names = [
 | 
			
		||||
                    eod.get_unflat_param_names(index) for index in eod.param_order
 | 
			
		||||
                ]
 | 
			
		||||
                curr_iter_param_names = [
 | 
			
		||||
                    eod.get_unflat_param_names(index) for index in curr_param_order
 | 
			
		||||
                ]
 | 
			
		||||
                print(first_iter_param_names, type(first_iter_param_names))
 | 
			
		||||
                print(curr_iter_param_names, type(curr_iter_param_names))
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "All-gather order differs from that of the first iteration "
 | 
			
		||||
                    "Forward order differs from that of the first iteration "
 | 
			
		||||
                    f"on rank {self.rank} -- collectives are unchecked and may "
 | 
			
		||||
                    "give incorrect results or hang\n" + sub_msg + "\n" +
 | 
			
		||||
                    f"First iteration's all-gather sequence: {eod.param_order}"
 | 
			
		||||
                    "\nThis iteration's all-gather sequence (so far): "
 | 
			
		||||
                    f"{all_gather_seq}\nwhere indices follow the root FSDP "
 | 
			
		||||
                    "module's `.parameters()` order"
 | 
			
		||||
                    f"First iteration's forward order: {first_iter_param_names}"
 | 
			
		||||
                    "\nThis iteration's forward order (so far): "
 | 
			
		||||
                    f"{curr_iter_param_names}"
 | 
			
		||||
                )
 | 
			
		||||
                eod.warn_status = _ExecOrderWarnStatus.WARNING
 | 
			
		||||
            eod.index += 1
 | 
			
		||||
@ -2896,10 +2909,10 @@ class FullyShardedDataParallel(nn.Module):
 | 
			
		||||
                    r1_param_names = eod.get_unflat_param_names(i1)
 | 
			
		||||
                    r2_param_names = eod.get_unflat_param_names(i2)
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        f"All-gather order differs across ranks: rank {r1} is "
 | 
			
		||||
                        f"all-gathering the flattened parameter wrapping "
 | 
			
		||||
                        f"{r1_param_names} while rank {r2} is all-gathering "
 | 
			
		||||
                        f"the flattened parameter wrapping {r2_param_names}"
 | 
			
		||||
                        f"Forward order differs across ranks: rank {r1} is "
 | 
			
		||||
                        "rebuilding full parameters in `forward()` for "
 | 
			
		||||
                        f"{r1_param_names} while rank {r2} is rebuilding full "
 | 
			
		||||
                        f"parameters in `forward()` for {r2_param_names}"
 | 
			
		||||
                    )
 | 
			
		||||
            eod.param_order.append(param_index)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user