add private config to temporarily preserve old FSDP guard behavior (#142871)

Summary: https://github.com/pytorch/pytorch/pull/138819 wobbled dynamo guards in a way that caused some performance regression, so this PR temporarily adds a config to get the old behavior back while we investigate.

Test Plan: CI

Differential Revision: D67096751

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142871
Approved by: https://github.com/yf225
This commit is contained in:
Brian Hirsh
2024-12-13 22:06:48 +00:00
committed by PyTorch MergeBot
parent 8fae4397b4
commit e19f493f02
3 changed files with 18 additions and 0 deletions

View File

@ -546,6 +546,11 @@ automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env(
"TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO"
)
# temporary config to kill later
_unsafe_skip_fsdp_module_guards = (
os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1"
)
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None

View File

@ -1891,6 +1891,8 @@ class GuardBuilder(GuardBuilderBase):
)
def TENSOR_MATCH(self, guard: Guard, value=None):
if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module():
return
# For tensors that are part of the Dynamo extracted Fx graph module, an
# ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these
# will be lifted as inputs and have a TENSOR_MATCH guard.

View File

@ -114,6 +114,17 @@ class GuardSource(enum.Enum):
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
def is_specialized_nn_module(self) -> bool:
import torch._dynamo.config as config
if config._unsafe_skip_fsdp_module_guards:
return (
self
in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
)
or self.is_fsdp_module()
)
return self in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,