[FSDP][Replicate] tests replicate with custom forward method (#162851)

**Summary: tests replicate works when users use custom forward methods**

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_register_fsdp_forward_method

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162851
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839
This commit is contained in:
Anshul Sinha
2025-09-25 12:23:45 -07:00
committed by PyTorch MergeBot
parent 1ce9563ff6
commit d3bdf8c32e

View File

@ -19,7 +19,12 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
OffloadPolicy,
register_fsdp_forward_method,
)
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@ -1078,5 +1083,53 @@ class TestReplicateGradientAccumulation(FSDPTest):
check_sharded_parity(self, ref_model, model)
class TestReplicateCustomForwardMethod(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_register_fsdp_forward_method(self):
class VisionTransformer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).sum(dim=1)
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
# Run `vit.forward_features`, which is not `forward`!
patch_embeddings = self.vit.forward_features(imgs)
return self.projector(patch_embeddings)
torch.manual_seed(42)
model = Model()
ref_model = copy.deepcopy(model).to(device_type)
replicate(model.vit)
replicate(model.projector)
replicate(model)
register_fsdp_forward_method(model.vit, "forward_features")
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)
if __name__ == "__main__":
run_tests()