Compare commits

...

1 Commits

2 changed files with 6 additions and 0 deletions

View File

@ -3213,6 +3213,8 @@ if torch.distributed.is_available():
# the forward_hook won't be ignored.
"torch.distributed._composable.replicate",
}
if not torch._dynamo.config.skip_fsdp_hooks:
LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp")
# Force inline functions under these modules, even they are in *_SKIPLIST.
@ -3263,6 +3265,8 @@ if torch.distributed.is_available():
MOD_INLINELIST.add("torch.distributed")
MOD_INLINELIST.add("torch.distributed._functional_collectives")
MOD_INLINELIST.add("torch.distributed._composable.replicate")
if not torch._dynamo.config.skip_fsdp_hooks:
MOD_INLINELIST.add("torch.distributed._composable.fsdp")
@functools.lru_cache(None)

View File

@ -351,12 +351,14 @@ def _register_group_forward_hooks(
"""
modules_set = set(modules)
@disable_if_config_true
@functools.wraps(pre_hook)
def wrapped_pre_hook(*args: Any, **kwargs: Any):
if len(modules_to_run) == 0: # first to run
modules_to_run.update(modules_set)
return pre_hook(*args, **kwargs)
@disable_if_config_true
def get_wrapped_post_hook(module: nn.Module):
@functools.wraps(post_hook)
def wrapped_post_hook(*args: Any, **kwargs: Any):