mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Swaps all logging.warn calls to logging.warning since the former is deprecated and even raises a deprecation warning now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94708 Approved by: https://github.com/ezyang
177 lines
5.5 KiB
Python
177 lines
5.5 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch._dynamo as dynamo
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.testing import reduce_to_scalar_loss
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.profiler import profile, ProfilerActivity, record_function
|
|
|
|
try:
|
|
from .common import timed
|
|
from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
|
|
except ImportError:
|
|
from common import timed
|
|
from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def torchviz_model(args, model, inputs, rank):
|
|
from torchviz import make_dot
|
|
|
|
outputs = model(*inputs)
|
|
loss = reduce_to_scalar_loss(outputs)
|
|
parameter_names = dict(model.named_parameters())
|
|
dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
|
|
if rank == 0:
|
|
dot.render("torchviz.dot")
|
|
|
|
|
|
def profile_model(args, model, inputs, rank):
|
|
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
|
|
for i in range(args.repeat):
|
|
with record_function("Forward"):
|
|
outputs = model(*inputs)
|
|
loss = reduce_to_scalar_loss(outputs)
|
|
with record_function("Backward"):
|
|
loss.backward()
|
|
if rank == 0:
|
|
prof.export_chrome_trace(args.trace_file)
|
|
|
|
|
|
def run_model(args, model, inputs, key):
|
|
rank = int(os.getenv("RANK", 0))
|
|
world_size = int(os.getenv("WORLD_SIZE", 1))
|
|
# result_q = []
|
|
|
|
setup(rank, world_size)
|
|
if args.device == "cuda":
|
|
# needed for FSDP
|
|
torch.cuda.set_device(rank)
|
|
|
|
dev_rank = f"{args.device}:{rank}"
|
|
model = model.to(dev_rank)
|
|
|
|
def move_tensor(maybe_tensor):
|
|
if torch.is_tensor(maybe_tensor):
|
|
return maybe_tensor.to(dev_rank)
|
|
return maybe_tensor
|
|
|
|
inputs = pytree.tree_map(move_tensor, inputs)
|
|
|
|
if args.fsdp:
|
|
model = apply_fsdp(
|
|
args,
|
|
model,
|
|
use_checkpointing=args.fsdp_checkpoint,
|
|
use_wrap_policy=args.fsdp_wrap,
|
|
)
|
|
elif args.ddp:
|
|
model = DDP(model)
|
|
|
|
if args.verbose:
|
|
print(model)
|
|
|
|
if args.dynamo:
|
|
dynamo.reset()
|
|
if args.verbose:
|
|
dynamo.config.verbose = True
|
|
dynamo.config.log_level = logging.DEBUG
|
|
if args.dynamo_no_optimize_ddp:
|
|
dynamo.config.optimize_ddp = False
|
|
if args.dynamo == "inductor" and args.fsdp:
|
|
torch._inductor.config.triton.cudagraphs = False
|
|
log.warning("disabling inductor cudagraphs for compatibility with FSDP")
|
|
|
|
def print_compile(gm, ex):
|
|
print(
|
|
f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
|
|
)
|
|
return gm
|
|
|
|
dynamo_ctx = dynamo.optimize(
|
|
print_compile if args.dynamo == "print" else args.dynamo
|
|
)
|
|
model = dynamo_ctx(model)
|
|
|
|
# warmup
|
|
_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
|
|
t_total = timed(
|
|
model, model_iter_fn, inputs, times=args.repeat, return_result=False
|
|
)
|
|
if args.torchviz:
|
|
torchviz_model(args, model, inputs, rank)
|
|
if args.profile:
|
|
profile_model(args, model, inputs, rank)
|
|
|
|
cleanup()
|
|
return t_total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--device", default="cuda")
|
|
parser.add_argument(
|
|
"--dynamo",
|
|
default=None,
|
|
help="if set to a str, uses dynamo[str] backend. else, eager",
|
|
)
|
|
parser.add_argument("--verbose", action="store_true")
|
|
parser.add_argument("--batch-size", "--batch_size", default=None)
|
|
parser.add_argument(
|
|
"--torchviz", action="store_true", help="Dump autograd graph with torchviz"
|
|
)
|
|
parser.add_argument("--profile", action="store_true", help="Run the profiler")
|
|
parser.add_argument(
|
|
"--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
|
|
)
|
|
parser.add_argument("--repeat", default=10, help="Repeats for timing run")
|
|
parser.add_argument(
|
|
"--dynamo-no-optimize-ddp",
|
|
"--dynamo_no_optimize_ddp",
|
|
action="store_true",
|
|
help="Disable dynamo's ddp optimizer (enabled by default)",
|
|
)
|
|
parser.add_argument(
|
|
"--fsdp-checkpoint",
|
|
"--fsdp_checkpoint",
|
|
action="store_true",
|
|
help="Use gradient checkpointing via model-specific policy",
|
|
)
|
|
parser.add_argument(
|
|
"--fsdp-wrap",
|
|
"--fsdp_wrap",
|
|
action="store_true",
|
|
help="Apply fsdp to submodules via model-specific policy",
|
|
)
|
|
|
|
dist_arg = parser.add_mutually_exclusive_group()
|
|
dist_arg.add_argument("--ddp", action="store_true")
|
|
dist_arg.add_argument("--fsdp", action="store_true")
|
|
|
|
model_arg = parser.add_mutually_exclusive_group(required=True)
|
|
model_arg.add_argument(
|
|
"--torchbench-model",
|
|
"--torchbench_model",
|
|
help="name of torchbench model, e.g. hf_Bert",
|
|
)
|
|
model_arg.add_argument(
|
|
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
model_name = args.torchbench_model
|
|
if args.toy_model:
|
|
model_name = "ToyModel"
|
|
model, inputs = get_model(args)
|
|
|
|
fn = partial(run_model, args, model, inputs)
|
|
|
|
world_size = os.getenv("WORLD_SIZE", 1)
|
|
t_total = fn(f"{model_name}_{world_size}")
|
|
print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
|