mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
369 lines
13 KiB
Python
369 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._inductor.utils import is_symbolic
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import config, ir
|
|
from .virtualized import V
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# NOTE [lowering-time collective optimization]
|
|
#
|
|
# In collective communication libraries such as NCCL, every rank maintains
|
|
# communication buffers that are remotely accessible by some peers. Depending
|
|
# on the underlying transport, remote accessibility may be established via
|
|
# mechanisms such as ib_reg_mr, CUDA P2P, or CUDA multicast. Typically, these
|
|
# buffers are private to the communication library by default, and
|
|
# communication ops copy user data in and out of these buffers.
|
|
#
|
|
# To prevent these copies, an optimization commonly known as "user buffer
|
|
# registration" can be employed. This allows direct establishment of remote
|
|
# accessibility on user buffers, eliminating the need for copying. However,
|
|
# this optimization introduces stringent usage requirements, which are
|
|
# typically hard to satisfy without being intrusive to the user code:
|
|
#
|
|
# - Establishing remote accessibility is expensive and often done ahead of
|
|
# time. In such implementations, all ranks must agree on the set of allocations
|
|
# used for every collective op. Failing to meet this requirement can
|
|
# lead to runtime errors or even silent correctness issues.
|
|
# - Even if the collective communication library supports gracefully falling
|
|
# back to "unregistered" implementations, the fallback mechanism would nullify
|
|
# the optimization.
|
|
# - Some communication mechanisms impose stricter requirements than others. For
|
|
# example, CUDA's multicast + multi-mem instructions require all ranks to agree
|
|
# not only on the allocations used for every collective but also on the offsets
|
|
# within these allocations.
|
|
#
|
|
# To support all different mechanisms with optimal results, we aim to satisfy
|
|
# the strictest requirement for this family of optimizations - we ensures that
|
|
# every collective op invocation is guaranteed to operate on the same
|
|
# allocation, at the same offset, in every iteration.
|
|
#
|
|
# For eligible collective ops, we identify communication buffers at lowering
|
|
# time and optionally choose to lower the op to a different kernel
|
|
# (ommunication libraries like NCCL handle both registered and non-registered
|
|
# buffers transparently within the same op, though some may require different
|
|
# ops for different cases). Later, the codegen will perform "persistent
|
|
# allocation" to satisfy the aforementioned constraints, and optionally,
|
|
# perform buffer planning to optimize overall memory usage.
|
|
def can_realize_as_comm_buffer(
|
|
x: ir.TensorBox, comm_buffer_type: ir.CommBufferType
|
|
) -> bool:
|
|
"""
|
|
Check if an input can be realized as a comm buffer of the specified
|
|
`comm_buffer_type`.
|
|
"""
|
|
data = _get_data(x)
|
|
|
|
if isinstance(data, ir.Loops):
|
|
return True
|
|
|
|
layout = data.get_output_spec()
|
|
if isinstance(layout, ir.CommBufferLayout):
|
|
return True
|
|
|
|
if isinstance(layout, ir.FlexibleLayout) and not is_symbolic(data.get_numel()):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def realize_as_comm_buffer(
|
|
x: ir.TensorBox, comm_buffer_type: ir.CommBufferType, group_name: str
|
|
) -> None:
|
|
"""
|
|
Realize an input as a comm buffer of the specified `comm_buffer_type`.
|
|
|
|
Specifically, this realizes the underlying buffer if it's still unrealized
|
|
and changes the layout of the buffer to `ir.CommBufferLayout`.
|
|
"""
|
|
x.realize()
|
|
buffer = _get_data(x)
|
|
assert isinstance(buffer, ir.Buffer)
|
|
|
|
layout = buffer.get_output_spec()
|
|
if isinstance(layout, ir.CommBufferLayout):
|
|
return
|
|
|
|
if not isinstance(layout, ir.FlexibleLayout):
|
|
raise AssertionError(
|
|
"A buffer can only be realized as a comm buffer if it "
|
|
f"has `FlexibleLayout` (got {layout})."
|
|
)
|
|
|
|
if is_symbolic(buffer.get_numel()):
|
|
raise AssertionError(
|
|
"A buffer with symbolic shape cannot be converted to "
|
|
f"a comm buffer (got {layout})."
|
|
)
|
|
|
|
buffer.layout = ir.CommBufferLayout(
|
|
layout=layout,
|
|
comm_buffer_type=comm_buffer_type,
|
|
group_name=group_name,
|
|
)
|
|
|
|
|
|
def _get_data(x: ir.TensorBox) -> ir.IRNode:
|
|
if isinstance(x.data, ir.BaseView):
|
|
# TensorBox -> *View -> StorageBox -> IRNode
|
|
node = x.data.unwrap_view()
|
|
assert isinstance(node, (ir.BaseView, ir.MutableBox))
|
|
return node.data
|
|
elif isinstance(x.data, ir.StorageBox):
|
|
# TensorBox -> StorageBox -> IRNode
|
|
return x.data.data
|
|
else:
|
|
raise AssertionError(
|
|
"Expect the data attr of a `TensorBox` to be either "
|
|
f"an `ir.BaseView` or `ir.StorageBox` (got {x.data})."
|
|
)
|
|
|
|
|
|
_bufs_to_skip_wait = OrderedSet[tuple[int, str]]()
|
|
|
|
|
|
def mark_as_skip_wait(x: ir.IRNode) -> None:
|
|
"""
|
|
If a non-blocking collective is lowered as a blocking collective, the wait
|
|
node in the original graph becomes useless and we can skip the lowering it.
|
|
"""
|
|
_bufs_to_skip_wait.add((id(V.graph), x.get_name()))
|
|
|
|
|
|
def should_skip_wait(x: ir.IRNode) -> bool:
|
|
return (id(V.graph), x.get_name()) in _bufs_to_skip_wait
|
|
|
|
|
|
def _should_lower_as_one_shot_all_reduce(
|
|
inp: ir.TensorBox, reduce_op: str, group_name: str
|
|
):
|
|
from torch.distributed._symmetric_memory import is_symm_mem_enabled_for_group
|
|
|
|
inp_size = inp.get_numel() * inp.get_dtype().itemsize
|
|
return (
|
|
config._collective.auto_select
|
|
and is_symm_mem_enabled_for_group(group_name)
|
|
and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM)
|
|
and reduce_op == "sum"
|
|
and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes
|
|
)
|
|
|
|
|
|
def _one_shot_all_reduce(inp: ir.TensorBox, reduce_op, group_name):
|
|
realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM, group_name)
|
|
return pytree.tree_map(
|
|
ir.TensorBox.create,
|
|
ir.FallbackKernel.create(
|
|
torch.ops.symm_mem.one_shot_all_reduce.default,
|
|
inp,
|
|
reduce_op,
|
|
group_name,
|
|
),
|
|
)
|
|
|
|
|
|
def register_comm_lowerings():
|
|
try:
|
|
torch.ops._c10d_functional.all_reduce
|
|
except AttributeError:
|
|
log.info(
|
|
"Inductor support for distributed collectives depends on building "
|
|
"torch.distributed"
|
|
)
|
|
return
|
|
|
|
from .lowering import (
|
|
add_layout_constraint,
|
|
clone,
|
|
constrain_to_fx_strides,
|
|
copy_,
|
|
register_lowering,
|
|
)
|
|
|
|
def register_comm_lowering(fn):
|
|
add_layout_constraint(fn, constrain_to_fx_strides)
|
|
return register_lowering(fn)
|
|
|
|
c10d = torch.ops._c10d_functional
|
|
|
|
@register_comm_lowering(c10d.all_reduce) # type: ignore[misc]
|
|
def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.TensorBox:
|
|
if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name):
|
|
return _one_shot_all_reduce(inp, reduce_op, group_name)
|
|
|
|
# Lower as c10d.all_reduce_
|
|
inp = clone(inp)
|
|
if config.reorder_for_compute_comm_overlap:
|
|
# The horizontal fusion of this clone often severely delays the
|
|
# scheduling of the all_reduce_ node. Horizontally fusing this
|
|
# clone can almost never out-perform scheduling the all_reduce_
|
|
# earlier. Also in most cases, this clone is eliminated via
|
|
# in-place reuse. Therefore, we tell the scheduler to not fuse it.
|
|
inp.realize()
|
|
V.graph.no_fuse_buffer_names.add(inp.get_name())
|
|
# pyrefly: ignore # bad-assignment
|
|
inp = ir.ExternKernel.require_contiguous(inp)
|
|
# Because we are lowering as inplace c10d.all_reduce_, we should generate
|
|
# _AllReduce_Kernel instead of _AllReduceKernel.
|
|
ir._AllReduce_Kernel.create_inplace(
|
|
c10d.all_reduce_.default,
|
|
inp, # type: ignore[arg-type]
|
|
reduce_op,
|
|
group_name, # type: ignore[arg-type]
|
|
)
|
|
return inp # type: ignore[return-value]
|
|
|
|
@register_comm_lowering(c10d.all_reduce_) # type: ignore[misc]
|
|
def _all_reduce_(
|
|
inp: ir.TensorBox, reduce_op: str, group_name: str
|
|
) -> ir.TensorBox:
|
|
if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name):
|
|
ret = copy_(
|
|
inp,
|
|
_one_shot_all_reduce(inp, reduce_op, group_name),
|
|
)
|
|
mark_as_skip_wait(ret)
|
|
return inp
|
|
|
|
# Lower as c10d.all_reduce_
|
|
# pyrefly: ignore # bad-assignment
|
|
inp = ir.ExternKernel.require_contiguous(inp)
|
|
ir._AllReduce_Kernel.create_inplace(
|
|
c10d.all_reduce_.default,
|
|
inp, # type: ignore[arg-type]
|
|
reduce_op,
|
|
group_name, # type: ignore[arg-type]
|
|
)
|
|
return inp # type: ignore[return-value]
|
|
|
|
@register_comm_lowering(c10d.all_reduce_coalesced)
|
|
def _all_reduce_coalesced(inputs, reduce_op, group_name):
|
|
inputs = [clone(inp) for inp in inputs]
|
|
ir._CollectiveKernel.create_inplace(
|
|
c10d.all_reduce_coalesced_.default,
|
|
inputs,
|
|
reduce_op,
|
|
group_name,
|
|
)
|
|
return inputs
|
|
|
|
@register_comm_lowering(c10d.all_reduce_coalesced_)
|
|
def _all_reduce_coalesced_(inputs, reduce_op, group_name):
|
|
ir._CollectiveKernel.create_inplace(
|
|
c10d.all_reduce_coalesced_.default,
|
|
inputs,
|
|
reduce_op,
|
|
group_name,
|
|
)
|
|
return inputs
|
|
|
|
def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode:
|
|
node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args)
|
|
assert isinstance(node, ir.IRNode)
|
|
return ir.TensorBox.create(node)
|
|
|
|
@register_comm_lowering(c10d.all_gather_into_tensor)
|
|
def _all_gather_into_tensor(inp, group_size, group_name):
|
|
return _create_out_of_place(
|
|
c10d.all_gather_into_tensor.default,
|
|
inp,
|
|
group_size,
|
|
group_name,
|
|
)
|
|
|
|
@register_comm_lowering(c10d.all_gather_into_tensor_coalesced)
|
|
def _all_gather_into_tensor_coalesced(inputs, group_size, group_name):
|
|
return pytree.tree_map(
|
|
ir.TensorBox.create,
|
|
ir._CollectiveKernel.create_out_of_place(
|
|
c10d.all_gather_into_tensor_coalesced.default,
|
|
inputs,
|
|
group_size,
|
|
group_name,
|
|
),
|
|
)
|
|
|
|
@register_comm_lowering(c10d.all_gather_into_tensor_out)
|
|
def _all_gather_into_tensor_out(inp, group_size, group_name, *, out):
|
|
ir._CollectiveKernel.create_inplace(
|
|
c10d.all_gather_into_tensor_out.default,
|
|
inp,
|
|
group_size,
|
|
group_name,
|
|
out=out,
|
|
)
|
|
return out
|
|
|
|
@register_comm_lowering(c10d.reduce_scatter_tensor)
|
|
def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name):
|
|
return _create_out_of_place(
|
|
c10d.reduce_scatter_tensor.default,
|
|
inp,
|
|
reduce_op,
|
|
group_size,
|
|
group_name,
|
|
)
|
|
|
|
@register_comm_lowering(c10d.reduce_scatter_tensor_coalesced)
|
|
def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name):
|
|
return pytree.tree_map(
|
|
ir.TensorBox.create,
|
|
ir._CollectiveKernel.create_out_of_place(
|
|
c10d.reduce_scatter_tensor_coalesced.default,
|
|
inputs,
|
|
reduce_op,
|
|
group_size,
|
|
group_name,
|
|
),
|
|
)
|
|
|
|
@register_comm_lowering(c10d.all_to_all_single)
|
|
def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name):
|
|
return _create_out_of_place(
|
|
c10d.all_to_all_single.default,
|
|
inp,
|
|
output_split_sizes,
|
|
input_split_sizes,
|
|
group_name,
|
|
)
|
|
|
|
@register_comm_lowering(c10d.broadcast)
|
|
def _broadcast(inp, src, group_name):
|
|
inp = clone(inp)
|
|
ir._CollectiveKernel.create_inplace(
|
|
c10d.broadcast_.default, inp, src, group_name
|
|
)
|
|
return inp
|
|
|
|
@register_comm_lowering(c10d.broadcast_)
|
|
def _broadcast_(inp, src, group_name):
|
|
ir._CollectiveKernel.create_inplace(
|
|
c10d.broadcast_.default, inp, src, group_name
|
|
)
|
|
return inp
|
|
|
|
@register_comm_lowering(torch.ops._dtensor.shard_dim_alltoall)
|
|
def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name):
|
|
return _create_out_of_place(
|
|
torch.ops._dtensor.shard_dim_alltoall.default,
|
|
inp,
|
|
gather_dim,
|
|
shard_dim,
|
|
group_name,
|
|
)
|
|
|
|
@register_comm_lowering(c10d.wait_tensor)
|
|
def _wait_tensor(inp):
|
|
if should_skip_wait(inp):
|
|
return inp
|
|
|
|
ir._WaitKernel.create_wait(c10d.wait_tensor.default, inp)
|
|
return inp
|