diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 237f59673828..e8ba77e8fa0e 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -230,7 +230,7 @@ class FSDPState(_State): self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[tuple[Any, ...], dict[str, Any]]: # When composing with module-hook-based activation checkpointing, the - # the pre-backward hook is responsible for the unshard + # pre-backward hook is responsible for the unshard if self._training_state == TrainingState.PRE_BACKWARD: return args, kwargs self._training_state = TrainingState.FORWARD