[FSDP] Add training_state parameter to fsdp_pre_all_gather hook (#167886)

Summary:

Add support for passing training state information to FSDP tensor subclass hooks during all-gather operations. This enables tensor subclasses to differentiate between forward and backward passes for state-aware optimizations like conditional quantization.

**Changes:**

1. **Hook Signature Extension** (`_fsdp_param.py`, `_fsdp_param_group.py`):
   - Added `param_group` parameter to `FSDPParam.__init__()` to maintain reference to parent `FSDPParamGroup`
   - Extended hook signature detection to support new 2-parameter signature: `fsdp_pre_all_gather(mesh, training_state)`
   - Maintained backward compatibility with existing 1-param and 5-param signatures
   - Pass `TrainingState` from parent group to tensor subclass hooks, defaulting to `TrainingState.IDLE` when `param_group` is None

2. **Implementation Details:**
   - Used `TYPE_CHECKING` guard to avoid circular import with `FSDPParamGroup`
   - Reference cycle between `FSDPParam` and `FSDPParamGroup` is safe (parent-child have same lifetime)
   - Added `TrainingState` import from `_fsdp_common`

**Benefits:**
- Enables tensor subclasses to implement state-aware behavior (e.g., different quantization strategies for forward vs backward)
- Maintains full backward compatibility with existing hooks
- Clean API extension without breaking changes

Differential Revision: D86938522
This commit is contained in:
Yuhang Yang
2025-11-14 15:45:50 -08:00
committed by Facebook GitHub Bot
parent 4ed26f7382
commit ef7260317c
2 changed files with 28 additions and 3 deletions

View File

@ -4,7 +4,11 @@ import itertools
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, cast, Optional
from typing import Any, cast, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ._fsdp_state import FSDPParamGroup
import torch
import torch.nn as nn
@ -26,6 +30,7 @@ from ._fsdp_common import (
compiled_autograd_enabled,
FSDPMeshInfo,
HSDPMeshInfo,
TrainingState,
)
@ -229,6 +234,7 @@ class FSDPParam:
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
param_group: Optional["FSDPParamGroup"] = None,
):
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
@ -241,6 +247,7 @@ class FSDPParam:
self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
)
self.grad_offload_event: Optional[torch.Event] = None
self._param_group = param_group # Reference to parent param group
self._init_sharded_param(param, device, shard_placement_fn)
if self.post_forward_mesh_info:
self._init_sharded_post_forward_param_metadata(param)
@ -730,8 +737,9 @@ class FSDPParam:
sharded_local_tensor.fsdp_pre_all_gather
)
num_fn_params = len(pre_all_gather_signature.parameters)
# Old signature only passes mesh; keep for BC for now
if num_fn_params not in (1, 5):
# Support old 1-param signature and new 2-param signature with training_state
# Keep 5-param signature as the default signature
if num_fn_params not in (1, 2, 5):
raise AssertionError(
f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n"
"Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, "
@ -739,6 +747,7 @@ class FSDPParam:
"module: nn.Module, mp_policy: MixedPrecisionPolicy)"
)
if num_fn_params == 1:
# Old signature - backward compatibility
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
@ -746,6 +755,21 @@ class FSDPParam:
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root
)
elif num_fn_params == 2:
# New signature with training_state
training_state = (
self._param_group._training_state
if self._param_group is not None
else TrainingState.IDLE
)
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
# pyrefly: ignore [missing-attribute]
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root,
training_state,
)
else:
(
all_gather_inputs,

View File

@ -147,6 +147,7 @@ class FSDPParamGroup:
shard_placement_fn,
mp_policy,
offload_policy,
param_group=self, # Ref cycle is OK: parent-child have same lifetime
)
for param, module_info in zip(params, param_module_infos)
]