[dynamo][fsdp] Consistent behavior of int attributes (#157262)

Reimpl of https://github.com/pytorch/pytorch/pull/150954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157262
Approved by: https://github.com/bdhirsh
This commit is contained in:
Animesh Jain
2025-06-30 12:04:14 -07:00
committed by PyTorch MergeBot
parent a9352bd25e
commit 42b48ee672
4 changed files with 98 additions and 11 deletions

View File

@ -678,6 +678,88 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_dynamism_on_int_attr(self):
global GUARDS_FILE
GUARDS_FILE = StringIO()
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class ToyModelWithIntAttr(nn.Module):
def __init__(self):
super().__init__()
self.attr = 2
def forward(self, x):
out = x + self.attr
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
def get_model_with_int_attr(device):
m = ToyModelWithIntAttr().to(device)
inputs = torch.rand(10).to(device)
outputs = m(inputs)
return m, inputs, outputs
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
compiled_fsdp_m = torch.compile(
fsdp_m, backend="eager", dynamic=True, fullgraph=True
)
outputs = compiled_fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
FileCheck().check(
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH"""
).run(GUARDS_FILE.getvalue())
@config.patch(enable_compiler_collectives=True)
@config.patch(allow_unspec_int_on_fsdp_module=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_dynamism_on_int_attr_unspec(self):
global GUARDS_FILE
GUARDS_FILE = StringIO()
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class ToyModelWithIntAttr(nn.Module):
def __init__(self):
super().__init__()
self.attr = 2
def forward(self, x):
out = x + self.attr
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
def get_model_with_int_attr(device):
m = ToyModelWithIntAttr().to(device)
inputs = torch.rand(10).to(device)
outputs = m(inputs)
return m, inputs, outputs
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
compiled_fsdp_m = torch.compile(
fsdp_m, backend="eager", dynamic=True, fullgraph=True
)
outputs = compiled_fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# No presence of EQUALS_MATCH because the guard will be dynamic
FileCheck().check(
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH"""
).run(GUARDS_FILE.getvalue())
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_ddp_optimizer_cudagraph(self):

View File

@ -284,6 +284,13 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
# Defaults to False for BC.
allow_unspec_int_on_nn_module = False
# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions,
# integer attributes on FSDP modules were treated as dynamic, while the same
# attributes on plain nn.Modules were static. We unified the behaviour by making
# FSDP ints static too. Set this flag to True to restore the legacy dynamic
# handling if needed.
allow_unspec_int_on_fsdp_module = False
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
# value or a string. There are 3 modes.
# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically

View File

@ -2398,6 +2398,15 @@ def is_int_specialization_case(value, source):
source.guard_source().is_specialized_nn_module()
and not config.allow_unspec_int_on_nn_module
)
# integers coming from FSDP modules are considered static. This is
# purely empirical and perhaps we should have a better heuristic.
or (
source.guard_source().is_fsdp_module()
and not (
config.allow_unspec_int_on_nn_module
or config.allow_unspec_int_on_fsdp_module
)
)
or (
source.guard_source().is_unspecialized_builtin_nn_module()
and not config.allow_unspec_int_on_nn_module

View File

@ -155,17 +155,6 @@ class GuardSource(enum.Enum):
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
def is_specialized_nn_module(self) -> bool:
import torch._dynamo.config as config
if config._unsafe_skip_fsdp_module_guards:
return (
self
in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
)
or self.is_fsdp_module()
)
return self in (
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,