mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157640 Approved by: https://github.com/yewentao256, https://github.com/malfet
235 lines
8.6 KiB
Python
235 lines
8.6 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch import distributed as dist
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import FSDPTest
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TEST_HPU,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
def get_cur_mem(rank, result, prefix):
|
|
"""Collect memory allocated values in a result dict in MB"""
|
|
torch._C._cuda_clearCublasWorkspaces()
|
|
result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
|
|
super().__init__()
|
|
if with_fsdp:
|
|
self.stem = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=3),
|
|
FSDP(nn.BatchNorm2d(64)),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
else:
|
|
self.stem = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=3),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
if with_fsdp:
|
|
self.blocks = nn.Sequential(
|
|
nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
|
|
FSDP(nn.BatchNorm2d(hidden_dim)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
FSDP(nn.BatchNorm2d(hidden_dim)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
FSDP(nn.BatchNorm2d(hidden_dim)),
|
|
nn.ReLU(inplace=True),
|
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
nn.Flatten(),
|
|
)
|
|
else:
|
|
self.blocks = nn.Sequential(
|
|
nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
|
|
nn.BatchNorm2d(hidden_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
nn.BatchNorm2d(hidden_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
nn.BatchNorm2d(hidden_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
nn.Flatten(),
|
|
)
|
|
|
|
self.head = nn.Linear(hidden_dim, 10)
|
|
self.with_checkpoint = with_checkpoint
|
|
|
|
def forward(self, x):
|
|
if self.with_checkpoint:
|
|
return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True))
|
|
else:
|
|
return self.head(self.blocks(self.stem(x)))
|
|
|
|
|
|
def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
|
|
torch.manual_seed(0)
|
|
model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
|
|
if with_fsdp:
|
|
model.stem = FSDP(model.stem)
|
|
model.blocks = FSDP(model.blocks)
|
|
model.head = FSDP(model.head)
|
|
|
|
return model
|
|
|
|
|
|
class TestFSDPMemory(FSDPTest):
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
|
|
gpu_id = self.rank
|
|
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
|
|
|
|
model = create_model(
|
|
with_fsdp=True,
|
|
with_checkpoint=with_checkpoint,
|
|
model_hidden_dim=model_hidden_dim,
|
|
)
|
|
model = model.cuda()
|
|
model = FSDP(model)
|
|
|
|
# We enable momentum so that after the first iteration, the optimizer state is added
|
|
# to the total memory used.
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
|
|
|
|
results = {} # results of memory stats
|
|
for iteration in range(iterations):
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: start")
|
|
|
|
out = model(batch)
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
|
|
|
|
out = sum(o.sum() for o in out[0])
|
|
fake_loss = criterion(out, torch.tensor(0.0).cuda())
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
|
|
|
|
fake_loss.backward()
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
|
|
|
|
optimizer.step()
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: after step")
|
|
|
|
# It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
|
|
model.zero_grad(set_to_none=True)
|
|
get_cur_mem(gpu_id, results, f"iter {iteration}: done")
|
|
|
|
def cmp(results, expected):
|
|
ret = ""
|
|
self.assertEqual(results.keys(), expected.keys())
|
|
for k, v in results.items():
|
|
exp = expected[k]
|
|
if abs(exp - v) > 1: # allow 1MB rounding differences
|
|
ret += f"{k}: got {v}, expected {exp}\n"
|
|
return ret
|
|
|
|
output = cmp(results, expected)
|
|
self.assertEqual(output, "")
|
|
|
|
@unittest.skipIf(TEST_HPU, "Memory will be different for CUDA and HPU, skipping")
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("ckpt", ["no_ckpt", "ckpt"])
|
|
def test_fsdp_memory(self, ckpt):
|
|
# hidden_dim 128: model size ~4MB
|
|
model_hidden_dim = 128
|
|
|
|
model = create_model(
|
|
with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
|
|
).cuda()
|
|
model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
|
|
del model
|
|
|
|
sharded_model_size_mb = int(model_size_mb / self.world_size)
|
|
|
|
# We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
|
|
# test but on much bigger scale tests). We run 4 iterations here just in case it happens.
|
|
iterations = 4
|
|
|
|
expected = {}
|
|
|
|
for iteration in range(iterations):
|
|
if iteration == 0:
|
|
# sharded model size + 1MB temp memory
|
|
expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
|
|
# it is hard to calculate this memory size, get it from printed memory usage
|
|
if ckpt == "ckpt":
|
|
expected[f"iter {iteration}: after fwd"] = 51
|
|
expected[f"iter {iteration}: after loss"] = 51
|
|
else:
|
|
expected[f"iter {iteration}: after fwd"] = 340
|
|
expected[f"iter {iteration}: after loss"] = 340
|
|
# sharded model size + sharded grad size + 1M temp memory
|
|
expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
|
|
else:
|
|
# after optimizer step in the first iteration, memory usage increased by
|
|
# sharded_model_size_mb because of increased optimizer states memory usage
|
|
expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
|
|
if ckpt == "ckpt":
|
|
expected[f"iter {iteration}: after fwd"] = (
|
|
51 + sharded_model_size_mb
|
|
)
|
|
expected[f"iter {iteration}: after loss"] = (
|
|
51 + sharded_model_size_mb
|
|
)
|
|
else:
|
|
expected[f"iter {iteration}: after fwd"] = (
|
|
340 + sharded_model_size_mb
|
|
)
|
|
expected[f"iter {iteration}: after loss"] = (
|
|
340 + sharded_model_size_mb
|
|
)
|
|
expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
|
|
|
|
# sharded model size + sharded grad size + optimizer states + 1M temp memory
|
|
expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
|
|
# grad memory is claimed after setting grad = None
|
|
# sharded model size + optimizer states + 1M temp memory
|
|
expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
|
|
|
|
# Get the fsdp and checkpoint flags.
|
|
with_ckpt = ckpt == "ckpt"
|
|
|
|
self._dist_train(
|
|
with_ckpt,
|
|
expected,
|
|
model_hidden_dim,
|
|
iterations,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestFSDPMemory)
|
|
if __name__ == "__main__":
|
|
run_tests()
|