Revert "[dynamo][fsdp] Consistent behavior of int attributes (#157262)"

This reverts commit 178fe7aa98987111a73534375099f4ad255e8b59.

Reverted https://github.com/pytorch/pytorch/pull/157262 on behalf of https://github.com/huydhn due to This fails some internal tests and needs to be relanded ([comment](https://github.com/pytorch/pytorch/pull/157262#issuecomment-3059463896))
This commit is contained in:
PyTorch MergeBot
2025-07-10 23:11:18 +00:00
parent 1a195bf7d6
commit e517066f41
4 changed files with 11 additions and 98 deletions

View File

@ -678,88 +678,6 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_dynamism_on_int_attr(self):
global GUARDS_FILE
GUARDS_FILE = StringIO()
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class ToyModelWithIntAttr(nn.Module):
def __init__(self):
super().__init__()
self.attr = 2
def forward(self, x):
out = x + self.attr
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
def get_model_with_int_attr(device):
m = ToyModelWithIntAttr().to(device)
inputs = torch.rand(10).to(device)
outputs = m(inputs)
return m, inputs, outputs
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
compiled_fsdp_m = torch.compile(
fsdp_m, backend="eager", dynamic=True, fullgraph=True
)
outputs = compiled_fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
FileCheck().check(
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH"""
).run(GUARDS_FILE.getvalue())
@config.patch(enable_compiler_collectives=True)
@config.patch(allow_unspec_int_on_fsdp_module=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_dynamism_on_int_attr_unspec(self):
global GUARDS_FILE
GUARDS_FILE = StringIO()
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class ToyModelWithIntAttr(nn.Module):
def __init__(self):
super().__init__()
self.attr = 2
def forward(self, x):
out = x + self.attr
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
def get_model_with_int_attr(device):
m = ToyModelWithIntAttr().to(device)
inputs = torch.rand(10).to(device)
outputs = m(inputs)
return m, inputs, outputs
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
compiled_fsdp_m = torch.compile(
fsdp_m, backend="eager", dynamic=True, fullgraph=True
)
outputs = compiled_fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# No presence of EQUALS_MATCH because the guard will be dynamic
FileCheck().check(
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH"""
).run(GUARDS_FILE.getvalue())
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_ddp_optimizer_cudagraph(self):

View File

@ -284,13 +284,6 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
# Defaults to False for BC.
allow_unspec_int_on_nn_module = False
# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions,
# integer attributes on FSDP modules were treated as dynamic, while the same
# attributes on plain nn.Modules were static. We unified the behaviour by making
# FSDP ints static too. Set this flag to True to restore the legacy dynamic
# handling if needed.
allow_unspec_int_on_fsdp_module = False
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
# value or a string. There are 3 modes.
# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically

View File

@ -2400,15 +2400,6 @@ def is_int_specialization_case(value, source):
source.guard_source().is_specialized_nn_module()
and not config.allow_unspec_int_on_nn_module
)
# integers coming from FSDP modules are considered static. This is
# purely empirical and perhaps we should have a better heuristic.
or (
source.guard_source().is_fsdp_module()
and not (
config.allow_unspec_int_on_nn_module
or config.allow_unspec_int_on_fsdp_module
)
)
or (
source.guard_source().is_unspecialized_builtin_nn_module()
and not config.allow_unspec_int_on_nn_module

View File

@ -155,6 +155,17 @@ class GuardSource(enum.Enum):
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
def is_specialized_nn_module(self) -> bool:
import torch._dynamo.config as config
if config._unsafe_skip_fsdp_module_guards:
return (
self
in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
)
or self.is_fsdp_module()
)
return self in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,