mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[FSDP][Replicate] tests replicate with custom forward method (#162851)
**Summary: tests replicate works when users use custom forward methods** **Test Cases** 1. pytest test/distributed/_composable/test_replicate_training.py -k test_register_fsdp_forward_method Pull Request resolved: https://github.com/pytorch/pytorch/pull/162851 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce9563ff6
commit
d3bdf8c32e
@ -19,7 +19,12 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
@ -1078,5 +1083,53 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestReplicateCustomForwardMethod(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_register_fsdp_forward_method(self):
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
|
||||
|
||||
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_features(imgs).sum(dim=1)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
# Run `vit.forward_features`, which is not `forward`!
|
||||
patch_embeddings = self.vit.forward_features(imgs)
|
||||
return self.projector(patch_embeddings)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = Model()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(model.vit)
|
||||
replicate(model.projector)
|
||||
replicate(model)
|
||||
register_fsdp_forward_method(model.vit, "forward_features")
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
ref_loss.backward()
|
||||
loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user