mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
[FSDP] Add training_state parameter to fsdp_pre_all_gather hook (#167886)
Summary: Add support for passing training state information to FSDP tensor subclass hooks during all-gather operations. This enables tensor subclasses to differentiate between forward and backward passes for state-aware optimizations like conditional quantization. **Changes:** 1. **Hook Signature Extension** (`_fsdp_param.py`, `_fsdp_param_group.py`): - Added `param_group` parameter to `FSDPParam.__init__()` to maintain reference to parent `FSDPParamGroup` - Extended hook signature detection to support new 2-parameter signature: `fsdp_pre_all_gather(mesh, training_state)` - Maintained backward compatibility with existing 1-param and 5-param signatures - Pass `TrainingState` from parent group to tensor subclass hooks, defaulting to `TrainingState.IDLE` when `param_group` is None 2. **Implementation Details:** - Used `TYPE_CHECKING` guard to avoid circular import with `FSDPParamGroup` - Reference cycle between `FSDPParam` and `FSDPParamGroup` is safe (parent-child have same lifetime) - Added `TrainingState` import from `_fsdp_common` **Benefits:** - Enables tensor subclasses to implement state-aware behavior (e.g., different quantization strategies for forward vs backward) - Maintains full backward compatibility with existing hooks - Clean API extension without breaking changes Differential Revision: D86938522
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4ed26f7382
commit
ef7260317c
@ -4,7 +4,11 @@ import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto, Enum
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._fsdp_state import FSDPParamGroup
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,6 +30,7 @@ from ._fsdp_common import (
|
||||
compiled_autograd_enabled,
|
||||
FSDPMeshInfo,
|
||||
HSDPMeshInfo,
|
||||
TrainingState,
|
||||
)
|
||||
|
||||
|
||||
@ -229,6 +234,7 @@ class FSDPParam:
|
||||
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
offload_policy: OffloadPolicy,
|
||||
param_group: Optional["FSDPParamGroup"] = None,
|
||||
):
|
||||
self._module_info: ParamModuleInfo = module_info
|
||||
self.mesh_info = mesh_info
|
||||
@ -241,6 +247,7 @@ class FSDPParam:
|
||||
self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
|
||||
)
|
||||
self.grad_offload_event: Optional[torch.Event] = None
|
||||
self._param_group = param_group # Reference to parent param group
|
||||
self._init_sharded_param(param, device, shard_placement_fn)
|
||||
if self.post_forward_mesh_info:
|
||||
self._init_sharded_post_forward_param_metadata(param)
|
||||
@ -730,8 +737,9 @@ class FSDPParam:
|
||||
sharded_local_tensor.fsdp_pre_all_gather
|
||||
)
|
||||
num_fn_params = len(pre_all_gather_signature.parameters)
|
||||
# Old signature only passes mesh; keep for BC for now
|
||||
if num_fn_params not in (1, 5):
|
||||
# Support old 1-param signature and new 2-param signature with training_state
|
||||
# Keep 5-param signature as the default signature
|
||||
if num_fn_params not in (1, 2, 5):
|
||||
raise AssertionError(
|
||||
f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n"
|
||||
"Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, "
|
||||
@ -739,6 +747,7 @@ class FSDPParam:
|
||||
"module: nn.Module, mp_policy: MixedPrecisionPolicy)"
|
||||
)
|
||||
if num_fn_params == 1:
|
||||
# Old signature - backward compatibility
|
||||
(
|
||||
all_gather_inputs,
|
||||
self._extensions_data.all_gather_metadata,
|
||||
@ -746,6 +755,21 @@ class FSDPParam:
|
||||
) = sharded_local_tensor.fsdp_pre_all_gather(
|
||||
self.shard_mesh_from_root
|
||||
)
|
||||
elif num_fn_params == 2:
|
||||
# New signature with training_state
|
||||
training_state = (
|
||||
self._param_group._training_state
|
||||
if self._param_group is not None
|
||||
else TrainingState.IDLE
|
||||
)
|
||||
(
|
||||
all_gather_inputs,
|
||||
self._extensions_data.all_gather_metadata,
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
) = sharded_local_tensor.fsdp_pre_all_gather(
|
||||
self.shard_mesh_from_root,
|
||||
training_state,
|
||||
)
|
||||
else:
|
||||
(
|
||||
all_gather_inputs,
|
||||
|
||||
@ -147,6 +147,7 @@ class FSDPParamGroup:
|
||||
shard_placement_fn,
|
||||
mp_policy,
|
||||
offload_policy,
|
||||
param_group=self, # Ref cycle is OK: parent-child have same lifetime
|
||||
)
|
||||
for param, module_info in zip(params, param_module_infos)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user