mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136359 Approved by: https://github.com/albanD
148 lines
4.0 KiB
Python
148 lines
4.0 KiB
Python
import argparse
|
|
import functools
|
|
import importlib
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch._dynamo.testing import reduce_to_scalar_loss
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
apply_activation_checkpointing,
|
|
checkpoint_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
|
|
|
|
try:
|
|
from .torchbench import setup_torchbench_cwd
|
|
except ImportError:
|
|
from torchbench import setup_torchbench_cwd
|
|
|
|
from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
|
|
from transformers.models.t5.modeling_t5 import T5Block
|
|
|
|
|
|
def setup(rank, world_size):
|
|
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
|
|
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
|
|
os.environ["RANK"] = os.getenv("RANK", "0")
|
|
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
|
|
dist.init_process_group("nccl")
|
|
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
class CustomLinear(torch.nn.Module):
|
|
def __init__(self, a, b):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.randn(a, b))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.weight)
|
|
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, a, b):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(a, b),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(10, 10000), nn.ReLU()]
|
|
+ [nn.Linear(10000, 10000), nn.ReLU()]
|
|
+ [MyModule(10000, 10000)]
|
|
+ [MyModule(10000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [MyModule(1000, 1000)]
|
|
+ [nn.Linear(1000, 5)]
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
def model_iter_fn(model, example_inputs, collect_outputs=False):
|
|
outputs = model(*example_inputs)
|
|
loss = reduce_to_scalar_loss(outputs)
|
|
loss.backward()
|
|
if collect_outputs:
|
|
return outputs
|
|
|
|
|
|
def get_model(args):
|
|
if args.torchbench_model:
|
|
setup_torchbench_cwd()
|
|
module = importlib.import_module(
|
|
f"torchbenchmark.models.{args.torchbench_model}"
|
|
)
|
|
benchmark_cls = getattr(module, "Model", None)
|
|
bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size)
|
|
model, inputs = bm.get_module()
|
|
elif args.toy_model:
|
|
model = ToyModel()
|
|
inputs = (torch.randn(20, 10),)
|
|
else:
|
|
raise argparse.ArgumentError(
|
|
args.torchbench_model, message="Must specify a model"
|
|
)
|
|
|
|
return model, inputs
|
|
|
|
|
|
def fsdp_checkpointing_base(model, blocks):
|
|
"""apply activation checkpointing to model
|
|
returns None as model is updated directly
|
|
"""
|
|
non_reentrant_wrapper = functools.partial(
|
|
checkpoint_wrapper,
|
|
offload_to_cpu=False,
|
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
|
)
|
|
|
|
def check_fn(submodule):
|
|
return isinstance(submodule, blocks)
|
|
|
|
apply_activation_checkpointing(
|
|
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
|
)
|
|
|
|
|
|
MODEL_FSDP_WRAP = {
|
|
"toy_model": (MyModule,),
|
|
"hf_Bert": (BertLayer, BertLMPredictionHead),
|
|
"hf_T5": (T5Block,),
|
|
}
|
|
|
|
|
|
def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
|
|
wrap_policy = None
|
|
blocks = MODEL_FSDP_WRAP[
|
|
"toy_model" if model.__class__ is ToyModel else args.torchbench_model
|
|
]
|
|
if use_wrap_policy:
|
|
wrap_policy = ModuleWrapPolicy(blocks)
|
|
|
|
model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
|
|
if use_checkpointing:
|
|
fsdp_checkpointing_base(model, blocks)
|
|
return model
|