mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace assert statements with explicit if/raise patterns across 20 files: - _optim_utils.py (38 asserts) - _flat_param.py (25 asserts) - _fully_shard/_fsdp_param.py (23 asserts) - sharded_grad_scaler.py (12 asserts) - fully_sharded_data_parallel.py (11 asserts) - wrap.py (10 asserts) - _state_dict_utils.py (9 asserts) - _fully_shard/_fsdp_param_group.py (8 asserts) - _runtime_utils.py (6 asserts) - _init_utils.py (6 asserts) - 10 additional files (16 asserts) This prevents assertions from being disabled with Python -O flag. Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165235 Approved by: https://github.com/albanD
160 lines
5.7 KiB
Python
160 lines
5.7 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.fsdp._flat_param as flat_param_file
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_apply_to_modules,
|
|
_get_module_fsdp_state,
|
|
clean_tensor_name,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleProfiler:
|
|
class Type(str, Enum):
|
|
ALL = "all"
|
|
ALLGATHER = "all_gather"
|
|
ALLGATHER_OBJ = "all_gather_object"
|
|
RESHARDING = "resharding"
|
|
H2D = "H2D"
|
|
D2H = "D2H"
|
|
|
|
results: dict[str, float] = defaultdict(float)
|
|
profiling: set[str] = set()
|
|
|
|
@classmethod
|
|
def reset(cls) -> None:
|
|
cls.results.clear()
|
|
cls.profiling.clear()
|
|
|
|
@classmethod
|
|
@contextmanager
|
|
def profile(cls, profile_type: str) -> Iterator[None]:
|
|
if profile_type in cls.profiling:
|
|
raise AssertionError(
|
|
f"{profile_type} is already being profiled. "
|
|
"SimpleProfiler does not support profiling multiple instances at "
|
|
"the same time. "
|
|
)
|
|
|
|
cls.profiling.add(profile_type)
|
|
begin = time.monotonic()
|
|
try:
|
|
yield
|
|
finally:
|
|
end = time.monotonic()
|
|
cls.results[profile_type] += end - begin
|
|
cls.profiling.remove(profile_type)
|
|
|
|
@classmethod
|
|
def dump_and_reset(cls, msg: str) -> None:
|
|
# This cannot be combined with DETAIL distributed log
|
|
# as the profiling will be very incorrect.
|
|
if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
|
|
logger.info("%s %s", msg, cls.results)
|
|
cls.reset()
|
|
|
|
|
|
def _get_sharded_module_tree_with_module_name_to_fqns(
|
|
model: torch.nn.Module,
|
|
) -> tuple[str, dict[str, list[str]]]:
|
|
"""
|
|
It is used for composable fully_shard() code path, it returns
|
|
1. sharded module tree info: each line represents a submodule name that contains the
|
|
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
|
|
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
|
|
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
|
|
is like this:
|
|
[CompositeModel] FULLY SHARDED
|
|
l1[Linear]
|
|
u1[UnitModule] FULLY SHARDED
|
|
u1.l1[Linear]
|
|
u1.seq[Sequential]
|
|
u1.seq.0[ReLU]
|
|
u1.seq.1[Linear]
|
|
u1.seq.2[ReLU]
|
|
u1.l2[Linear]
|
|
u2[UnitModule] FULLY SHARDED
|
|
u2.l1[Linear]
|
|
u2.seq[Sequential]
|
|
u2.seq.0[ReLU]
|
|
u2.seq.1[Linear]
|
|
u2.seq.2[ReLU]
|
|
u2.l2[Linear]
|
|
l2[Linear]
|
|
2. a dict mapping from the concated module FQN and class name to a list of its managed
|
|
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
|
|
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
|
|
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
|
|
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
|
|
}
|
|
All FQNs are prefixed starting from ``model``.
|
|
|
|
Args:
|
|
model (torch.nn.Module): Root module (which may or may not be passed to
|
|
composable `fully_shard()`).
|
|
"""
|
|
|
|
def module_fn(
|
|
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
|
|
):
|
|
num_spaces = tree_level * 4
|
|
trimed_prefix = (
|
|
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
|
|
)
|
|
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
|
|
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
|
|
|
|
state = _get_module_fsdp_state(module)
|
|
if state is None:
|
|
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
|
|
return
|
|
|
|
handle = state._fully_sharded_module_to_handle.get(module, None)
|
|
|
|
if handle:
|
|
sharded_tree_info[0] += (
|
|
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
|
|
)
|
|
else:
|
|
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
|
|
|
|
if handle:
|
|
param = handle.flat_param
|
|
if not isinstance(param, flat_param_file.FlatParameter):
|
|
raise AssertionError(f"Expected FlatParameter, got {type(param)}")
|
|
global_fqns = [
|
|
clean_tensor_name(prefix + name) for name in param._fqns
|
|
] # prefixed from the top level `model` (i.e. including `prefix`)
|
|
|
|
if prefixed_module_name in sharded_module_name_to_fqns:
|
|
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
|
|
else:
|
|
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
|
|
|
|
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
|
|
return sharded_tree_info[0], sharded_module_name_to_fqns
|
|
|
|
# Use List to mutate its value in place while running the recursive functions
|
|
sharded_tree_info: list[str] = [
|
|
"",
|
|
]
|
|
sharded_module_name_to_fqns: dict[str, list[str]] = {}
|
|
return _apply_to_modules(
|
|
model,
|
|
module_fn,
|
|
return_fn,
|
|
[key for key, _ in model.named_parameters()],
|
|
sharded_tree_info,
|
|
sharded_module_name_to_fqns,
|
|
)
|