[PP] Let PP split BlockMask into micro-BlockMask (#164111)

BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111
Approved by: https://github.com/H-Huang
This commit is contained in:
Chien-Chin Huang
2025-10-07 11:27:01 -07:00
committed by PyTorch MergeBot
parent 483f4e0db9
commit e3ae80fc03
2 changed files with 266 additions and 99 deletions

View File

@ -2,10 +2,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import operator
from typing import Any, Optional
from typing import Any, Optional, Sequence
import torch
from torch.fx.node import map_aggregate
from torch.nn.attention.flex_attention import BlockMask
from torch.utils._pytree import tree_flatten, tree_unflatten
@ -115,6 +116,107 @@ class _Replicate:
pass
def _split_block_mask(
block_mask: BlockMask,
num_chunks: int,
) -> list[BlockMask]:
"""Given a block mask, split the block mask along the batch dimension (dim0).
Args:
block_mask: Block mask to split
num_chunks: Number of chunks to split the block mask into
Returns:
chunk_block_masks: List of chunked block masks
"""
assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
"Block mask has fewer batch size than the number of chunks. "
)
batch_dim = 0
kv_num_blocks_chunks = torch.tensor_split(
block_mask.kv_num_blocks, num_chunks, batch_dim
)
kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
full_kv_num_blocks_chunks = (
torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
if block_mask.full_kv_num_blocks is not None
else [None] * num_chunks
)
full_kv_indices_chunks = (
torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
if block_mask.full_kv_indices is not None
else [None] * num_chunks
)
chunk_block_masks = []
batch_offset = 0
for chunk_idx in range(num_chunks):
def create_mask_mod(idx):
def batch_offset_mask_mod(b, h, q_idx, kv_idx):
b_offset = torch.full_like(b, idx)
return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
return batch_offset_mask_mod
chunk_block_masks.append(
BlockMask.from_kv_blocks(
kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
kv_indices=kv_indices_chunks[chunk_idx],
full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
full_kv_indices=full_kv_indices_chunks[chunk_idx],
BLOCK_SIZE=block_mask.BLOCK_SIZE,
mask_mod=create_mask_mod(batch_offset),
seq_lengths=block_mask.seq_lengths,
)
)
batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
return chunk_block_masks
def _split_tensor(
tensor: torch.Tensor,
spec: TensorChunkSpec,
num_chunks: int,
) -> Sequence[torch.Tensor]:
"""Given a tensor, and a chunking spec, split the tensor.
Args:
tensor: Tensor to split
spec: Chunking spec
num_chunks: Number of chunks to split the tensor into
Returns:
chunk_tensors: List of chunked tensors
"""
assert tensor.size(spec.split_dim) >= num_chunks, (
f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
)
chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
if not _debug_mask_minibatches:
return chunk_tensors
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(tensor)
upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(spec.split_dim)
return expanded_chunks
def _shard_dict_of_args(
args_dict,
args_chunk_spec,
@ -132,114 +234,60 @@ def _shard_dict_of_args(
Returns:
args_split: List of sharded args
"""
# Stage 1+2: flatten and shard/replicate
# args_sharded_replicated : [num args, num flat values, num chunks]
args_sharded_replicated = {}
arg_specs = []
real_num_chunks = num_chunks
first_tensor = True
if not args_dict:
return [{} for _ in range(num_chunks)]
assert len(args_dict) == len(args_chunk_spec), (
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
f"args_dict.keys() = {list(args_dict.keys())} "
f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
)
assert args_chunk_spec is not None # Should have been set by caller
for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg)
arg_specs.append(spec)
values, tree_spec = tree_flatten(args_dict)
chunk_specs, _ = tree_flatten(args_chunk_spec)
chunk_spec = args_chunk_spec[arg_key]
assert chunk_spec is not None # Should have been set by caller
chunk_spec_flat, _ = tree_flatten(chunk_spec)
if len(flat) != len(chunk_spec_flat):
# First check and find the actual number of chunks
split_sizes = []
for v, spec in zip(values, chunk_specs, strict=True):
if spec is _Replicate:
split_sizes.append(num_chunks)
elif isinstance(v, torch.Tensor):
assert isinstance(spec, TensorChunkSpec)
split_sizes.append(v.size(spec.split_dim))
elif isinstance(v, BlockMask):
assert isinstance(spec, TensorChunkSpec)
assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
split_sizes.append(v.kv_num_blocks.size(0))
else:
raise ValueError(
f"Argument value {arg} did not have the same number of "
f"values as as chunk spec {chunk_spec}"
f"Unsupported chunk spec: {spec} and value: {v} combination."
)
result_num_chunks = min(*split_sizes, num_chunks)
flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
for v, spec in zip(values, chunk_specs, strict=True):
v_splits: Sequence[Any] = []
if spec is _Replicate:
v_splits = [v] * result_num_chunks
elif isinstance(v, torch.Tensor):
v_splits = _split_tensor(v, spec, result_num_chunks)
elif isinstance(v, BlockMask):
v_splits = _split_block_mask(v, result_num_chunks)
else:
raise ValueError(
f"Unsupported chunk spec: {spec} and value: {v} combination."
)
sharded_arg_flat = []
for _flat_split_result, _v_split in zip(
flat_split_results, v_splits, strict=True
):
_flat_split_result.append(_v_split)
for v, chunk_v in zip(flat, chunk_spec_flat):
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
sharded_arg_flat.append([v] * real_num_chunks)
elif isinstance(chunk_v, TensorChunkSpec):
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
# If it's a collection type, split it as you would expect. Otherwise,
# Throw an error
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
v_split_dim_size = v.size(chunk_v.split_dim)
if v_split_dim_size < real_num_chunks:
if first_tensor:
# We can only adjust number of chunks when we hit this
# issue at the first tensor encountered
logger.warning(
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
)
real_num_chunks = v_split_dim_size
else:
raise RuntimeError(
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
f"smaller than the number of chunks {num_chunks}. "
"PiPPy cannot reduce the number of chunks because "
"other arguments have bigger chunk-dimension sizes. "
"Please adjust your num_chunks setting."
)
chunk_tensors = torch.tensor_split(
v, real_num_chunks, chunk_v.split_dim
)
if _debug_mask_minibatches:
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(v)
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[chunk_v.split_dim] = slice(
split_dim_idx, upper_idx
)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
sharded_arg_flat.append(expanded_chunks)
else:
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
first_tensor = False
else:
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
args_sharded_replicated[arg_key] = sharded_arg_flat
# chunks_flat : [num chunks, num args, num flat values]
chunks_flat = []
for chunk_idx in range(real_num_chunks):
chunk_args = {}
for key, arg in args_sharded_replicated.items():
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args)
# args_split : [num chunks, num args]
args_split = []
for chunk in chunks_flat:
per_chunk_args = {}
assert len(arg_specs) == len(chunk)
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
args_split.append(per_chunk_args)
return args_split
return [
tree_unflatten(_flat_split_result, tree_spec)
for _flat_split_result in flat_split_results
]
def split_args_kwargs_into_chunks(