mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730)
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good). This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern. ------ Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients` - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching` - `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager` - `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32` - `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
5ca46be15e
commit
94d2471d1f
@ -4,7 +4,10 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
@ -30,8 +33,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
def _is_op_in_graph(graph, op):
|
||||
return any(node.target is op for node in graph.nodes)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _count_op_in_graph(graph, op):
|
||||
return sum(1 for node in graph.nodes if node.target is op)
|
||||
|
||||
|
||||
def _is_fallback_op_in_snodes(snodes, op):
|
||||
@ -130,7 +136,7 @@ class TestFullyShardCompile(FSDPTest):
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(len(cnt.graphs), 1)
|
||||
|
||||
def test_trace_fsdp_set_(self):
|
||||
def test_trace_fsdp_copy_(self):
|
||||
@torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
|
||||
def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.add(x, 1, out=out)
|
||||
@ -140,7 +146,7 @@ class TestFullyShardCompile(FSDPTest):
|
||||
buf_view = buf.view(-1)
|
||||
torch.ops.mylib.add_one_out(x, out=buf_view)
|
||||
buf_view2 = buf.view(-1)
|
||||
torch.ops.fsdp.set_(x, buf_view2)
|
||||
torch.ops.fsdp.copy_(x, buf_view2)
|
||||
|
||||
ref_x = torch.zeros(2)
|
||||
x = copy.deepcopy(ref_x)
|
||||
@ -148,26 +154,80 @@ class TestFullyShardCompile(FSDPTest):
|
||||
torch.compile(f, backend="aot_eager")(x)
|
||||
self.assertEqual(x, ref_x)
|
||||
|
||||
def _assert_no_aliased_graph_inputs(self, graph: torch.fx.Graph) -> None:
|
||||
storage_id_to_graph_inputs = defaultdict(list)
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder" and isinstance(
|
||||
node.meta.get("val", None), torch.Tensor
|
||||
):
|
||||
storage_id_to_graph_inputs[
|
||||
id(node.meta["val"].untyped_storage())
|
||||
].append(node)
|
||||
no_aliased_graph_inputs = True
|
||||
err_msg = ""
|
||||
for aliased_graph_inputs in storage_id_to_graph_inputs.values():
|
||||
if len(aliased_graph_inputs) > 1:
|
||||
no_aliased_graph_inputs = False
|
||||
err_msg += f"""\n
|
||||
Found aliased graph inputs: {aliased_graph_inputs},
|
||||
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
"""
|
||||
self.assertTrue(no_aliased_graph_inputs, err_msg)
|
||||
|
||||
def _check_fsdp_copy_and_resize_ops_count_in_graph(
|
||||
self,
|
||||
graph,
|
||||
*,
|
||||
fwd_copy_count,
|
||||
fwd_resize_count,
|
||||
bwd_copy_count,
|
||||
bwd_resize_count,
|
||||
):
|
||||
def _check_count(copy_count, resize_count):
|
||||
actual_copy_count = _count_op_in_graph(graph, torch.ops.fsdp.copy_.default)
|
||||
self.assertEqual(
|
||||
actual_copy_count,
|
||||
copy_count,
|
||||
f"Unexpected number of `fsdp.copy_` ops (expected {copy_count}, got {actual_copy_count}) in graph: {graph}",
|
||||
)
|
||||
|
||||
actual_resize_count = _count_op_in_graph(
|
||||
graph, torch.ops.inductor.resize_storage_bytes_.default
|
||||
)
|
||||
self.assertEqual(
|
||||
actual_resize_count,
|
||||
resize_count,
|
||||
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
|
||||
)
|
||||
|
||||
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
|
||||
else:
|
||||
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
|
||||
|
||||
def _reinplace_all_gather_with_optional_checks(self, fullgraph):
|
||||
def _run_with_checks(graph, orig_fn):
|
||||
self.assertTrue(
|
||||
_is_op_in_graph(
|
||||
graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
self.assertGreater(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
orig_fn(graph)
|
||||
self.assertFalse(
|
||||
_is_op_in_graph(
|
||||
graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
),
|
||||
0,
|
||||
)
|
||||
self.assertTrue(
|
||||
_is_op_in_graph(
|
||||
graph,
|
||||
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||||
)
|
||||
|
||||
self.assertGreater(
|
||||
_count_op_in_graph(
|
||||
graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
if fullgraph:
|
||||
@ -266,8 +326,6 @@ class TestFullyShardCompile(FSDPTest):
|
||||
self,
|
||||
file_check,
|
||||
overlapped_compute_op_str,
|
||||
num_resize,
|
||||
num_set,
|
||||
last_all_gather=False,
|
||||
):
|
||||
file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.")
|
||||
@ -278,16 +336,9 @@ class TestFullyShardCompile(FSDPTest):
|
||||
# Checks that AGWait is delayed, making the AG overlap with some compute op.
|
||||
if overlapped_compute_op_str is not None:
|
||||
file_check = file_check.check(f"{overlapped_compute_op_str}")
|
||||
file_check = file_check.check_count(
|
||||
"inductor_ops.resize_storage_bytes_(", num_resize, exactly=True
|
||||
)
|
||||
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
|
||||
file_check = self.inductor_code_check_no_compute_op(file_check)
|
||||
file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.")
|
||||
file_check = self.inductor_code_check_no_compute_op(file_check)
|
||||
file_check = file_check.check_count(
|
||||
"torch.ops.aten.set_.", num_set, exactly=True
|
||||
)
|
||||
if not last_all_gather:
|
||||
# Checks that there is no compute op between this AGWait and next AG.
|
||||
file_check = self.inductor_code_check_no_compute_op(file_check)
|
||||
@ -307,20 +358,6 @@ class TestFullyShardCompile(FSDPTest):
|
||||
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
|
||||
return file_check
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
inline_inbuilt_nn_modules=True,
|
||||
skip_fsdp_hooks=False,
|
||||
)
|
||||
@torch._functorch.config.patch(recompute_views=True)
|
||||
@torch._functorch.config.patch(cse=False)
|
||||
@torch._inductor.config.patch(
|
||||
reorder_for_compute_comm_overlap=True,
|
||||
reorder_for_compute_comm_overlap_passes=[
|
||||
"sink_waits",
|
||||
"raise_comms",
|
||||
"reorder_compute_for_overlap",
|
||||
],
|
||||
)
|
||||
def _test_traceable_fsdp(
|
||||
self, model_init_fn, input_creation_fn, backend, fullgraph
|
||||
):
|
||||
@ -334,7 +371,12 @@ class TestFullyShardCompile(FSDPTest):
|
||||
|
||||
return _fn
|
||||
|
||||
def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None):
|
||||
def run_iters(
|
||||
model,
|
||||
optim,
|
||||
n_iter=10,
|
||||
compiled_autograd_backend=None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
losses = []
|
||||
for i in range(n_iter):
|
||||
@ -360,7 +402,11 @@ class TestFullyShardCompile(FSDPTest):
|
||||
run_iters(model, optim, n_iter=1)
|
||||
|
||||
model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph)
|
||||
res = run_iters(model_compiled, optim, compiled_autograd_backend=backend)
|
||||
res = run_iters(
|
||||
model_compiled,
|
||||
optim,
|
||||
compiled_autograd_backend=backend,
|
||||
)
|
||||
return res
|
||||
|
||||
def test_eager():
|
||||
@ -371,7 +417,23 @@ class TestFullyShardCompile(FSDPTest):
|
||||
res = run_iters(model, optim)
|
||||
return res
|
||||
|
||||
losses_compiled = test_compiled()
|
||||
with torch._dynamo.config.patch(
|
||||
inline_inbuilt_nn_modules=True,
|
||||
skip_fsdp_hooks=False,
|
||||
), torch._functorch.config.patch(
|
||||
recompute_views=True, cse=False
|
||||
), torch._inductor.config.patch(
|
||||
reorder_for_compute_comm_overlap=True,
|
||||
reorder_for_compute_comm_overlap_passes=[
|
||||
"sink_waits",
|
||||
"raise_comms",
|
||||
"reorder_compute_for_overlap",
|
||||
],
|
||||
post_grad_custom_pre_pass=self._assert_no_aliased_graph_inputs
|
||||
if fullgraph
|
||||
else None,
|
||||
):
|
||||
losses_compiled = test_compiled()
|
||||
losses_eager = test_eager()
|
||||
if not self.fake_pg:
|
||||
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
|
||||
@ -448,9 +510,9 @@ class TestFullyShardCompile(FSDPTest):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
ret = torch.matmul(x, self.param1)
|
||||
if not fullgraph:
|
||||
torch._dynamo.graph_break()
|
||||
ret = torch.matmul(x, self.param1)
|
||||
ret = ret * self.param2
|
||||
ret = torch.relu(ret)
|
||||
return ret
|
||||
@ -519,7 +581,19 @@ class TestFullyShardCompile(FSDPTest):
|
||||
for fullgraph in [True, False]:
|
||||
with self._reinplace_all_gather_with_optional_checks(
|
||||
fullgraph
|
||||
), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph):
|
||||
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
|
||||
fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
fwd_copy_count=0,
|
||||
fwd_resize_count=0,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=0,
|
||||
)
|
||||
if fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_nested_fully_shard_factory_fns(
|
||||
@ -537,46 +611,30 @@ class TestFullyShardCompile(FSDPTest):
|
||||
fwd_code = triton_codes[0]
|
||||
file_check = FileCheck().check("def call(args):")
|
||||
for fwd_ag_block_info in [
|
||||
dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
|
||||
dict(overlapped_compute_op_str=None),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=2,
|
||||
num_set=2,
|
||||
last_all_gather=True,
|
||||
),
|
||||
]:
|
||||
@ -588,16 +646,12 @@ class TestFullyShardCompile(FSDPTest):
|
||||
bwd_code = triton_codes[1]
|
||||
file_check = FileCheck().check("def call(args):")
|
||||
for bwd_ag_block_info in [
|
||||
dict(overlapped_compute_op_str=None, num_resize=0, num_set=2),
|
||||
dict(overlapped_compute_op_str=None),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=0,
|
||||
num_set=2,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=0,
|
||||
num_set=2,
|
||||
last_all_gather=True,
|
||||
),
|
||||
]:
|
||||
@ -605,7 +659,7 @@ class TestFullyShardCompile(FSDPTest):
|
||||
file_check, **bwd_ag_block_info
|
||||
)
|
||||
for bwd_rs_block_info in [
|
||||
dict(overlapped_compute_op_str="extern_kernels.mm("),
|
||||
dict(overlapped_compute_op_str="extern_kernels.addmm("),
|
||||
dict(
|
||||
overlapped_compute_op_str=None
|
||||
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
|
||||
@ -623,9 +677,10 @@ class TestFullyShardCompile(FSDPTest):
|
||||
"Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
|
||||
)
|
||||
|
||||
def _create_transformer_factory_fns(self):
|
||||
def _create_transformer_factory_fns(self, all_requires_grad):
|
||||
seq_len = 16
|
||||
vocab_size = 8
|
||||
n_layers = 3
|
||||
|
||||
def model_init_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
@ -633,9 +688,20 @@ class TestFullyShardCompile(FSDPTest):
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model_args = ModelArgs(
|
||||
vocab_size=vocab_size,
|
||||
n_layers=3,
|
||||
n_layers=n_layers,
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
if not all_requires_grad:
|
||||
requires_grad_params = ["attention.wq", "attention.wv"]
|
||||
requires_grad_param_count = 0
|
||||
for k, v in model.named_parameters():
|
||||
for substring in requires_grad_params:
|
||||
if substring in k:
|
||||
v.requires_grad_(True)
|
||||
requires_grad_param_count += 1
|
||||
else:
|
||||
v.requires_grad_(False)
|
||||
assert requires_grad_param_count == n_layers * len(requires_grad_params)
|
||||
for layer_id, mod in enumerate(model.layers):
|
||||
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
||||
model = fully_shard(
|
||||
@ -672,12 +738,16 @@ class TestFullyShardCompile(FSDPTest):
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_transformer_backend_aot_eager(self):
|
||||
for fullgraph in [True, False]:
|
||||
for fullgraph, all_requires_grad in itertools.product(
|
||||
[True, False], [True, False]
|
||||
):
|
||||
with self._maybe_add_graph_break_to_sdpa(
|
||||
fullgraph
|
||||
), self._reinplace_all_gather_with_optional_checks(fullgraph):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(),
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad
|
||||
),
|
||||
"aot_eager",
|
||||
fullgraph=fullgraph,
|
||||
)
|
||||
@ -687,10 +757,14 @@ class TestFullyShardCompile(FSDPTest):
|
||||
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_transformer_backend_aot_eager_decomp_partition(self):
|
||||
for fullgraph in [True, False]:
|
||||
for fullgraph, all_requires_grad in itertools.product(
|
||||
[True, False], [True, False]
|
||||
):
|
||||
with self._maybe_add_graph_break_to_sdpa(fullgraph):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(),
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad
|
||||
),
|
||||
"aot_eager_decomp_partition",
|
||||
fullgraph=fullgraph,
|
||||
)
|
||||
@ -700,17 +774,36 @@ class TestFullyShardCompile(FSDPTest):
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_transformer_backend_inductor(self):
|
||||
for fullgraph in [True, False]:
|
||||
# TODO: enable fullgraph=False case
|
||||
for fullgraph, all_requires_grad in itertools.product([True], [True, False]):
|
||||
log.warning(
|
||||
f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}" # noqa: G004, G001
|
||||
)
|
||||
with self._maybe_add_graph_break_to_sdpa(
|
||||
fullgraph
|
||||
), self._reinplace_all_gather_with_optional_checks(
|
||||
fullgraph
|
||||
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
|
||||
fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||
# their resize and copy ops in the graph.
|
||||
fwd_copy_count=4,
|
||||
fwd_resize_count=4,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=4,
|
||||
)
|
||||
if fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(),
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad
|
||||
),
|
||||
"inductor",
|
||||
fullgraph=fullgraph,
|
||||
)
|
||||
@ -723,21 +816,19 @@ class TestFullyShardCompile(FSDPTest):
|
||||
fwd_code = triton_codes[0]
|
||||
file_check = FileCheck().check("def call(args):")
|
||||
for fwd_ag_block_info in [
|
||||
dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4),
|
||||
dict(
|
||||
overlapped_compute_op_str="triton_"
|
||||
if all_requires_grad
|
||||
else None,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten.native_dropout.",
|
||||
num_resize=0,
|
||||
num_set=12,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
|
||||
num_resize=12,
|
||||
num_set=12,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
|
||||
num_resize=12,
|
||||
num_set=12,
|
||||
last_all_gather=True,
|
||||
),
|
||||
]:
|
||||
@ -751,35 +842,33 @@ class TestFullyShardCompile(FSDPTest):
|
||||
for bwd_ag_block_info in [
|
||||
dict(
|
||||
overlapped_compute_op_str="extern_kernels.mm(",
|
||||
num_resize=0,
|
||||
num_set=12,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
|
||||
num_resize=0,
|
||||
num_set=12,
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
|
||||
num_resize=0,
|
||||
num_set=12,
|
||||
last_all_gather=True,
|
||||
),
|
||||
]:
|
||||
file_check = self.inductor_code_check_fsdp_all_gather(
|
||||
file_check, **bwd_ag_block_info
|
||||
)
|
||||
if bwd_ag_block_info is not None:
|
||||
file_check = self.inductor_code_check_fsdp_all_gather(
|
||||
file_check, **bwd_ag_block_info
|
||||
)
|
||||
for bwd_rs_block_info in [
|
||||
dict(overlapped_compute_op_str="extern_kernels.mm("),
|
||||
dict(overlapped_compute_op_str="extern_kernels.mm(")
|
||||
if all_requires_grad
|
||||
else None,
|
||||
dict(
|
||||
overlapped_compute_op_str=None
|
||||
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
|
||||
dict(overlapped_compute_op_str=None),
|
||||
dict(overlapped_compute_op_str=None),
|
||||
dict(overlapped_compute_op_str=None) if all_requires_grad else None,
|
||||
]:
|
||||
file_check = self.inductor_code_check_fsdp_reduce_scatter(
|
||||
file_check, **bwd_rs_block_info
|
||||
)
|
||||
if bwd_rs_block_info is not None:
|
||||
file_check = self.inductor_code_check_fsdp_reduce_scatter(
|
||||
file_check, **bwd_rs_block_info
|
||||
)
|
||||
file_check.run(bwd_code)
|
||||
else:
|
||||
# TODO: when fullgraph=False and there is graph break in FWD graph,
|
||||
|
@ -5557,7 +5557,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
z0 = x.sin()
|
||||
z1 = x.sin()
|
||||
y = x + 1
|
||||
torch.ops.fsdp.set_.default(x, y)
|
||||
torch.ops.fsdp.copy_.default(x, y)
|
||||
# z3 and z3 can be CSEd with each other,
|
||||
# but *not* with z0/z1 (they cross a mutation boundary)
|
||||
z2 = x.sin()
|
||||
@ -5589,7 +5589,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
z = x.sin()
|
||||
y = x + 1
|
||||
# graph input has its storage mutated
|
||||
torch.ops.fsdp.set_.default(x, y)
|
||||
torch.ops.fsdp.copy_.default(x, y)
|
||||
z2 = x.sin()
|
||||
return z2, l**2
|
||||
|
||||
|
@ -725,59 +725,6 @@ def forward(self, primals_1):
|
||||
return (add,)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case")
|
||||
@unittest.skipIf(IS_MACOS, "TODO: need to fix the test case")
|
||||
def test_input_mutation_fsdp_set__into_same_input(self):
|
||||
import torch.distributed._composable.fsdp._fsdp_param
|
||||
|
||||
def f(a):
|
||||
b = torch.arange(9, dtype=a.dtype).view(3, 3)
|
||||
c = torch.arange(9, dtype=a.dtype).view(3, 3)
|
||||
d = torch.arange(9, dtype=a.dtype).view(3, 3)
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
|
||||
torch.ops.fsdp.set_.default(a, b)
|
||||
x = a * a
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
|
||||
torch.ops.fsdp.set_.default(a, c)
|
||||
y = a * a
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
|
||||
torch.ops.fsdp.set_.default(a, c)
|
||||
z = a * a
|
||||
return x + y + z
|
||||
|
||||
inp = [torch.ones(3, 3, requires_grad=True)]
|
||||
fw_graph = self.verify_aot_autograd(
|
||||
f, inp, test_mutation=True, keep_inp_mutations=True
|
||||
)
|
||||
inp = [torch.ones(3, 3, requires_grad=False)]
|
||||
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
|
||||
"""
|
||||
Expected behavior:
|
||||
(1) When there are multiple set_() calls on the same graph input primal_X,
|
||||
we want those set_() calls to all show up with primal_X as the first arg in the graph.
|
||||
(2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892),
|
||||
but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior.
|
||||
"""
|
||||
self.assertExpectedInline(
|
||||
fw_graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals_1):
|
||||
arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
|
||||
arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None
|
||||
set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None
|
||||
mul = torch.ops.aten.mul.Tensor(primals_1, primals_1)
|
||||
set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
|
||||
set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None
|
||||
mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
|
||||
add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None
|
||||
return (add_1, primals_1)""",
|
||||
)
|
||||
self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp))
|
||||
|
||||
def test_input_mutation_simple_with_none_and_nontensor(self):
|
||||
# Tensor, None, int
|
||||
def f(a, b, c):
|
||||
|
@ -26,22 +26,13 @@ def init_fake_distributed(device="cpu"):
|
||||
return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
|
||||
|
||||
def fw_pre_hook(mod, inp):
|
||||
if not compiled_autograd.compiled_autograd_enabled:
|
||||
# torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
|
||||
mod.unsharded_weight.untyped_storage().resize_(
|
||||
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
||||
)
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
|
||||
else:
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
torch.ops.fsdp.set_(
|
||||
mod.unsharded_weight, all_gather(mod.sharded_weight)
|
||||
)
|
||||
mod.unsharded_weight.untyped_storage().resize_(
|
||||
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
||||
)
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
|
||||
mod._parameters["weight"] = mod.unsharded_weight
|
||||
|
||||
# Forward:
|
||||
@ -58,22 +49,13 @@ def init_fake_distributed(device="cpu"):
|
||||
mod.unsharded_weight.untyped_storage().resize_(0)
|
||||
|
||||
def bw_pre_hook(mod, gO):
|
||||
if not compiled_autograd.compiled_autograd_enabled:
|
||||
# torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead.
|
||||
mod.unsharded_weight.untyped_storage().resize_(
|
||||
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
||||
)
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
mod.unsharded_weight.copy_(all_gather(mod.sharded_weight))
|
||||
else:
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
torch.ops.fsdp.set_(
|
||||
mod.unsharded_weight, all_gather(mod.sharded_weight)
|
||||
)
|
||||
mod.unsharded_weight.untyped_storage().resize_(
|
||||
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
||||
)
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
mod.unsharded_weight
|
||||
):
|
||||
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
|
||||
mod._parameters["weight"] = mod.unsharded_weight
|
||||
|
||||
# Backward:
|
||||
@ -465,7 +447,6 @@ class DistributedPatternTests(TestCase):
|
||||
@requires_gpu()
|
||||
@torch._functorch.config.patch(recompute_views=True)
|
||||
def test_fake_distributed_inductor(self):
|
||||
# TODO: fix .set_ lowering in CPU inductor, and enable the CPU test.
|
||||
m1, inp1 = init_fake_distributed(GPU_TYPE)
|
||||
out1 = steps(m1, inp1)
|
||||
|
||||
|
@ -1447,7 +1447,7 @@ class TestPatternMatcher(TestCase):
|
||||
(t, [64, 128, 8, 8]),
|
||||
{"dim": 1, "out": [t, t, t, t]},
|
||||
)
|
||||
check("call_function", torch.ops.fsdp.set_, (t, t), {})
|
||||
check("call_function", torch.ops.fsdp.copy_, (t, t), {})
|
||||
check(
|
||||
"call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False
|
||||
)
|
||||
|
@ -405,8 +405,8 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
||||
torch.ops.aten.copy_.default,
|
||||
torch.ops.aten.set_.source_Tensor,
|
||||
]
|
||||
if hasattr(torch.ops.fsdp, "set_"):
|
||||
allowed_mutation_ops.append(torch.ops.fsdp.set_.default)
|
||||
if hasattr(torch.ops.fsdp, "copy_"):
|
||||
allowed_mutation_ops.append(torch.ops.fsdp.copy_.default)
|
||||
|
||||
placeholders = set()
|
||||
mutation_count = 0
|
||||
|
@ -1272,6 +1272,7 @@ def get_default_op_list() -> OpTypes:
|
||||
aten.expand,
|
||||
aten.as_strided,
|
||||
aten.permute,
|
||||
aten.select,
|
||||
]
|
||||
view_ops = recomputable_view_ops
|
||||
default_recomputable_ops += [
|
||||
|
@ -3,12 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
from . import config, ir
|
||||
from .dependencies import WeakDep
|
||||
@ -23,6 +25,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -342,6 +345,202 @@ def reorder_compute_and_comm_for_overlap(
|
||||
return order
|
||||
|
||||
|
||||
def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
|
||||
"""
|
||||
This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding
|
||||
graph intermediates that were fsdp.copy_ into the unsharded params in the original graph.
|
||||
|
||||
NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern
|
||||
(or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case
|
||||
where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't
|
||||
remove these resize and copy ops and thus we will have worse performance there.
|
||||
|
||||
In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param"
|
||||
is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern
|
||||
(in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed.
|
||||
"""
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
# Find all graph inputs and their resize counts
|
||||
graph_input_to_resized_to_full_node_idxes = defaultdict(list)
|
||||
graph_input_to_resized_to_0_node_idxes = defaultdict(list)
|
||||
for idx, node in enumerate(node_list):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
):
|
||||
assert (
|
||||
node.args[0].op == "placeholder"
|
||||
), f"""\
|
||||
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
||||
"""
|
||||
graph_input = node.args[0]
|
||||
new_size = node.args[1]
|
||||
if new_size > 0:
|
||||
graph_input_to_resized_to_full_node_idxes[graph_input].append(idx)
|
||||
else:
|
||||
graph_input_to_resized_to_0_node_idxes[graph_input].append(idx)
|
||||
|
||||
def check_resize_pattern(graph_input):
|
||||
# Check the number of resize-to-full and resize-to-0 nodes are equal,
|
||||
# and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node
|
||||
# always happens before the resize-to-0 node.
|
||||
# This is the precondition for being able to remove all the resize and copy nodes
|
||||
# for this specific unsharded param.
|
||||
resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get(
|
||||
graph_input, []
|
||||
)
|
||||
resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
|
||||
|
||||
if not len(resized_to_full_idxes) == len(resized_to_0_idxes):
|
||||
log.warning(
|
||||
f"""
|
||||
Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:
|
||||
{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}.
|
||||
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass.
|
||||
""" # noqa: G004
|
||||
)
|
||||
return False
|
||||
|
||||
# Check the sequence: (resize_to_full -> resize_to_0)+
|
||||
for resize_to_full_idx, resize_to_0_idx in zip(
|
||||
resized_to_full_idxes, resized_to_0_idxes
|
||||
):
|
||||
if resize_to_full_idx >= resize_to_0_idx:
|
||||
log.warning(
|
||||
f"""
|
||||
For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx}
|
||||
happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}.
|
||||
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param.
|
||||
""" # noqa: G004
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# Find all eligible unsharded params and their corresponding graph intermediates.
|
||||
unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
|
||||
for idx, node in enumerate(node_list):
|
||||
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
||||
fsdp_copy_node = node
|
||||
unsharded_param = node.args[0]
|
||||
assert (
|
||||
unsharded_param.op == "placeholder"
|
||||
), f"""
|
||||
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
|
||||
Offending node: {unsharded_param}. Graph: {graph}
|
||||
"""
|
||||
if check_resize_pattern(unsharded_param):
|
||||
unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx)
|
||||
|
||||
def is_allowed_mutation(node):
|
||||
return (
|
||||
node.target == torch.ops.fsdp.copy_.default
|
||||
or node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
)
|
||||
|
||||
def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
|
||||
# Check whether the node is mutating any of the unsharded params or their aliases.
|
||||
mutated_arg_idxes = (
|
||||
[
|
||||
i
|
||||
for i, x in enumerate(node.target._schema.arguments)
|
||||
if x.alias_info is not None and x.alias_info.is_write
|
||||
]
|
||||
if isinstance(node.target, torch._ops.OpOverload)
|
||||
else []
|
||||
)
|
||||
mutated_node_arg_storages = {
|
||||
StorageWeakRef(node.args[i].meta["val"].untyped_storage())
|
||||
for i in mutated_arg_idxes
|
||||
}
|
||||
storages_of_unsharded_params = {
|
||||
StorageWeakRef(unsharded_param.meta["val"].untyped_storage())
|
||||
for unsharded_param in unsharded_params
|
||||
}
|
||||
return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0
|
||||
|
||||
# Check no user mutation on any unsharded_param
|
||||
for node in node_list:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and isinstance(node.target, torch._ops.OpOverload)
|
||||
and node.target._schema.is_mutable
|
||||
and not is_allowed_mutation(node)
|
||||
):
|
||||
assert not is_node_mutating_unsharded_param_or_its_alias(
|
||||
node, unsharded_param_to_fsdp_copy_node_idxes.keys()
|
||||
), f"""\
|
||||
User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node}
|
||||
"""
|
||||
|
||||
# For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`.
|
||||
#
|
||||
# NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input.
|
||||
# e.g.
|
||||
# ```
|
||||
# fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1)
|
||||
# ... (use of unsharded_param_1) -> Subgraph 1
|
||||
# fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2)
|
||||
# ... (use of unsharded_param_1) -> Subgraph 2
|
||||
# fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3)
|
||||
# ... (use of unsharded_param_1) -> Subgraph 3
|
||||
# ```
|
||||
# We must do the replacement only within each subgraph.
|
||||
for (
|
||||
unsharded_param,
|
||||
fsdp_copy_node_idxes,
|
||||
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
||||
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
||||
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
||||
assert fsdp_copy_node.args[0] is unsharded_param
|
||||
_, replacement = fsdp_copy_node.args
|
||||
# subgraph_start_idx is exclusive
|
||||
subgraph_start_idx = fsdp_copy_node_idx + 1
|
||||
# subgraph_end_idx is exclusive (also intentionally don't replace args in return op)
|
||||
subgraph_end_idx = (
|
||||
fsdp_copy_node_idxes[i + 1]
|
||||
if i < len(fsdp_copy_node_idxes) - 1
|
||||
else len(node_list) - 1
|
||||
)
|
||||
subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx]
|
||||
assert not any(
|
||||
is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param])
|
||||
for node in subgraph_nodes
|
||||
), f"""\
|
||||
Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true!
|
||||
Graph: {graph}
|
||||
"""
|
||||
for node in subgraph_nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and unsharded_param in node.args
|
||||
and node.target != torch.ops.inductor.resize_storage_bytes_.default
|
||||
): # TODO(yf225): implement replacement in kwargs
|
||||
new_args = tuple(
|
||||
replacement if arg is unsharded_param else arg
|
||||
for arg in node.args
|
||||
)
|
||||
node.args = new_args
|
||||
|
||||
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
|
||||
for (
|
||||
unsharded_param,
|
||||
fsdp_copy_node_idxes,
|
||||
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
||||
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
||||
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
||||
graph.erase_node(fsdp_copy_node)
|
||||
|
||||
# Delete `resize_(unsharded_param, ...)` nodes
|
||||
for node in node_list:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
|
||||
):
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
||||
try:
|
||||
import torch.distributed._composable.fsdp._fsdp_collectives
|
||||
@ -509,12 +708,11 @@ def enforce_comm_ordering_for_fsdp(
|
||||
name_to_fused_node,
|
||||
)
|
||||
|
||||
# Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
|
||||
# Find the "all_gather + all_gather_wait_tensor + copy_out" code block
|
||||
allowed_ops = {
|
||||
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
torch.ops.fsdp.split_with_sizes_copy.default,
|
||||
torch.ops.aten.set_.source_Tensor,
|
||||
}
|
||||
find_recursive_users_of_node(
|
||||
ag_snode,
|
||||
@ -560,7 +758,7 @@ def enforce_comm_ordering_for_fsdp(
|
||||
assert wait_node_idx is not None
|
||||
ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
|
||||
|
||||
# Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
|
||||
# Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode
|
||||
ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
|
||||
|
||||
ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
|
||||
|
@ -22,6 +22,7 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
|
||||
from .. import config, ir, pattern_matcher
|
||||
from ..codegen.common import BackendFeature, has_backend_feature
|
||||
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
|
||||
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
|
||||
from ..lowering import lowerings as L
|
||||
from ..pattern_matcher import (
|
||||
@ -76,6 +77,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
|
||||
The IR here has been normalized and functionalized.
|
||||
"""
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
|
||||
|
||||
if config.dce:
|
||||
# has some issues with mutation in inference mode
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
@ -6102,13 +6102,17 @@ def set__source_tensor(self, source_tensor):
|
||||
return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
|
||||
|
||||
|
||||
if hasattr(torch.ops.fsdp, "set_"):
|
||||
if hasattr(torch.ops.fsdp, "copy_"):
|
||||
|
||||
@register_lowering(torch.ops.fsdp.set_.default)
|
||||
def fsdp_set_(self, source_tensor):
|
||||
self.realize()
|
||||
source_tensor.realize()
|
||||
ir.SetSourceTensorKernel(self, source_tensor)
|
||||
@register_lowering(torch.ops.fsdp.copy_.default)
|
||||
def fsdp_copy_(dst, src):
|
||||
if dst is src:
|
||||
# dst.copy_(dst) can happen from the reinplacing pass
|
||||
return dst
|
||||
src = to_device(src, dst.get_device())
|
||||
src = to_dtype(src, dst.get_dtype())
|
||||
src = expand(src, dst.get_size())
|
||||
return mutate_to(dst, src)
|
||||
|
||||
|
||||
@register_lowering(torch.ops.aten.resize)
|
||||
|
@ -65,73 +65,77 @@ data, so we use storage resizing on the all-gather output.
|
||||
|
||||
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
|
||||
lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()")
|
||||
|
||||
|
||||
@torch.library.impl(lib, "set_", "Meta")
|
||||
@torch.library.impl(lib, "set_", "CUDA")
|
||||
@torch.library.impl(lib, "set_", "CPU")
|
||||
def set_(tensor, data):
|
||||
tensor.set_(data)
|
||||
@torch.library.impl(lib, "copy_", "Meta")
|
||||
@torch.library.impl(lib, "copy_", "CUDA")
|
||||
@torch.library.impl(lib, "copy_", "CPU")
|
||||
def copy_(tensor, data):
|
||||
tensor.copy_(data)
|
||||
|
||||
|
||||
"""
|
||||
[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
|
||||
[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_]
|
||||
|
||||
Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
|
||||
Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op
|
||||
(i.e. they show up as a mutation op in the middle of the AOT joint graph).
|
||||
|
||||
Reason:
|
||||
Traceable FSDP2 compiled autograd BWD graph have the following traits:
|
||||
(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
|
||||
(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
|
||||
(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param).
|
||||
(3) They are both subclasses.
|
||||
The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
|
||||
So this doesn't work at all for Traceable FSDP2.
|
||||
|
||||
The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
|
||||
The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops.
|
||||
This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
|
||||
that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
|
||||
|
||||
We can avoid this functionalization because:
|
||||
(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
|
||||
so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
|
||||
(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created),
|
||||
so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream.
|
||||
(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
|
||||
So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
|
||||
(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
|
||||
|
||||
Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
|
||||
A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
|
||||
Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
|
||||
(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
|
||||
make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
|
||||
This requires a custom FX pass but we believe it's not hard to write and maintain.
|
||||
Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance?
|
||||
A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops
|
||||
for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0).
|
||||
|
||||
Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
|
||||
A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
|
||||
so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
|
||||
This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
|
||||
nn.Parameter in order to see the result of .set_.
|
||||
TODO:
|
||||
Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward,
|
||||
it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are:
|
||||
|
||||
Cons (of functionalizing those ops):
|
||||
(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time
|
||||
in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them
|
||||
in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse
|
||||
peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_:
|
||||
https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089)
|
||||
|
||||
Pros (of functionalizing):
|
||||
(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations
|
||||
mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more).
|
||||
(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it.
|
||||
But maybe there are few enough mutations induced by FSDP for this to matter.
|
||||
"""
|
||||
|
||||
|
||||
@torch.library.impl(lib, "set_", "Functionalize")
|
||||
def set__functionalize(tensor, data):
|
||||
@torch.library.impl(lib, "copy_", "Functionalize")
|
||||
def copy__functionalize(tensor, data):
|
||||
torch._sync(tensor)
|
||||
torch._sync(data)
|
||||
# AOTDispatcher needs to know if any inputs had their storages mutated.
|
||||
# (Why? It sometimes detaches inputs before sending them into the graph,
|
||||
# when it sees that they do not need to have any gradients computed)
|
||||
torch._functionalize_set_storage_changed(tensor)
|
||||
tensor_inner = torch._from_functional_tensor(tensor)
|
||||
data_inner = torch._from_functional_tensor(data)
|
||||
with torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
):
|
||||
torch.ops.fsdp.set_.default(tensor_inner, data_inner)
|
||||
torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
|
||||
|
||||
|
||||
torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default)
|
||||
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
|
||||
|
||||
|
||||
class ShardedState(Enum):
|
||||
@ -475,7 +479,12 @@ class FSDPParam:
|
||||
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
||||
self._unsharded_param
|
||||
):
|
||||
torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param)
|
||||
# NOTE: Under compile, if an unsharded param goes through
|
||||
# resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
|
||||
# resize_ and copy_ ops in a compiler graph pass
|
||||
# `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance.
|
||||
alloc_storage(self._unsharded_param)
|
||||
torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param)
|
||||
else:
|
||||
self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
@ -605,13 +614,26 @@ class FSDPParam:
|
||||
alloc_storage(tensor)
|
||||
|
||||
def free_unsharded_param(self) -> None:
|
||||
for tensor in itertools.chain(
|
||||
self.all_gather_outputs, self._unsharded_inner_tensors
|
||||
):
|
||||
free_storage(tensor)
|
||||
if ca.compiled_autograd_enabled:
|
||||
"""
|
||||
Assumptions under compile:
|
||||
- `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`.
|
||||
Instead, we resize `self._unsharded_param` storage size to full and then
|
||||
explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param`
|
||||
in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove
|
||||
the resize_ and copy_ ops in a compiler graph pass to recover performance.)
|
||||
- `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT
|
||||
graph inputs. They are created within the graph and is guaranteed to be freed
|
||||
by the end of the graph. They don't leak outside of the graph.
|
||||
"""
|
||||
self._unsharded_param.untyped_storage().resize_(0)
|
||||
self.all_gather_outputs = []
|
||||
self._unsharded_inner_tensors = []
|
||||
else:
|
||||
for tensor in itertools.chain(
|
||||
self.all_gather_outputs, self._unsharded_inner_tensors
|
||||
):
|
||||
free_storage(tensor)
|
||||
|
||||
@property
|
||||
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
|
||||
|
Reference in New Issue
Block a user