Files
pytorch/torch/distributed/pipelining/microbatch.py
Chien-Chin Huang e3ae80fc03 [PP] Let PP split BlockMask into micro-BlockMask (#164111)
BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111
Approved by: https://github.com/H-Huang
2025-10-07 23:25:34 +00:00

518 lines
17 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import operator
from typing import Any, Optional, Sequence
import torch
from torch.fx.node import map_aggregate
from torch.nn.attention.flex_attention import BlockMask
from torch.utils._pytree import tree_flatten, tree_unflatten
__all__ = [
"TensorChunkSpec",
"split_args_kwargs_into_chunks",
"merge_chunks",
]
logger = logging.getLogger(__name__)
"""
_debug_mask_minibatches specifies to send masked versions of the mini-batch
through instead of micro-batch slices--this can be used for more stable
numerical testing (see [A Note About Correctness Testing])
"""
_debug_mask_minibatches = False
class _CustomReducer:
"""
Custom reducer class that can be used to specify a custom operation that
reduces losses of multiple microbatches into one value.
Example:
>>> # xdoctest: +SKIP
>>> sum_reducer = _CustomReducer(
>>> torch.tensor(0.0),
>>> lambda a, b: a + b
>>> )
"""
def __init__(self, init_value, reduce_fn):
self.init_value = init_value
self.reduce_fn = reduce_fn
class _LossReducer(_CustomReducer):
pass
sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
# Default chunking dimension is 0. This is used for the case where the user did
# not specify a chunking dimension.
DEFAULT_CHUNK_DIM = 0
class TensorChunkSpec:
"""
Class used to specify chunking of inputs
"""
def __init__(self, split_dim):
self.split_dim = split_dim
split_dim: int
def __repr__(self):
return (
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
)
def __str__(self):
return f"TensorChunkSpec({self.split_dim})"
@staticmethod
def from_tuple(
chunk_dims: tuple[int, ...],
):
"""
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return args_chunk_spec
@staticmethod
def from_dict(
chunk_dims: dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
dictionary of chunk dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return kwargs_chunk_spec
# Class used to specify replication of inputs
class _Replicate:
pass
def _split_block_mask(
block_mask: BlockMask,
num_chunks: int,
) -> list[BlockMask]:
"""Given a block mask, split the block mask along the batch dimension (dim0).
Args:
block_mask: Block mask to split
num_chunks: Number of chunks to split the block mask into
Returns:
chunk_block_masks: List of chunked block masks
"""
assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
"Block mask has fewer batch size than the number of chunks. "
)
batch_dim = 0
kv_num_blocks_chunks = torch.tensor_split(
block_mask.kv_num_blocks, num_chunks, batch_dim
)
kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
full_kv_num_blocks_chunks = (
torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
if block_mask.full_kv_num_blocks is not None
else [None] * num_chunks
)
full_kv_indices_chunks = (
torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
if block_mask.full_kv_indices is not None
else [None] * num_chunks
)
chunk_block_masks = []
batch_offset = 0
for chunk_idx in range(num_chunks):
def create_mask_mod(idx):
def batch_offset_mask_mod(b, h, q_idx, kv_idx):
b_offset = torch.full_like(b, idx)
return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
return batch_offset_mask_mod
chunk_block_masks.append(
BlockMask.from_kv_blocks(
kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
kv_indices=kv_indices_chunks[chunk_idx],
full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
full_kv_indices=full_kv_indices_chunks[chunk_idx],
BLOCK_SIZE=block_mask.BLOCK_SIZE,
mask_mod=create_mask_mod(batch_offset),
seq_lengths=block_mask.seq_lengths,
)
)
batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
return chunk_block_masks
def _split_tensor(
tensor: torch.Tensor,
spec: TensorChunkSpec,
num_chunks: int,
) -> Sequence[torch.Tensor]:
"""Given a tensor, and a chunking spec, split the tensor.
Args:
tensor: Tensor to split
spec: Chunking spec
num_chunks: Number of chunks to split the tensor into
Returns:
chunk_tensors: List of chunked tensors
"""
assert tensor.size(spec.split_dim) >= num_chunks, (
f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
)
chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
if not _debug_mask_minibatches:
return chunk_tensors
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(tensor)
upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(spec.split_dim)
return expanded_chunks
def _shard_dict_of_args(
args_dict,
args_chunk_spec,
num_chunks,
):
"""
Given a dictionary of args, and a dictionary of chunking specs, shard the
args according to the chunking specs.
Args:
args_dict: Dictionary of args
args_chunk_spec: Dictionary of chunking specs
num_chunks: Number of chunks to shard the args into
Returns:
args_split: List of sharded args
"""
if not args_dict:
return [{} for _ in range(num_chunks)]
assert len(args_dict) == len(args_chunk_spec), (
f"args_dict.keys() = {list(args_dict.keys())} "
f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
)
assert args_chunk_spec is not None # Should have been set by caller
values, tree_spec = tree_flatten(args_dict)
chunk_specs, _ = tree_flatten(args_chunk_spec)
# First check and find the actual number of chunks
split_sizes = []
for v, spec in zip(values, chunk_specs, strict=True):
if spec is _Replicate:
split_sizes.append(num_chunks)
elif isinstance(v, torch.Tensor):
assert isinstance(spec, TensorChunkSpec)
split_sizes.append(v.size(spec.split_dim))
elif isinstance(v, BlockMask):
assert isinstance(spec, TensorChunkSpec)
assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
split_sizes.append(v.kv_num_blocks.size(0))
else:
raise ValueError(
f"Unsupported chunk spec: {spec} and value: {v} combination."
)
result_num_chunks = min(*split_sizes, num_chunks)
flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
for v, spec in zip(values, chunk_specs, strict=True):
v_splits: Sequence[Any] = []
if spec is _Replicate:
v_splits = [v] * result_num_chunks
elif isinstance(v, torch.Tensor):
v_splits = _split_tensor(v, spec, result_num_chunks)
elif isinstance(v, BlockMask):
v_splits = _split_block_mask(v, result_num_chunks)
else:
raise ValueError(
f"Unsupported chunk spec: {spec} and value: {v} combination."
)
for _flat_split_result, _v_split in zip(
flat_split_results, v_splits, strict=True
):
_flat_split_result.append(_v_split)
return [
tree_unflatten(_flat_split_result, tree_spec)
for _flat_split_result in flat_split_results
]
def split_args_kwargs_into_chunks(
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
) -> tuple[list[tuple], list[dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
Args:
args: Tuple of args
kwargs: Dict of kwargs
chunks: Number of chunks to split the args and kwargs into
args_chunk_spec: chunking specs for args, in same shape as args
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
Returns:
args_split: List of sharded args
kwargs_split: List of sharded kwargs
"""
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
# and `kwargs_chunk_spec` specifications. The steps are as follows:
#
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
# To use a running example: suppose our inputs look like
#
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
# (kwargs not shown but it's a similar process)
#
# Then for this step we would end up with
#
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
#
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
#
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
#
# 3. Rotate the nesting order such that chunks are the outer dimension
#
# args_chunks = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 4. Unflatten each chunk according to the spec
#
# args_chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# TODO: _debug_mask_minibatches
# Handle the case where kwargs is None
if kwargs is None:
kwargs = {}
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
# their format and use default chunking along dim 0
if args_chunk_spec is None:
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
if kwargs_chunk_spec is None:
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
chunks,
)
real_num_chunks = len(args_split_dict)
kwargs_split = _shard_dict_of_args(
kwargs,
kwargs_chunk_spec,
real_num_chunks,
)
if len(kwargs_split) < real_num_chunks:
# In case kwargs are sharded into less chunks
# e.g. when `args` has no tensor, just values
real_num_chunks = len(kwargs_split)
# Re-shard args
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
real_num_chunks,
)
if len(args_split_dict) != len(kwargs_split):
raise RuntimeError(
"args and kwargs are split into different number of chunks: "
f"{len(args_split_dict)}, {len(kwargs_split)}"
)
args_split = [
tuple(chunk_args[i] for i in range(len(chunk_args)))
for chunk_args in args_split_dict
]
return args_split, kwargs_split
def merge_chunks(
chunks: list[Any],
chunk_spec,
):
"""
Given a list of chunks, merge them into a single value according to
the chunk spec.
Args:
chunks: list of chunks
chunk_spec: Chunking spec for the chunks
Returns:
value: Merged value
"""
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
# steps are similar to the steps in that function but in reverse. Given the
# input values:
#
# chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# args_spec = ([None, [None, TensorChunkSpec]], None)
#
# 1. Flatten the chunks according to the chunk_spec
#
# chunks_flat = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 2. Rotate the nesting order such that chunks are the inner dimension
#
# value_inner = ([A, B, [C_1, C_2]], D)
#
# 3. Concatenate sharded arguments
#
# value_combined = ([A, B, C], D)
#
# 4. Unflatten the combined args given the spec
#
# value = ([A, [B, C]], D)
# Preliminary: flatten the chunk spec
if chunk_spec is not None:
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
else:
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
# Stage 1: flatten chunks
# chunks_flattened : [num chunks, num args]
chunks_flattened = []
for chunk in chunks:
chunk_flattened, _ = tree_flatten(chunk)
if len(chunk_flattened) != len(spec_flattened):
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
chunks_flattened.append(chunk_flattened)
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
# concatenate sharded operands
# args_flattened : [num args]
args_flattened = []
for arg_idx, arg in enumerate(spec_flattened):
if isinstance(arg, TensorChunkSpec):
partial_values = [
chunks_flattened[chunk_idx][arg_idx]
for chunk_idx in range(len(chunks_flattened))
]
if _debug_mask_minibatches:
# Infer size of individual chunks by running `tensor_split` again
overall_shape = partial_values[0].shape
for val in partial_values[1:]:
assert val.shape == overall_shape
meta_chunks = torch.tensor_split(
torch.empty(*overall_shape, device="meta"),
sections=len(partial_values),
dim=arg.split_dim,
)
values_to_cat = []
chunk_start_idx = 0
assert len(partial_values) == len(meta_chunks)
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
slice_indices = [slice(None, None, None)] * partial_value.ndim
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
sliced = partial_value[slice_indices]
values_to_cat.append(sliced)
chunk_start_idx = chunk_end_idx
else:
values_to_cat = partial_values
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
elif isinstance(arg, _CustomReducer):
reduced_val = arg.init_value
for chunk_idx in range(len(chunks_flattened)):
reduced_val = arg.reduce_fn(
reduced_val, chunks_flattened[chunk_idx][arg_idx]
)
args_flattened.append(reduced_val)
else:
value = chunks_flattened[0][arg_idx]
for chunk_idx in range(1, len(chunks_flattened)):
assert chunks_flattened[chunk_idx][arg_idx] == value
args_flattened.append(value)
# Stage 4: Unflatten combined args
return tree_unflatten(args_flattened, flatten_spec)