mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][fsdp] Consistent behavior of int attributes (#157262)
Reimpl of https://github.com/pytorch/pytorch/pull/150954 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157262 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9352bd25e
commit
42b48ee672
@ -678,6 +678,88 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
outputs = fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
@config.patch(enable_compiler_collectives=True)
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fsdp_dynamism_on_int_attr(self):
|
||||
global GUARDS_FILE
|
||||
GUARDS_FILE = StringIO()
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
|
||||
class ToyModelWithIntAttr(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attr = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.attr
|
||||
|
||||
@comptime
|
||||
def _(ctx):
|
||||
ctx.print_guards(file=GUARDS_FILE)
|
||||
|
||||
return out
|
||||
|
||||
def get_model_with_int_attr(device):
|
||||
m = ToyModelWithIntAttr().to(device)
|
||||
inputs = torch.rand(10).to(device)
|
||||
outputs = m(inputs)
|
||||
return m, inputs, outputs
|
||||
|
||||
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
compiled_fsdp_m = torch.compile(
|
||||
fsdp_m, backend="eager", dynamic=True, fullgraph=True
|
||||
)
|
||||
outputs = compiled_fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
FileCheck().check(
|
||||
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH"""
|
||||
).run(GUARDS_FILE.getvalue())
|
||||
|
||||
@config.patch(enable_compiler_collectives=True)
|
||||
@config.patch(allow_unspec_int_on_fsdp_module=True)
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fsdp_dynamism_on_int_attr_unspec(self):
|
||||
global GUARDS_FILE
|
||||
GUARDS_FILE = StringIO()
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
|
||||
class ToyModelWithIntAttr(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attr = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.attr
|
||||
|
||||
@comptime
|
||||
def _(ctx):
|
||||
ctx.print_guards(file=GUARDS_FILE)
|
||||
|
||||
return out
|
||||
|
||||
def get_model_with_int_attr(device):
|
||||
m = ToyModelWithIntAttr().to(device)
|
||||
inputs = torch.rand(10).to(device)
|
||||
outputs = m(inputs)
|
||||
return m, inputs, outputs
|
||||
|
||||
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
compiled_fsdp_m = torch.compile(
|
||||
fsdp_m, backend="eager", dynamic=True, fullgraph=True
|
||||
)
|
||||
outputs = compiled_fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
# No presence of EQUALS_MATCH because the guard will be dynamic
|
||||
FileCheck().check(
|
||||
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH"""
|
||||
).run(GUARDS_FILE.getvalue())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_ddp_optimizer_cudagraph(self):
|
||||
|
@ -284,6 +284,13 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
|
||||
# Defaults to False for BC.
|
||||
allow_unspec_int_on_nn_module = False
|
||||
|
||||
# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions,
|
||||
# integer attributes on FSDP modules were treated as dynamic, while the same
|
||||
# attributes on plain nn.Modules were static. We unified the behaviour by making
|
||||
# FSDP ints static too. Set this flag to True to restore the legacy dynamic
|
||||
# handling if needed.
|
||||
allow_unspec_int_on_fsdp_module = False
|
||||
|
||||
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
|
||||
# value or a string. There are 3 modes.
|
||||
# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically
|
||||
|
@ -2398,6 +2398,15 @@ def is_int_specialization_case(value, source):
|
||||
source.guard_source().is_specialized_nn_module()
|
||||
and not config.allow_unspec_int_on_nn_module
|
||||
)
|
||||
# integers coming from FSDP modules are considered static. This is
|
||||
# purely empirical and perhaps we should have a better heuristic.
|
||||
or (
|
||||
source.guard_source().is_fsdp_module()
|
||||
and not (
|
||||
config.allow_unspec_int_on_nn_module
|
||||
or config.allow_unspec_int_on_fsdp_module
|
||||
)
|
||||
)
|
||||
or (
|
||||
source.guard_source().is_unspecialized_builtin_nn_module()
|
||||
and not config.allow_unspec_int_on_nn_module
|
||||
|
@ -155,17 +155,6 @@ class GuardSource(enum.Enum):
|
||||
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
||||
|
||||
def is_specialized_nn_module(self) -> bool:
|
||||
import torch._dynamo.config as config
|
||||
|
||||
if config._unsafe_skip_fsdp_module_guards:
|
||||
return (
|
||||
self
|
||||
in (
|
||||
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
||||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||
)
|
||||
or self.is_fsdp_module()
|
||||
)
|
||||
return self in (
|
||||
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
||||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||
|
Reference in New Issue
Block a user