mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Inductor support for native c10d_functional (#112439)
This PR adds Inductor support for [native c10d_functional ops](https://github.com/pytorch/pytorch/pull/110570). The Inductor IRs introduced in this PR will replace the existing `CollectiveKernel` IR hierarchy. Compared to the existing collective IRs, the new IRs: - Are target language agnostic and support AOTInductor. - Express the constraints solely with read/write deps. This maximizes the potential for buffer reuse. - Address an issue where out-of-place collective's input buffers could be mutated while being volatilely read. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112439 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
297c26bb8e
commit
625958d8bc
@ -1,14 +1,19 @@
|
||||
# Owner(s): ["module: c10d"]
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.utils import same
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
@ -182,6 +187,264 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert output.eq(self.rank * i).all()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_all_reduce_single(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
# Expect in-place with inductor allocated buf
|
||||
ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default")
|
||||
ar0 = torch.ops._c10d_functional.wait_tensor(ar0)
|
||||
# Expect no in-place with graph input
|
||||
ar1 = torch.ops._c10d_functional.all_reduce(arg, "avg", "default")
|
||||
ar1 = torch.ops._c10d_functional.wait_tensor(ar1)
|
||||
return ar0, ar1
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
FileCheck()
|
||||
.check("buf0 = empty(")
|
||||
.check("buf5 = empty(")
|
||||
# Expect in-place with inductor allocated buf
|
||||
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
# Expect no in-place with graph input (buf5 is a clone)
|
||||
.check("torch.ops._c10d_functional.all_reduce_.default(buf5")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf5")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf0, buf5, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(arg)
|
||||
correct = func(arg)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_all_reduce_coalesced(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
bufs = [arg + 42 for arg in args]
|
||||
# Expect in-place with inductor allocated buf
|
||||
ar0 = torch.ops._c10d_functional.all_reduce_coalesced(
|
||||
bufs, "avg", "default"
|
||||
)
|
||||
ar0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar0]
|
||||
# Expect no in-place with graph input
|
||||
ar1 = torch.ops._c10d_functional.all_reduce_coalesced(
|
||||
args, "avg", "default"
|
||||
)
|
||||
ar1 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar1]
|
||||
return ar0, ar1
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device) for _ in range(2)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
FileCheck()
|
||||
.check("buf0 = empty(")
|
||||
.check("buf5 = empty(")
|
||||
.check("buf1 = empty(")
|
||||
.check("buf6 = empty(")
|
||||
# Expect in-place with inductor allocated buf
|
||||
.check(
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
||||
".default([buf0, buf1]"
|
||||
)
|
||||
# Expect no in-place with graph input (buf5, buf6 are clones)
|
||||
.check(
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
||||
".default([buf5, buf6]"
|
||||
)
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf5")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf6")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf0, buf1, buf5, buf6, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(args)
|
||||
correct = func(args)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_reuse_buffer_after_inplace_collective(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
# Expect allocation
|
||||
buf0 = arg + 42
|
||||
ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default")
|
||||
ar0 = torch.ops._c10d_functional.wait_tensor(ar0)
|
||||
# Expect allocation
|
||||
buf1 = torch.mm(arg, ar0)
|
||||
# Expect buf0 to be reused
|
||||
buf2 = torch.mm(arg, buf1)
|
||||
return buf1, buf2
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
FileCheck()
|
||||
# Expect allocation
|
||||
.check("buf0 = empty(")
|
||||
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
# Expect allocation
|
||||
.check("buf5 = empty(")
|
||||
.check("extern_kernels.mm(arg0_1, buf0, out=buf5")
|
||||
# Expect buf0 to be reused
|
||||
.check("buf6 = buf0; del buf0 # reuse")
|
||||
.check("extern_kernels.mm(arg0_1, buf5, out=buf6")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf5, buf6, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(arg)
|
||||
correct = func(arg)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_all_gather_into_tensor_single(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
ag0 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
arg, self.world_size, "default"
|
||||
)
|
||||
ag0 = torch.ops._c10d_functional.wait_tensor(ag0)
|
||||
return ag0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"
|
||||
)
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf0, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(arg)
|
||||
correct = func(arg)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_all_gather_into_tensor_coalesced(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
ag0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
|
||||
args, self.world_size, "default"
|
||||
)
|
||||
ag0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ag0]
|
||||
return ag0
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
print(code)
|
||||
(
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
|
||||
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
|
||||
)
|
||||
.check("buf1 = buf0[0]")
|
||||
.check("buf2 = buf0[1]")
|
||||
.check("buf3 = buf0[2]")
|
||||
.check("buf4 = buf0[3]")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf4")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf1, buf2, buf3, buf4, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(args)
|
||||
correct = func(args)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_reduce_scatter_tensor_single(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
arg, "avg", self.world_size, "default"
|
||||
)
|
||||
rs0 = torch.ops._c10d_functional.wait_tensor(rs0)
|
||||
return rs0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"
|
||||
)
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf0, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(arg)
|
||||
correct = func(arg)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_inductor_reduce_scatter_tensor_coalesced(self):
|
||||
torch._inductor.config.debug = True
|
||||
self._init_process_group()
|
||||
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
|
||||
args, "avg", self.world_size, "default"
|
||||
)
|
||||
rs0 = [torch.ops._c10d_functional.wait_tensor(out) for out in rs0]
|
||||
return rs0
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"
|
||||
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
|
||||
)
|
||||
.check("buf1 = buf0[0]")
|
||||
.check("buf2 = buf0[1]")
|
||||
.check("buf3 = buf0[2]")
|
||||
.check("buf4 = buf0[3]")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf4")
|
||||
# Expect no extra copy on return
|
||||
.check("return (buf1, buf2, buf3, buf4, )")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(args)
|
||||
correct = func(args)
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -35,4 +35,4 @@ inputs and outputs have any aliasing, it suffices to check whether the
|
||||
storages of the input and the storages of the output have any overlap. See
|
||||
`remove_noop_ops` for an example of how to do this.
|
||||
|
||||
Additionally, we do have one pass that *does* introduce mutation - `reinplace_scatters`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant.
|
||||
Additionally, we do have one pass that *does* introduce mutation - `reinplace_inplaceable_ops`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant.
|
||||
|
||||
@ -93,7 +93,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
|
||||
# Keep this last, since it introduces mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
reinplace_scatters(gm.graph)
|
||||
reinplace_inplaceable_ops(gm.graph)
|
||||
gm.recompile()
|
||||
gm.graph.lint()
|
||||
|
||||
@ -625,9 +625,9 @@ def remove_noop_ops(graph: torch.fx.Graph):
|
||||
InplaceableOp = namedtuple("InplaceableOp", ["inplace_op", "mutated_arg"])
|
||||
|
||||
|
||||
def reinplace_scatters(graph):
|
||||
def reinplace_inplaceable_ops(graph):
|
||||
"""
|
||||
Reinplaces scatter operations.
|
||||
Reinplaces in-placeable operations.
|
||||
If there are no uses of a view of the mutated arg after the current node,
|
||||
it is possible to inplace the op.
|
||||
This above algorithm could be justified by observing side effects. While
|
||||
@ -676,6 +676,9 @@ def reinplace_scatters(graph):
|
||||
return False
|
||||
|
||||
def can_inplace(node, mutated_arg):
|
||||
if isinstance(mutated_arg, (list, tuple)):
|
||||
return all(can_inplace(node, arg) for arg in mutated_arg)
|
||||
|
||||
if get_node_storage(mutated_arg) is None:
|
||||
return False
|
||||
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
|
||||
@ -690,7 +693,6 @@ def reinplace_scatters(graph):
|
||||
):
|
||||
return False
|
||||
|
||||
graph.erase_node(copy_node)
|
||||
return True
|
||||
elif any(view.op == "placeholder" for view in shared_view_nodes):
|
||||
# If mutated arg is view of any of the inputs of the graph,
|
||||
@ -708,11 +710,35 @@ def reinplace_scatters(graph):
|
||||
inductor_prims._unsafe_index_put_, 0
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
c10d_functional = torch.ops._c10d_functional
|
||||
inplaceable_collective_ops = {
|
||||
c10d_functional.all_reduce.default: InplaceableOp(
|
||||
c10d_functional.all_reduce_.default, 0
|
||||
),
|
||||
c10d_functional.all_reduce_coalesced.default: InplaceableOp(
|
||||
c10d_functional.all_reduce_coalesced_.default, 0
|
||||
),
|
||||
}
|
||||
inplaceable_ops.update(inplaceable_collective_ops)
|
||||
except AttributeError:
|
||||
# _c10d_functional ops are only available when torch
|
||||
# is built with USE_DISTRIBUTED=1.
|
||||
pass
|
||||
|
||||
inplaceable_triton_ops = {triton_kernel_wrapper_functional}
|
||||
|
||||
for node in graph.nodes:
|
||||
if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
|
||||
if can_inplace(node, node.args[inplaceable_op.mutated_arg]):
|
||||
mutated_arg = node.args[inplaceable_op.mutated_arg]
|
||||
if can_inplace(node, mutated_arg):
|
||||
# TODO(yifu): this doesn't properly remove copy epilogues for
|
||||
# ops that mutate multiple inputs. Need to revise the copy
|
||||
# node tracking logic to support the case.
|
||||
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
||||
if copy_node is not None:
|
||||
graph.erase_node(copy_node)
|
||||
node.target = inplaceable_op.inplace_op
|
||||
elif node.target in inplaceable_triton_ops:
|
||||
# inplaceable_triton_ops take an additional argument called
|
||||
@ -722,7 +748,12 @@ def reinplace_scatters(graph):
|
||||
tensors_to_clone = []
|
||||
for arg in node.kwargs["tensors_to_clone"]:
|
||||
assert arg in node.kwargs["kwargs"]
|
||||
if not can_inplace(node, node.kwargs["kwargs"][arg]):
|
||||
mutated_arg = node.kwargs["kwargs"][arg]
|
||||
if can_inplace(node, mutated_arg):
|
||||
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
||||
if copy_node is not None:
|
||||
graph.erase_node(copy_node)
|
||||
else:
|
||||
tensors_to_clone.append(arg)
|
||||
kwargs = dict(node.kwargs)
|
||||
kwargs["tensors_to_clone"] = tensors_to_clone
|
||||
|
||||
@ -2693,7 +2693,7 @@ class Buffer(IRNode):
|
||||
# dynamic; it has i0 as one of the arguments. We cannot tell this
|
||||
# directly from MultiOutput, we have to look at the input buffer's
|
||||
# uses to work this out. No big deal.
|
||||
if isinstance(self.layout, MultiOutputLayout):
|
||||
if isinstance(self.layout, (NoneLayout, MultiOutputLayout)):
|
||||
return set()
|
||||
|
||||
# This kernel defines all unbacked symbols... that it didn't get in as
|
||||
@ -4496,6 +4496,15 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
else:
|
||||
super().codegen(wrapper)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_layout(output: torch.Tensor):
|
||||
return FixedLayout(
|
||||
output.device,
|
||||
output.dtype,
|
||||
convert_shape_to_inductor(output.size()),
|
||||
convert_shape_to_inductor(output.stride()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(cls, kernel, *args, **kwargs):
|
||||
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
|
||||
@ -4511,18 +4520,10 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
schema,
|
||||
) = cls.process_kernel(kernel, *args, **kwargs)
|
||||
|
||||
device = FallbackKernel.find_device(tensor_args, example_output)
|
||||
device = cls.find_device(tensor_args, example_output)
|
||||
assert device, "Not sure where to find device info"
|
||||
|
||||
def tensor_to_layout(output: torch.Tensor):
|
||||
return FixedLayout(
|
||||
output.device,
|
||||
output.dtype,
|
||||
convert_shape_to_inductor(output.size()),
|
||||
convert_shape_to_inductor(output.stride()),
|
||||
)
|
||||
|
||||
packed = FallbackKernel(
|
||||
packed = cls(
|
||||
MultiOutputLayout(device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
@ -4544,7 +4545,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
}
|
||||
elif isinstance(output, torch.Tensor):
|
||||
return MultiOutput(
|
||||
tensor_to_layout(output),
|
||||
cls.tensor_to_layout(output),
|
||||
packed,
|
||||
indices,
|
||||
)
|
||||
@ -6930,6 +6931,176 @@ class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel):
|
||||
)
|
||||
|
||||
|
||||
# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel.
|
||||
class _CollectiveKernel(FallbackKernel):
|
||||
def should_allocate(self):
|
||||
return False
|
||||
|
||||
def has_side_effects(self):
|
||||
return True
|
||||
|
||||
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
|
||||
# part that checks against input aliasing and mutation.
|
||||
def set_cpp_kernel(self, kernel):
|
||||
from .codegen.wrapper import get_cpp_op_schema
|
||||
|
||||
self.kernel = kernel._schema.name
|
||||
self.cpp_kernel_overlad_name = kernel._schema.overload_name
|
||||
self.cpp_kernel_key = (
|
||||
f"{self.kernel.replace('::', '_')}_{self.cpp_kernel_overlad_name}"
|
||||
)
|
||||
|
||||
self.cpp_op_schema = get_cpp_op_schema(kernel)
|
||||
self.ordered_kwargs_for_cpp_kernel = [
|
||||
x.name for x in kernel._schema.arguments if x.kwarg_only
|
||||
]
|
||||
|
||||
# NOTE: [In-Place Collective Safety]
|
||||
# Between the initiation and completion of an in-place collective, the
|
||||
# input buffers are subject to both volatile reads and volatile writes.
|
||||
# They must not be read, written to or reused by another kernel. To ensure
|
||||
# the constraints, we model collective -> wait_tensor as as two-step
|
||||
# mutation of the input buffers.
|
||||
@classmethod
|
||||
def create_inplace(
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
) -> None:
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
example_output,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema,
|
||||
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
|
||||
for tensor_arg in tensor_args:
|
||||
tensor_arg.realize()
|
||||
|
||||
packed = cls(
|
||||
NoneLayout(tensor_args[0].get_device()),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema=schema,
|
||||
)
|
||||
pytree.tree_map(lambda x: MutationOutput(x.layout, x, packed), inputs)
|
||||
|
||||
# NOTE: [Out-of-Place Collective Safety]
|
||||
# Between the initiation and completion of an out-of-place collective:
|
||||
#
|
||||
# Input buffers:
|
||||
# - Are subject to volatile reads
|
||||
# - Can be read by another kernel
|
||||
# - Must not be written to or reused by another kernel
|
||||
#
|
||||
# Output buffers:
|
||||
# - Are subject to volatile writes
|
||||
# - Must not be read, written to or reused by another kernel
|
||||
#
|
||||
# To ensure the safety of input buffers without sacrificing read
|
||||
# availability, we add input buffers as read deps of wait_tensor kernels.
|
||||
#
|
||||
# To ensure the safety of output buffers, we model wait_tensor as a
|
||||
# mutation to the output buffer. Note we also assumes the user program being
|
||||
# correct and the output buffer is not consumed by kernels other than
|
||||
# wait_tensor.
|
||||
#
|
||||
# TODO(yifu): add a pre-grad pass to validate the correctness of collective
|
||||
# usage in the user program.
|
||||
@classmethod
|
||||
def create_out_of_place(
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
):
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
example_output,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema,
|
||||
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
|
||||
for tensor_arg in tensor_args:
|
||||
tensor_arg.realize()
|
||||
|
||||
if isinstance(example_output, list):
|
||||
device = cls.find_device(tensor_args, example_output)
|
||||
packed = cls(
|
||||
MultiOutputLayout(device),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema=schema,
|
||||
)
|
||||
packed.outputs = [
|
||||
MultiOutput(
|
||||
cls.tensor_to_layout(tensor),
|
||||
packed,
|
||||
[(list, i)],
|
||||
)
|
||||
for i, tensor in enumerate(example_output)
|
||||
]
|
||||
return packed.outputs
|
||||
else:
|
||||
packed = cls(
|
||||
cls.tensor_to_layout(example_output),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema=schema,
|
||||
)
|
||||
packed.outputs = [packed]
|
||||
return packed
|
||||
|
||||
|
||||
class _WaitKernel(_CollectiveKernel):
|
||||
def get_volatile_reads(self):
|
||||
inp = self.inputs[0]
|
||||
if isinstance(inp, _CollectiveKernel):
|
||||
# Out-of-place single-output
|
||||
return [inp.inputs[0]]
|
||||
elif isinstance(inp, MultiOutput):
|
||||
# Out-of-place multi-output
|
||||
coll = inp.inputs[0]
|
||||
assert isinstance(coll, _CollectiveKernel)
|
||||
_, idx = inp.indices[0]
|
||||
return [coll.inputs[idx]]
|
||||
else:
|
||||
# In-place requires no additional deps handling for volatile
|
||||
# reads since the inputs are mutated.
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def create_wait(cls, kernel, inp: TensorBox) -> None:
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
example_output,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema,
|
||||
) = cls.process_kernel(kernel, inp)
|
||||
packed = cls(
|
||||
NoneLayout(inp.get_device()),
|
||||
kernel,
|
||||
tensor_args,
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
schema=schema,
|
||||
)
|
||||
MutationOutput(inp.layout, inp, packed)
|
||||
|
||||
def get_read_writes(self):
|
||||
read_writes = super().get_read_writes()
|
||||
# See [Out-of-Place Collective Safety].
|
||||
volatile_reads = self.get_volatile_reads()
|
||||
for vr in volatile_reads:
|
||||
read_writes.reads.add(dependencies.StarDep(vr.get_name()))
|
||||
return read_writes
|
||||
|
||||
|
||||
# NB: recursive structure here reflects val_to_arg_str, avoid
|
||||
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
|
||||
# treatment
|
||||
|
||||
@ -5132,6 +5132,97 @@ try:
|
||||
)
|
||||
)
|
||||
|
||||
_c10d_functional = torch.ops._c10d_functional
|
||||
|
||||
@register_lowering(_c10d_functional.all_reduce)
|
||||
def _all_reduce(inp, reduce_op, group_name):
|
||||
inp = clone(inp)
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
_c10d_functional.all_reduce_.default, inp, reduce_op, group_name
|
||||
)
|
||||
return inp
|
||||
|
||||
@register_lowering(_c10d_functional.all_reduce_)
|
||||
def _all_reduce_(inp, reduce_op, group_name):
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
_c10d_functional.all_reduce_.default, inp, reduce_op, group_name
|
||||
)
|
||||
return inp
|
||||
|
||||
@register_lowering(_c10d_functional.all_reduce_coalesced)
|
||||
def _all_reduce_coalesced(inputs, reduce_op, group_name):
|
||||
inputs = [clone(inp) for inp in inputs]
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
_c10d_functional.all_reduce_coalesced_.default,
|
||||
inputs,
|
||||
reduce_op,
|
||||
group_name,
|
||||
)
|
||||
return inputs
|
||||
|
||||
@register_lowering(_c10d_functional.all_reduce_coalesced_)
|
||||
def _all_reduce_coalesced_(inputs, reduce_op, group_name):
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
_c10d_functional.all_reduce_coalesced_.default,
|
||||
inputs,
|
||||
reduce_op,
|
||||
group_name,
|
||||
)
|
||||
return inputs
|
||||
|
||||
@register_lowering(_c10d_functional.all_gather_into_tensor)
|
||||
def _all_gather_into_tensor(inp, group_size, group_name):
|
||||
return ir.TensorBox.create(
|
||||
ir._CollectiveKernel.create_out_of_place(
|
||||
_c10d_functional.all_gather_into_tensor.default,
|
||||
inp,
|
||||
group_size,
|
||||
group_name,
|
||||
)
|
||||
)
|
||||
|
||||
@register_lowering(_c10d_functional.all_gather_into_tensor_coalesced)
|
||||
def _all_gather_into_tensor_coalesced(inputs, group_size, group_name):
|
||||
return pytree.tree_map(
|
||||
ir.TensorBox.create,
|
||||
ir._CollectiveKernel.create_out_of_place(
|
||||
_c10d_functional.all_gather_into_tensor_coalesced.default,
|
||||
inputs,
|
||||
group_size,
|
||||
group_name,
|
||||
),
|
||||
)
|
||||
|
||||
@register_lowering(_c10d_functional.reduce_scatter_tensor)
|
||||
def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name):
|
||||
return ir.TensorBox.create(
|
||||
ir._CollectiveKernel.create_out_of_place(
|
||||
_c10d_functional.reduce_scatter_tensor.default,
|
||||
inp,
|
||||
reduce_op,
|
||||
group_size,
|
||||
group_name,
|
||||
)
|
||||
)
|
||||
|
||||
@register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced)
|
||||
def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name):
|
||||
return pytree.tree_map(
|
||||
ir.TensorBox.create,
|
||||
ir._CollectiveKernel.create_out_of_place(
|
||||
_c10d_functional.reduce_scatter_tensor_coalesced.default,
|
||||
inputs,
|
||||
reduce_op,
|
||||
group_size,
|
||||
group_name,
|
||||
),
|
||||
)
|
||||
|
||||
@register_lowering(_c10d_functional.wait_tensor)
|
||||
def _wait_tensor(inp):
|
||||
ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp)
|
||||
return inp
|
||||
|
||||
except ImportError:
|
||||
log.info(
|
||||
"Inductor support for distributed collectives depends on building torch.distributed"
|
||||
|
||||
@ -542,6 +542,12 @@ def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
|
||||
def _all_reduce_coalesced_meta(self, *args):
|
||||
return [torch.empty_like(t) for t in self]
|
||||
|
||||
def _all_reduce__meta(inp, *args):
|
||||
return inp
|
||||
|
||||
def _all_reduce_coalesced__meta(inputs, *args):
|
||||
return inputs
|
||||
|
||||
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
|
||||
def mk_out_tensor(input):
|
||||
out_size = list(input.size())
|
||||
@ -577,15 +583,15 @@ def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name
|
||||
for input in inputs
|
||||
]
|
||||
|
||||
def _reduce_scatter_tensor_native_meta(input, group_size, group_name):
|
||||
shape = list(input.size())
|
||||
def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
|
||||
shape = list(inp.size())
|
||||
shape[0] //= group_size
|
||||
return input.new_empty(shape)
|
||||
return inp.new_empty(shape)
|
||||
|
||||
def _reduce_scatter_tensor_coalesced_native_meta(inputs, group_size, group_name):
|
||||
def _reduce_scatter_tensor_coalesced_native_meta(inputs, reduce_op, group_size, group_name):
|
||||
return [
|
||||
_reduce_scatter_tensor_native_meta(input, group_size, group_name)
|
||||
for input in inputs
|
||||
_reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
|
||||
for inp in inputs
|
||||
]
|
||||
|
||||
def _register_ops():
|
||||
@ -620,7 +626,9 @@ if not torch._running_with_deploy():
|
||||
|
||||
_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
|
||||
_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
|
||||
_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
|
||||
_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_gather_into_tensor_coalesced", _all_gather_into_tensor_coalesced_native_meta, "Meta")
|
||||
|
||||
Reference in New Issue
Block a user