mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo][fsdp] Consistent behavior of int attributes (#157262)"
This reverts commit 178fe7aa98987111a73534375099f4ad255e8b59. Reverted https://github.com/pytorch/pytorch/pull/157262 on behalf of https://github.com/huydhn due to This fails some internal tests and needs to be relanded ([comment](https://github.com/pytorch/pytorch/pull/157262#issuecomment-3059463896))
This commit is contained in:
@ -678,88 +678,6 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
outputs = fsdp_m(inputs)
|
outputs = fsdp_m(inputs)
|
||||||
self.assertTrue(same(correct_outputs, outputs))
|
self.assertTrue(same(correct_outputs, outputs))
|
||||||
|
|
||||||
@config.patch(enable_compiler_collectives=True)
|
|
||||||
@skip_if_lt_x_gpu(1)
|
|
||||||
def test_fsdp_dynamism_on_int_attr(self):
|
|
||||||
global GUARDS_FILE
|
|
||||||
GUARDS_FILE = StringIO()
|
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
||||||
|
|
||||||
class ToyModelWithIntAttr(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.attr = 2
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = x + self.attr
|
|
||||||
|
|
||||||
@comptime
|
|
||||||
def _(ctx):
|
|
||||||
ctx.print_guards(file=GUARDS_FILE)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_model_with_int_attr(device):
|
|
||||||
m = ToyModelWithIntAttr().to(device)
|
|
||||||
inputs = torch.rand(10).to(device)
|
|
||||||
outputs = m(inputs)
|
|
||||||
return m, inputs, outputs
|
|
||||||
|
|
||||||
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
|
|
||||||
fsdp_m = FSDP(m, use_orig_params=True)
|
|
||||||
compiled_fsdp_m = torch.compile(
|
|
||||||
fsdp_m, backend="eager", dynamic=True, fullgraph=True
|
|
||||||
)
|
|
||||||
outputs = compiled_fsdp_m(inputs)
|
|
||||||
self.assertTrue(same(correct_outputs, outputs))
|
|
||||||
|
|
||||||
FileCheck().check(
|
|
||||||
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH"""
|
|
||||||
).run(GUARDS_FILE.getvalue())
|
|
||||||
|
|
||||||
@config.patch(enable_compiler_collectives=True)
|
|
||||||
@config.patch(allow_unspec_int_on_fsdp_module=True)
|
|
||||||
@skip_if_lt_x_gpu(1)
|
|
||||||
def test_fsdp_dynamism_on_int_attr_unspec(self):
|
|
||||||
global GUARDS_FILE
|
|
||||||
GUARDS_FILE = StringIO()
|
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
||||||
|
|
||||||
class ToyModelWithIntAttr(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.attr = 2
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = x + self.attr
|
|
||||||
|
|
||||||
@comptime
|
|
||||||
def _(ctx):
|
|
||||||
ctx.print_guards(file=GUARDS_FILE)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_model_with_int_attr(device):
|
|
||||||
m = ToyModelWithIntAttr().to(device)
|
|
||||||
inputs = torch.rand(10).to(device)
|
|
||||||
outputs = m(inputs)
|
|
||||||
return m, inputs, outputs
|
|
||||||
|
|
||||||
m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}")
|
|
||||||
fsdp_m = FSDP(m, use_orig_params=True)
|
|
||||||
compiled_fsdp_m = torch.compile(
|
|
||||||
fsdp_m, backend="eager", dynamic=True, fullgraph=True
|
|
||||||
)
|
|
||||||
outputs = compiled_fsdp_m(inputs)
|
|
||||||
self.assertTrue(same(correct_outputs, outputs))
|
|
||||||
|
|
||||||
# No presence of EQUALS_MATCH because the guard will be dynamic
|
|
||||||
FileCheck().check(
|
|
||||||
"""local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH"""
|
|
||||||
).run(GUARDS_FILE.getvalue())
|
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
def test_ddp_optimizer_cudagraph(self):
|
def test_ddp_optimizer_cudagraph(self):
|
||||||
|
@ -284,13 +284,6 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
|
|||||||
# Defaults to False for BC.
|
# Defaults to False for BC.
|
||||||
allow_unspec_int_on_nn_module = False
|
allow_unspec_int_on_nn_module = False
|
||||||
|
|
||||||
# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions,
|
|
||||||
# integer attributes on FSDP modules were treated as dynamic, while the same
|
|
||||||
# attributes on plain nn.Modules were static. We unified the behaviour by making
|
|
||||||
# FSDP ints static too. Set this flag to True to restore the legacy dynamic
|
|
||||||
# handling if needed.
|
|
||||||
allow_unspec_int_on_fsdp_module = False
|
|
||||||
|
|
||||||
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
|
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
|
||||||
# value or a string. There are 3 modes.
|
# value or a string. There are 3 modes.
|
||||||
# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically
|
# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically
|
||||||
|
@ -2400,15 +2400,6 @@ def is_int_specialization_case(value, source):
|
|||||||
source.guard_source().is_specialized_nn_module()
|
source.guard_source().is_specialized_nn_module()
|
||||||
and not config.allow_unspec_int_on_nn_module
|
and not config.allow_unspec_int_on_nn_module
|
||||||
)
|
)
|
||||||
# integers coming from FSDP modules are considered static. This is
|
|
||||||
# purely empirical and perhaps we should have a better heuristic.
|
|
||||||
or (
|
|
||||||
source.guard_source().is_fsdp_module()
|
|
||||||
and not (
|
|
||||||
config.allow_unspec_int_on_nn_module
|
|
||||||
or config.allow_unspec_int_on_fsdp_module
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or (
|
or (
|
||||||
source.guard_source().is_unspecialized_builtin_nn_module()
|
source.guard_source().is_unspecialized_builtin_nn_module()
|
||||||
and not config.allow_unspec_int_on_nn_module
|
and not config.allow_unspec_int_on_nn_module
|
||||||
|
@ -155,6 +155,17 @@ class GuardSource(enum.Enum):
|
|||||||
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
||||||
|
|
||||||
def is_specialized_nn_module(self) -> bool:
|
def is_specialized_nn_module(self) -> bool:
|
||||||
|
import torch._dynamo.config as config
|
||||||
|
|
||||||
|
if config._unsafe_skip_fsdp_module_guards:
|
||||||
|
return (
|
||||||
|
self
|
||||||
|
in (
|
||||||
|
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
||||||
|
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||||
|
)
|
||||||
|
or self.is_fsdp_module()
|
||||||
|
)
|
||||||
return self in (
|
return self in (
|
||||||
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
||||||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||||
|
Reference in New Issue
Block a user