[Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730)

Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).

This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.

------

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
This commit is contained in:
Will Feng
2024-09-11 12:39:07 -07:00
committed by PyTorch MergeBot
parent 5ca46be15e
commit 94d2471d1f
11 changed files with 481 additions and 235 deletions

View File

@ -4,7 +4,10 @@
import contextlib
import copy
import functools
import itertools
import logging
import unittest
from collections import defaultdict
from unittest import mock
import torch
@ -30,8 +33,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.utils._triton import has_triton
def _is_op_in_graph(graph, op):
return any(node.target is op for node in graph.nodes)
log = logging.getLogger(__name__)
def _count_op_in_graph(graph, op):
return sum(1 for node in graph.nodes if node.target is op)
def _is_fallback_op_in_snodes(snodes, op):
@ -130,7 +136,7 @@ class TestFullyShardCompile(FSDPTest):
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
def test_trace_fsdp_set_(self):
def test_trace_fsdp_copy_(self):
@torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
torch.add(x, 1, out=out)
@ -140,7 +146,7 @@ class TestFullyShardCompile(FSDPTest):
buf_view = buf.view(-1)
torch.ops.mylib.add_one_out(x, out=buf_view)
buf_view2 = buf.view(-1)
torch.ops.fsdp.set_(x, buf_view2)
torch.ops.fsdp.copy_(x, buf_view2)
ref_x = torch.zeros(2)
x = copy.deepcopy(ref_x)
@ -148,26 +154,80 @@ class TestFullyShardCompile(FSDPTest):
torch.compile(f, backend="aot_eager")(x)
self.assertEqual(x, ref_x)
def _assert_no_aliased_graph_inputs(self, graph: torch.fx.Graph) -> None:
storage_id_to_graph_inputs = defaultdict(list)
for node in graph.nodes:
if node.op == "placeholder" and isinstance(
node.meta.get("val", None), torch.Tensor
):
storage_id_to_graph_inputs[
id(node.meta["val"].untyped_storage())
].append(node)
no_aliased_graph_inputs = True
err_msg = ""
for aliased_graph_inputs in storage_id_to_graph_inputs.values():
if len(aliased_graph_inputs) > 1:
no_aliased_graph_inputs = False
err_msg += f"""\n
Found aliased graph inputs: {aliased_graph_inputs},
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
"""
self.assertTrue(no_aliased_graph_inputs, err_msg)
def _check_fsdp_copy_and_resize_ops_count_in_graph(
self,
graph,
*,
fwd_copy_count,
fwd_resize_count,
bwd_copy_count,
bwd_resize_count,
):
def _check_count(copy_count, resize_count):
actual_copy_count = _count_op_in_graph(graph, torch.ops.fsdp.copy_.default)
self.assertEqual(
actual_copy_count,
copy_count,
f"Unexpected number of `fsdp.copy_` ops (expected {copy_count}, got {actual_copy_count}) in graph: {graph}",
)
actual_resize_count = _count_op_in_graph(
graph, torch.ops.inductor.resize_storage_bytes_.default
)
self.assertEqual(
actual_resize_count,
resize_count,
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
)
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
else:
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
def _reinplace_all_gather_with_optional_checks(self, fullgraph):
def _run_with_checks(graph, orig_fn):
self.assertTrue(
_is_op_in_graph(
graph,
torch.ops._c10d_functional.all_gather_into_tensor.default,
)
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
orig_fn(graph)
self.assertFalse(
_is_op_in_graph(
graph,
torch.ops._c10d_functional.all_gather_into_tensor.default,
)
self.assertEqual(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
self.assertTrue(
_is_op_in_graph(
graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
)
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default
),
0,
)
if fullgraph:
@ -266,8 +326,6 @@ class TestFullyShardCompile(FSDPTest):
self,
file_check,
overlapped_compute_op_str,
num_resize,
num_set,
last_all_gather=False,
):
file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.")
@ -278,16 +336,9 @@ class TestFullyShardCompile(FSDPTest):
# Checks that AGWait is delayed, making the AG overlap with some compute op.
if overlapped_compute_op_str is not None:
file_check = file_check.check(f"{overlapped_compute_op_str}")
file_check = file_check.check_count(
"inductor_ops.resize_storage_bytes_(", num_resize, exactly=True
)
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
file_check = self.inductor_code_check_no_compute_op(file_check)
file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.")
file_check = self.inductor_code_check_no_compute_op(file_check)
file_check = file_check.check_count(
"torch.ops.aten.set_.", num_set, exactly=True
)
if not last_all_gather:
# Checks that there is no compute op between this AGWait and next AG.
file_check = self.inductor_code_check_no_compute_op(file_check)
@ -307,20 +358,6 @@ class TestFullyShardCompile(FSDPTest):
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
return file_check
@torch._dynamo.config.patch(
inline_inbuilt_nn_modules=True,
skip_fsdp_hooks=False,
)
@torch._functorch.config.patch(recompute_views=True)
@torch._functorch.config.patch(cse=False)
@torch._inductor.config.patch(
reorder_for_compute_comm_overlap=True,
reorder_for_compute_comm_overlap_passes=[
"sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
],
)
def _test_traceable_fsdp(
self, model_init_fn, input_creation_fn, backend, fullgraph
):
@ -334,7 +371,12 @@ class TestFullyShardCompile(FSDPTest):
return _fn
def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None):
def run_iters(
model,
optim,
n_iter=10,
compiled_autograd_backend=None,
):
torch.manual_seed(42)
losses = []
for i in range(n_iter):
@ -360,7 +402,11 @@ class TestFullyShardCompile(FSDPTest):
run_iters(model, optim, n_iter=1)
model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph)
res = run_iters(model_compiled, optim, compiled_autograd_backend=backend)
res = run_iters(
model_compiled,
optim,
compiled_autograd_backend=backend,
)
return res
def test_eager():
@ -371,7 +417,23 @@ class TestFullyShardCompile(FSDPTest):
res = run_iters(model, optim)
return res
losses_compiled = test_compiled()
with torch._dynamo.config.patch(
inline_inbuilt_nn_modules=True,
skip_fsdp_hooks=False,
), torch._functorch.config.patch(
recompute_views=True, cse=False
), torch._inductor.config.patch(
reorder_for_compute_comm_overlap=True,
reorder_for_compute_comm_overlap_passes=[
"sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
],
post_grad_custom_pre_pass=self._assert_no_aliased_graph_inputs
if fullgraph
else None,
):
losses_compiled = test_compiled()
losses_eager = test_eager()
if not self.fake_pg:
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
@ -448,9 +510,9 @@ class TestFullyShardCompile(FSDPTest):
)
def forward(self, x):
ret = torch.matmul(x, self.param1)
if not fullgraph:
torch._dynamo.graph_break()
ret = torch.matmul(x, self.param1)
ret = ret * self.param2
ret = torch.relu(ret)
return ret
@ -519,7 +581,19 @@ class TestFullyShardCompile(FSDPTest):
for fullgraph in [True, False]:
with self._reinplace_all_gather_with_optional_checks(
fullgraph
), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph):
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_copy_count=0,
fwd_resize_count=0,
bwd_copy_count=0,
bwd_resize_count=0,
)
if fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_nested_fully_shard_factory_fns(
@ -537,46 +611,30 @@ class TestFullyShardCompile(FSDPTest):
fwd_code = triton_codes[0]
file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [
dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
dict(overlapped_compute_op_str=None),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=2,
num_set=2,
last_all_gather=True,
),
]:
@ -588,16 +646,12 @@ class TestFullyShardCompile(FSDPTest):
bwd_code = triton_codes[1]
file_check = FileCheck().check("def call(args):")
for bwd_ag_block_info in [
dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
dict(overlapped_compute_op_str=None),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=0,
num_set=2,
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=0,
num_set=2,
last_all_gather=True,
),
]:
@ -605,7 +659,7 @@ class TestFullyShardCompile(FSDPTest):
file_check, **bwd_ag_block_info
)
for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.mm("),
dict(overlapped_compute_op_str="extern_kernels.addmm("),
dict(
overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
@ -623,9 +677,10 @@ class TestFullyShardCompile(FSDPTest):
"Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
)
def _create_transformer_factory_fns(self):
def _create_transformer_factory_fns(self, all_requires_grad):
seq_len = 16
vocab_size = 8
n_layers = 3
def model_init_fn():
torch.manual_seed(self.rank)
@ -633,9 +688,20 @@ class TestFullyShardCompile(FSDPTest):
mesh = init_device_mesh("cuda", (self.world_size,))
model_args = ModelArgs(
vocab_size=vocab_size,
n_layers=3,
n_layers=n_layers,
)
model = Transformer(model_args)
if not all_requires_grad:
requires_grad_params = ["attention.wq", "attention.wv"]
requires_grad_param_count = 0
for k, v in model.named_parameters():
for substring in requires_grad_params:
if substring in k:
v.requires_grad_(True)
requires_grad_param_count += 1
else:
v.requires_grad_(False)
assert requires_grad_param_count == n_layers * len(requires_grad_params)
for layer_id, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model = fully_shard(
@ -672,12 +738,16 @@ class TestFullyShardCompile(FSDPTest):
@skipIfRocm
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_transformer_backend_aot_eager(self):
for fullgraph in [True, False]:
for fullgraph, all_requires_grad in itertools.product(
[True, False], [True, False]
):
with self._maybe_add_graph_break_to_sdpa(
fullgraph
), self._reinplace_all_gather_with_optional_checks(fullgraph):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(),
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
),
"aot_eager",
fullgraph=fullgraph,
)
@ -687,10 +757,14 @@ class TestFullyShardCompile(FSDPTest):
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_aot_eager_decomp_partition(self):
for fullgraph in [True, False]:
for fullgraph, all_requires_grad in itertools.product(
[True, False], [True, False]
):
with self._maybe_add_graph_break_to_sdpa(fullgraph):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(),
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
),
"aot_eager_decomp_partition",
fullgraph=fullgraph,
)
@ -700,17 +774,36 @@ class TestFullyShardCompile(FSDPTest):
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_inductor(self):
for fullgraph in [True, False]:
# TODO: enable fullgraph=False case
for fullgraph, all_requires_grad in itertools.product([True], [True, False]):
log.warning(
f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}" # noqa: G004, G001
)
with self._maybe_add_graph_break_to_sdpa(
fullgraph
), self._reinplace_all_gather_with_optional_checks(
fullgraph
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
# the parameters would be freed and all-gathered immediately. Hence we still have
# their resize and copy ops in the graph.
fwd_copy_count=4,
fwd_resize_count=4,
bwd_copy_count=0,
bwd_resize_count=4,
)
if fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_transformer_factory_fns(),
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
),
"inductor",
fullgraph=fullgraph,
)
@ -723,21 +816,19 @@ class TestFullyShardCompile(FSDPTest):
fwd_code = triton_codes[0]
file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [
dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4),
dict(
overlapped_compute_op_str="triton_"
if all_requires_grad
else None,
),
dict(
overlapped_compute_op_str="aten.native_dropout.",
num_resize=0,
num_set=12,
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
num_resize=12,
num_set=12,
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
num_resize=12,
num_set=12,
last_all_gather=True,
),
]:
@ -751,35 +842,33 @@ class TestFullyShardCompile(FSDPTest):
for bwd_ag_block_info in [
dict(
overlapped_compute_op_str="extern_kernels.mm(",
num_resize=0,
num_set=12,
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
num_resize=0,
num_set=12,
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
num_resize=0,
num_set=12,
last_all_gather=True,
),
]:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **bwd_ag_block_info
)
if bwd_ag_block_info is not None:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **bwd_ag_block_info
)
for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.mm("),
dict(overlapped_compute_op_str="extern_kernels.mm(")
if all_requires_grad
else None,
dict(
overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
dict(overlapped_compute_op_str=None),
dict(overlapped_compute_op_str=None),
dict(overlapped_compute_op_str=None) if all_requires_grad else None,
]:
file_check = self.inductor_code_check_fsdp_reduce_scatter(
file_check, **bwd_rs_block_info
)
if bwd_rs_block_info is not None:
file_check = self.inductor_code_check_fsdp_reduce_scatter(
file_check, **bwd_rs_block_info
)
file_check.run(bwd_code)
else:
# TODO: when fullgraph=False and there is graph break in FWD graph,

View File

@ -5557,7 +5557,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
z0 = x.sin()
z1 = x.sin()
y = x + 1
torch.ops.fsdp.set_.default(x, y)
torch.ops.fsdp.copy_.default(x, y)
# z3 and z3 can be CSEd with each other,
# but *not* with z0/z1 (they cross a mutation boundary)
z2 = x.sin()
@ -5589,7 +5589,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
z = x.sin()
y = x + 1
# graph input has its storage mutated
torch.ops.fsdp.set_.default(x, y)
torch.ops.fsdp.copy_.default(x, y)
z2 = x.sin()
return z2, l**2

View File

@ -725,59 +725,6 @@ def forward(self, primals_1):
return (add,)""",
)
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case")
@unittest.skipIf(IS_MACOS, "TODO: need to fix the test case")
def test_input_mutation_fsdp_set__into_same_input(self):
import torch.distributed._composable.fsdp._fsdp_param
def f(a):
b = torch.arange(9, dtype=a.dtype).view(3, 3)
c = torch.arange(9, dtype=a.dtype).view(3, 3)
d = torch.arange(9, dtype=a.dtype).view(3, 3)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
torch.ops.fsdp.set_.default(a, b)
x = a * a
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
torch.ops.fsdp.set_.default(a, c)
y = a * a
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
torch.ops.fsdp.set_.default(a, c)
z = a * a
return x + y + z
inp = [torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
"""
Expected behavior:
(1) When there are multiple set_() calls on the same graph input primal_X,
we want those set_() calls to all show up with primal_X as the first arg in the graph.
(2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892),
but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior.
"""
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None
set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None
mul = torch.ops.aten.mul.Tensor(primals_1, primals_1)
set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None
mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None
mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None
add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None
return (add_1, primals_1)""",
)
self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp))
def test_input_mutation_simple_with_none_and_nontensor(self):
# Tensor, None, int
def f(a, b, c):

View File

@ -26,22 +26,13 @@ def init_fake_distributed(device="cpu"):
return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
def fw_pre_hook(mod, inp):
if not compiled_autograd.compiled_autograd_enabled:
# torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
else:
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.set_(
mod.unsharded_weight, all_gather(mod.sharded_weight)
)
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Forward:
@ -58,22 +49,13 @@ def init_fake_distributed(device="cpu"):
mod.unsharded_weight.untyped_storage().resize_(0)
def bw_pre_hook(mod, gO):
if not compiled_autograd.compiled_autograd_enabled:
# torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
else:
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.set_(
mod.unsharded_weight, all_gather(mod.sharded_weight)
)
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Backward:
@ -465,7 +447,6 @@ class DistributedPatternTests(TestCase):
@requires_gpu()
@torch._functorch.config.patch(recompute_views=True)
def test_fake_distributed_inductor(self):
# TODO: fix .set_ lowering in CPU inductor, and enable the CPU test.
m1, inp1 = init_fake_distributed(GPU_TYPE)
out1 = steps(m1, inp1)

View File

@ -1447,7 +1447,7 @@ class TestPatternMatcher(TestCase):
(t, [64, 128, 8, 8]),
{"dim": 1, "out": [t, t, t, t]},
)
check("call_function", torch.ops.fsdp.set_, (t, t), {})
check("call_function", torch.ops.fsdp.copy_, (t, t), {})
check(
"call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False
)

View File

@ -405,8 +405,8 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
torch.ops.aten.copy_.default,
torch.ops.aten.set_.source_Tensor,
]
if hasattr(torch.ops.fsdp, "set_"):
allowed_mutation_ops.append(torch.ops.fsdp.set_.default)
if hasattr(torch.ops.fsdp, "copy_"):
allowed_mutation_ops.append(torch.ops.fsdp.copy_.default)
placeholders = set()
mutation_count = 0

View File

@ -1272,6 +1272,7 @@ def get_default_op_list() -> OpTypes:
aten.expand,
aten.as_strided,
aten.permute,
aten.select,
]
view_ops = recomputable_view_ops
default_recomputable_ops += [

View File

@ -3,12 +3,14 @@
from __future__ import annotations
import heapq
import logging
import operator
import sys
from collections import defaultdict
from typing import Dict, List, Set, TYPE_CHECKING
import torch
from torch.multiprocessing.reductions import StorageWeakRef
from . import config, ir
from .dependencies import WeakDep
@ -23,6 +25,7 @@ from .utils import (
)
log = logging.getLogger(__name__)
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
if TYPE_CHECKING:
@ -342,6 +345,202 @@ def reorder_compute_and_comm_for_overlap(
return order
def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
"""
This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding
graph intermediates that were fsdp.copy_ into the unsharded params in the original graph.
NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern
(or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case
where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't
remove these resize and copy ops and thus we will have worse performance there.
In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param"
is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern
(in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed.
"""
node_list = list(graph.nodes)
# Find all graph inputs and their resize counts
graph_input_to_resized_to_full_node_idxes = defaultdict(list)
graph_input_to_resized_to_0_node_idxes = defaultdict(list)
for idx, node in enumerate(node_list):
if (
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
):
assert (
node.args[0].op == "placeholder"
), f"""\
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
"""
graph_input = node.args[0]
new_size = node.args[1]
if new_size > 0:
graph_input_to_resized_to_full_node_idxes[graph_input].append(idx)
else:
graph_input_to_resized_to_0_node_idxes[graph_input].append(idx)
def check_resize_pattern(graph_input):
# Check the number of resize-to-full and resize-to-0 nodes are equal,
# and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node
# always happens before the resize-to-0 node.
# This is the precondition for being able to remove all the resize and copy nodes
# for this specific unsharded param.
resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get(
graph_input, []
)
resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
if not len(resized_to_full_idxes) == len(resized_to_0_idxes):
log.warning(
f"""
Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:
{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}.
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass.
""" # noqa: G004
)
return False
# Check the sequence: (resize_to_full -> resize_to_0)+
for resize_to_full_idx, resize_to_0_idx in zip(
resized_to_full_idxes, resized_to_0_idxes
):
if resize_to_full_idx >= resize_to_0_idx:
log.warning(
f"""
For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx}
happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}.
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param.
""" # noqa: G004
)
return False
return True
# Find all eligible unsharded params and their corresponding graph intermediates.
unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
for idx, node in enumerate(node_list):
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
fsdp_copy_node = node
unsharded_param = node.args[0]
assert (
unsharded_param.op == "placeholder"
), f"""
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
Offending node: {unsharded_param}. Graph: {graph}
"""
if check_resize_pattern(unsharded_param):
unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx)
def is_allowed_mutation(node):
return (
node.target == torch.ops.fsdp.copy_.default
or node.target == torch.ops.inductor.resize_storage_bytes_.default
)
def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
# Check whether the node is mutating any of the unsharded params or their aliases.
mutated_arg_idxes = (
[
i
for i, x in enumerate(node.target._schema.arguments)
if x.alias_info is not None and x.alias_info.is_write
]
if isinstance(node.target, torch._ops.OpOverload)
else []
)
mutated_node_arg_storages = {
StorageWeakRef(node.args[i].meta["val"].untyped_storage())
for i in mutated_arg_idxes
}
storages_of_unsharded_params = {
StorageWeakRef(unsharded_param.meta["val"].untyped_storage())
for unsharded_param in unsharded_params
}
return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0
# Check no user mutation on any unsharded_param
for node in node_list:
if (
node.op == "call_function"
and isinstance(node.target, torch._ops.OpOverload)
and node.target._schema.is_mutable
and not is_allowed_mutation(node)
):
assert not is_node_mutating_unsharded_param_or_its_alias(
node, unsharded_param_to_fsdp_copy_node_idxes.keys()
), f"""\
User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node}
"""
# For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`.
#
# NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input.
# e.g.
# ```
# fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1)
# ... (use of unsharded_param_1) -> Subgraph 1
# fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2)
# ... (use of unsharded_param_1) -> Subgraph 2
# fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3)
# ... (use of unsharded_param_1) -> Subgraph 3
# ```
# We must do the replacement only within each subgraph.
for (
unsharded_param,
fsdp_copy_node_idxes,
) in unsharded_param_to_fsdp_copy_node_idxes.items():
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
fsdp_copy_node = node_list[fsdp_copy_node_idx]
assert fsdp_copy_node.args[0] is unsharded_param
_, replacement = fsdp_copy_node.args
# subgraph_start_idx is exclusive
subgraph_start_idx = fsdp_copy_node_idx + 1
# subgraph_end_idx is exclusive (also intentionally don't replace args in return op)
subgraph_end_idx = (
fsdp_copy_node_idxes[i + 1]
if i < len(fsdp_copy_node_idxes) - 1
else len(node_list) - 1
)
subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx]
assert not any(
is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param])
for node in subgraph_nodes
), f"""\
Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true!
Graph: {graph}
"""
for node in subgraph_nodes:
if (
node.op == "call_function"
and unsharded_param in node.args
and node.target != torch.ops.inductor.resize_storage_bytes_.default
): # TODO(yf225): implement replacement in kwargs
new_args = tuple(
replacement if arg is unsharded_param else arg
for arg in node.args
)
node.args = new_args
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
for (
unsharded_param,
fsdp_copy_node_idxes,
) in unsharded_param_to_fsdp_copy_node_idxes.items():
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
fsdp_copy_node = node_list[fsdp_copy_node_idx]
graph.erase_node(fsdp_copy_node)
# Delete `resize_(unsharded_param, ...)` nodes
for node in node_list:
if (
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
):
graph.erase_node(node)
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
try:
import torch.distributed._composable.fsdp._fsdp_collectives
@ -509,12 +708,11 @@ def enforce_comm_ordering_for_fsdp(
name_to_fused_node,
)
# Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
# Find the "all_gather + all_gather_wait_tensor + copy_out" code block
allowed_ops = {
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
torch.ops._c10d_functional.wait_tensor.default,
torch.ops.fsdp.split_with_sizes_copy.default,
torch.ops.aten.set_.source_Tensor,
}
find_recursive_users_of_node(
ag_snode,
@ -560,7 +758,7 @@ def enforce_comm_ordering_for_fsdp(
assert wait_node_idx is not None
ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
# Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
# Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode
ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node

View File

@ -22,6 +22,7 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from .. import config, ir, pattern_matcher
from ..codegen.common import BackendFeature, has_backend_feature
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
from ..lowering import lowerings as L
from ..pattern_matcher import (
@ -76,6 +77,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
The IR here has been normalized and functionalized.
"""
if not torch._dynamo.config.skip_fsdp_hooks:
remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
if config.dce:
# has some issues with mutation in inference mode
gm.graph.eliminate_dead_code()

View File

@ -6102,13 +6102,17 @@ def set__source_tensor(self, source_tensor):
return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
if hasattr(torch.ops.fsdp, "set_"):
if hasattr(torch.ops.fsdp, "copy_"):
@register_lowering(torch.ops.fsdp.set_.default)
def fsdp_set_(self, source_tensor):
self.realize()
source_tensor.realize()
ir.SetSourceTensorKernel(self, source_tensor)
@register_lowering(torch.ops.fsdp.copy_.default)
def fsdp_copy_(dst, src):
if dst is src:
# dst.copy_(dst) can happen from the reinplacing pass
return dst
src = to_device(src, dst.get_device())
src = to_dtype(src, dst.get_dtype())
src = expand(src, dst.get_size())
return mutate_to(dst, src)
@register_lowering(torch.ops.aten.resize)

View File

@ -65,73 +65,77 @@ data, so we use storage resizing on the all-gather output.
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()")
@torch.library.impl(lib, "set_", "Meta")
@torch.library.impl(lib, "set_", "CUDA")
@torch.library.impl(lib, "set_", "CPU")
def set_(tensor, data):
tensor.set_(data)
@torch.library.impl(lib, "copy_", "Meta")
@torch.library.impl(lib, "copy_", "CUDA")
@torch.library.impl(lib, "copy_", "CPU")
def copy_(tensor, data):
tensor.copy_(data)
"""
[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_]
Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op
(i.e. they show up as a mutation op in the middle of the AOT joint graph).
Reason:
Traceable FSDP2 compiled autograd BWD graph have the following traits:
(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param).
(3) They are both subclasses.
The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
So this doesn't work at all for Traceable FSDP2.
The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops.
This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
We can avoid this functionalization because:
(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created),
so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream.
(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
This requires a custom FX pass but we believe it's not hard to write and maintain.
Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance?
A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops
for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0).
Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
nn.Parameter in order to see the result of .set_.
TODO:
Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward,
it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are:
Cons (of functionalizing those ops):
(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time
in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them
in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse
peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_:
https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089)
Pros (of functionalizing):
(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations
mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more).
(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it.
But maybe there are few enough mutations induced by FSDP for this to matter.
"""
@torch.library.impl(lib, "set_", "Functionalize")
def set__functionalize(tensor, data):
@torch.library.impl(lib, "copy_", "Functionalize")
def copy__functionalize(tensor, data):
torch._sync(tensor)
torch._sync(data)
# AOTDispatcher needs to know if any inputs had their storages mutated.
# (Why? It sometimes detaches inputs before sending them into the graph,
# when it sees that they do not need to have any gradients computed)
torch._functionalize_set_storage_changed(tensor)
tensor_inner = torch._from_functional_tensor(tensor)
data_inner = torch._from_functional_tensor(data)
with torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
):
torch.ops.fsdp.set_.default(tensor_inner, data_inner)
torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default)
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
class ShardedState(Enum):
@ -475,7 +479,12 @@ class FSDPParam:
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
self._unsharded_param
):
torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param)
# NOTE: Under compile, if an unsharded param goes through
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
# resize_ and copy_ ops in a compiler graph pass
# `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance.
alloc_storage(self._unsharded_param)
torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@ -605,13 +614,26 @@ class FSDPParam:
alloc_storage(tensor)
def free_unsharded_param(self) -> None:
for tensor in itertools.chain(
self.all_gather_outputs, self._unsharded_inner_tensors
):
free_storage(tensor)
if ca.compiled_autograd_enabled:
"""
Assumptions under compile:
- `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`.
Instead, we resize `self._unsharded_param` storage size to full and then
explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param`
in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove
the resize_ and copy_ ops in a compiler graph pass to recover performance.)
- `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT
graph inputs. They are created within the graph and is guaranteed to be freed
by the end of the graph. They don't leak outside of the graph.
"""
self._unsharded_param.untyped_storage().resize_(0)
self.all_gather_outputs = []
self._unsharded_inner_tensors = []
else:
for tensor in itertools.chain(
self.all_gather_outputs, self._unsharded_inner_tensors
):
free_storage(tensor)
@property
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D