mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111 Approved by: https://github.com/H-Huang
518 lines
17 KiB
Python
518 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
import operator
|
|
from typing import Any, Optional, Sequence
|
|
|
|
import torch
|
|
from torch.fx.node import map_aggregate
|
|
from torch.nn.attention.flex_attention import BlockMask
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
|
|
|
__all__ = [
|
|
"TensorChunkSpec",
|
|
"split_args_kwargs_into_chunks",
|
|
"merge_chunks",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
"""
|
|
_debug_mask_minibatches specifies to send masked versions of the mini-batch
|
|
through instead of micro-batch slices--this can be used for more stable
|
|
numerical testing (see [A Note About Correctness Testing])
|
|
"""
|
|
_debug_mask_minibatches = False
|
|
|
|
|
|
class _CustomReducer:
|
|
"""
|
|
Custom reducer class that can be used to specify a custom operation that
|
|
reduces losses of multiple microbatches into one value.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> sum_reducer = _CustomReducer(
|
|
>>> torch.tensor(0.0),
|
|
>>> lambda a, b: a + b
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(self, init_value, reduce_fn):
|
|
self.init_value = init_value
|
|
self.reduce_fn = reduce_fn
|
|
|
|
|
|
class _LossReducer(_CustomReducer):
|
|
pass
|
|
|
|
|
|
sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
|
|
|
|
# Default chunking dimension is 0. This is used for the case where the user did
|
|
# not specify a chunking dimension.
|
|
DEFAULT_CHUNK_DIM = 0
|
|
|
|
|
|
class TensorChunkSpec:
|
|
"""
|
|
Class used to specify chunking of inputs
|
|
"""
|
|
|
|
def __init__(self, split_dim):
|
|
self.split_dim = split_dim
|
|
|
|
split_dim: int
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"TensorChunkSpec({self.split_dim})"
|
|
|
|
@staticmethod
|
|
def from_tuple(
|
|
chunk_dims: tuple[int, ...],
|
|
):
|
|
"""
|
|
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
|
|
dimensions (int's).
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # There are three positional arguments to the model, and
|
|
>>> # we are chunking them along dimension 0, 0 and 1, respectively
|
|
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
|
|
"""
|
|
args_chunk_spec = map_aggregate(
|
|
chunk_dims,
|
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
|
)
|
|
return args_chunk_spec
|
|
|
|
@staticmethod
|
|
def from_dict(
|
|
chunk_dims: dict[str, int],
|
|
):
|
|
"""
|
|
A helper for creating a dictionary of `TensorChunkSpec` from a
|
|
dictionary of chunk dimensions (int's).
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
|
|
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
|
|
"""
|
|
kwargs_chunk_spec = map_aggregate(
|
|
chunk_dims,
|
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
|
)
|
|
return kwargs_chunk_spec
|
|
|
|
|
|
# Class used to specify replication of inputs
|
|
class _Replicate:
|
|
pass
|
|
|
|
|
|
def _split_block_mask(
|
|
block_mask: BlockMask,
|
|
num_chunks: int,
|
|
) -> list[BlockMask]:
|
|
"""Given a block mask, split the block mask along the batch dimension (dim0).
|
|
|
|
Args:
|
|
block_mask: Block mask to split
|
|
num_chunks: Number of chunks to split the block mask into
|
|
|
|
Returns:
|
|
chunk_block_masks: List of chunked block masks
|
|
"""
|
|
|
|
assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
|
|
"Block mask has fewer batch size than the number of chunks. "
|
|
)
|
|
|
|
batch_dim = 0
|
|
kv_num_blocks_chunks = torch.tensor_split(
|
|
block_mask.kv_num_blocks, num_chunks, batch_dim
|
|
)
|
|
kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
|
|
full_kv_num_blocks_chunks = (
|
|
torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
|
|
if block_mask.full_kv_num_blocks is not None
|
|
else [None] * num_chunks
|
|
)
|
|
full_kv_indices_chunks = (
|
|
torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
|
|
if block_mask.full_kv_indices is not None
|
|
else [None] * num_chunks
|
|
)
|
|
|
|
chunk_block_masks = []
|
|
batch_offset = 0
|
|
for chunk_idx in range(num_chunks):
|
|
|
|
def create_mask_mod(idx):
|
|
def batch_offset_mask_mod(b, h, q_idx, kv_idx):
|
|
b_offset = torch.full_like(b, idx)
|
|
return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
|
|
|
|
return batch_offset_mask_mod
|
|
|
|
chunk_block_masks.append(
|
|
BlockMask.from_kv_blocks(
|
|
kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
|
|
kv_indices=kv_indices_chunks[chunk_idx],
|
|
full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
|
|
full_kv_indices=full_kv_indices_chunks[chunk_idx],
|
|
BLOCK_SIZE=block_mask.BLOCK_SIZE,
|
|
mask_mod=create_mask_mod(batch_offset),
|
|
seq_lengths=block_mask.seq_lengths,
|
|
)
|
|
)
|
|
batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
|
|
return chunk_block_masks
|
|
|
|
|
|
def _split_tensor(
|
|
tensor: torch.Tensor,
|
|
spec: TensorChunkSpec,
|
|
num_chunks: int,
|
|
) -> Sequence[torch.Tensor]:
|
|
"""Given a tensor, and a chunking spec, split the tensor.
|
|
Args:
|
|
|
|
tensor: Tensor to split
|
|
spec: Chunking spec
|
|
num_chunks: Number of chunks to split the tensor into
|
|
|
|
Returns:
|
|
chunk_tensors: List of chunked tensors
|
|
"""
|
|
|
|
assert tensor.size(spec.split_dim) >= num_chunks, (
|
|
f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
|
|
)
|
|
chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
|
|
|
|
if not _debug_mask_minibatches:
|
|
return chunk_tensors
|
|
|
|
expanded_chunks = []
|
|
split_dim_idx = 0
|
|
for chunk_tensor in chunk_tensors:
|
|
new_val = torch.zeros_like(tensor)
|
|
upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
|
|
|
|
slice_indices = [slice(None, None, None)] * new_val.ndim
|
|
slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
|
|
new_val[slice_indices] = chunk_tensor
|
|
|
|
expanded_chunks.append(new_val)
|
|
|
|
split_dim_idx += chunk_tensor.size(spec.split_dim)
|
|
|
|
return expanded_chunks
|
|
|
|
|
|
def _shard_dict_of_args(
|
|
args_dict,
|
|
args_chunk_spec,
|
|
num_chunks,
|
|
):
|
|
"""
|
|
Given a dictionary of args, and a dictionary of chunking specs, shard the
|
|
args according to the chunking specs.
|
|
|
|
Args:
|
|
args_dict: Dictionary of args
|
|
args_chunk_spec: Dictionary of chunking specs
|
|
num_chunks: Number of chunks to shard the args into
|
|
|
|
Returns:
|
|
args_split: List of sharded args
|
|
"""
|
|
|
|
if not args_dict:
|
|
return [{} for _ in range(num_chunks)]
|
|
|
|
assert len(args_dict) == len(args_chunk_spec), (
|
|
f"args_dict.keys() = {list(args_dict.keys())} "
|
|
f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
|
)
|
|
assert args_chunk_spec is not None # Should have been set by caller
|
|
|
|
values, tree_spec = tree_flatten(args_dict)
|
|
chunk_specs, _ = tree_flatten(args_chunk_spec)
|
|
|
|
# First check and find the actual number of chunks
|
|
split_sizes = []
|
|
for v, spec in zip(values, chunk_specs, strict=True):
|
|
if spec is _Replicate:
|
|
split_sizes.append(num_chunks)
|
|
elif isinstance(v, torch.Tensor):
|
|
assert isinstance(spec, TensorChunkSpec)
|
|
split_sizes.append(v.size(spec.split_dim))
|
|
elif isinstance(v, BlockMask):
|
|
assert isinstance(spec, TensorChunkSpec)
|
|
assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
|
|
split_sizes.append(v.kv_num_blocks.size(0))
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
|
)
|
|
result_num_chunks = min(*split_sizes, num_chunks)
|
|
|
|
flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
|
|
for v, spec in zip(values, chunk_specs, strict=True):
|
|
v_splits: Sequence[Any] = []
|
|
if spec is _Replicate:
|
|
v_splits = [v] * result_num_chunks
|
|
elif isinstance(v, torch.Tensor):
|
|
v_splits = _split_tensor(v, spec, result_num_chunks)
|
|
elif isinstance(v, BlockMask):
|
|
v_splits = _split_block_mask(v, result_num_chunks)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
|
)
|
|
|
|
for _flat_split_result, _v_split in zip(
|
|
flat_split_results, v_splits, strict=True
|
|
):
|
|
_flat_split_result.append(_v_split)
|
|
|
|
return [
|
|
tree_unflatten(_flat_split_result, tree_spec)
|
|
for _flat_split_result in flat_split_results
|
|
]
|
|
|
|
|
|
def split_args_kwargs_into_chunks(
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[dict[str, Any]],
|
|
chunks: int,
|
|
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
|
) -> tuple[list[tuple], list[dict]]:
|
|
"""
|
|
Given a sequence of args and kwargs, split them into a number of chunks
|
|
according to their respective chunking specs.
|
|
|
|
Args:
|
|
args: Tuple of args
|
|
kwargs: Dict of kwargs
|
|
chunks: Number of chunks to split the args and kwargs into
|
|
args_chunk_spec: chunking specs for args, in same shape as args
|
|
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
|
|
|
|
Returns:
|
|
args_split: List of sharded args
|
|
kwargs_split: List of sharded kwargs
|
|
"""
|
|
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
|
|
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
|
|
# and `kwargs_chunk_spec` specifications. The steps are as follows:
|
|
#
|
|
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
|
|
# To use a running example: suppose our inputs look like
|
|
#
|
|
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
|
|
# (kwargs not shown but it's a similar process)
|
|
#
|
|
# Then for this step we would end up with
|
|
#
|
|
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
|
|
#
|
|
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
|
|
#
|
|
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
|
|
#
|
|
# 3. Rotate the nesting order such that chunks are the outer dimension
|
|
#
|
|
# args_chunks = [
|
|
# ([A, B, C_1], D),
|
|
# ([A, B, C_2], D),
|
|
# ]
|
|
#
|
|
# 4. Unflatten each chunk according to the spec
|
|
#
|
|
# args_chunks = [
|
|
# ([A, [B, C_1]], D),
|
|
# ([A, [B, C_2]], D),
|
|
# ]
|
|
|
|
# TODO: _debug_mask_minibatches
|
|
# Handle the case where kwargs is None
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
|
|
# their format and use default chunking along dim 0
|
|
if args_chunk_spec is None:
|
|
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
|
|
|
|
if kwargs_chunk_spec is None:
|
|
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
|
|
|
|
args_split_dict = _shard_dict_of_args(
|
|
dict(enumerate(args)),
|
|
dict(enumerate(args_chunk_spec)),
|
|
chunks,
|
|
)
|
|
real_num_chunks = len(args_split_dict)
|
|
|
|
kwargs_split = _shard_dict_of_args(
|
|
kwargs,
|
|
kwargs_chunk_spec,
|
|
real_num_chunks,
|
|
)
|
|
|
|
if len(kwargs_split) < real_num_chunks:
|
|
# In case kwargs are sharded into less chunks
|
|
# e.g. when `args` has no tensor, just values
|
|
real_num_chunks = len(kwargs_split)
|
|
# Re-shard args
|
|
args_split_dict = _shard_dict_of_args(
|
|
dict(enumerate(args)),
|
|
dict(enumerate(args_chunk_spec)),
|
|
real_num_chunks,
|
|
)
|
|
|
|
if len(args_split_dict) != len(kwargs_split):
|
|
raise RuntimeError(
|
|
"args and kwargs are split into different number of chunks: "
|
|
f"{len(args_split_dict)}, {len(kwargs_split)}"
|
|
)
|
|
|
|
args_split = [
|
|
tuple(chunk_args[i] for i in range(len(chunk_args)))
|
|
for chunk_args in args_split_dict
|
|
]
|
|
|
|
return args_split, kwargs_split
|
|
|
|
|
|
def merge_chunks(
|
|
chunks: list[Any],
|
|
chunk_spec,
|
|
):
|
|
"""
|
|
Given a list of chunks, merge them into a single value according to
|
|
the chunk spec.
|
|
|
|
Args:
|
|
chunks: list of chunks
|
|
chunk_spec: Chunking spec for the chunks
|
|
|
|
Returns:
|
|
value: Merged value
|
|
"""
|
|
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
|
|
# steps are similar to the steps in that function but in reverse. Given the
|
|
# input values:
|
|
#
|
|
# chunks = [
|
|
# ([A, [B, C_1]], D),
|
|
# ([A, [B, C_2]], D),
|
|
# ]
|
|
# args_spec = ([None, [None, TensorChunkSpec]], None)
|
|
#
|
|
# 1. Flatten the chunks according to the chunk_spec
|
|
#
|
|
# chunks_flat = [
|
|
# ([A, B, C_1], D),
|
|
# ([A, B, C_2], D),
|
|
# ]
|
|
#
|
|
# 2. Rotate the nesting order such that chunks are the inner dimension
|
|
#
|
|
# value_inner = ([A, B, [C_1, C_2]], D)
|
|
#
|
|
# 3. Concatenate sharded arguments
|
|
#
|
|
# value_combined = ([A, B, C], D)
|
|
#
|
|
# 4. Unflatten the combined args given the spec
|
|
#
|
|
# value = ([A, [B, C]], D)
|
|
|
|
# Preliminary: flatten the chunk spec
|
|
if chunk_spec is not None:
|
|
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
|
|
else:
|
|
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
|
|
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
|
|
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
|
|
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
|
|
|
|
# Stage 1: flatten chunks
|
|
# chunks_flattened : [num chunks, num args]
|
|
chunks_flattened = []
|
|
|
|
for chunk in chunks:
|
|
chunk_flattened, _ = tree_flatten(chunk)
|
|
if len(chunk_flattened) != len(spec_flattened):
|
|
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
|
|
|
|
chunks_flattened.append(chunk_flattened)
|
|
|
|
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
|
|
# concatenate sharded operands
|
|
# args_flattened : [num args]
|
|
args_flattened = []
|
|
for arg_idx, arg in enumerate(spec_flattened):
|
|
if isinstance(arg, TensorChunkSpec):
|
|
partial_values = [
|
|
chunks_flattened[chunk_idx][arg_idx]
|
|
for chunk_idx in range(len(chunks_flattened))
|
|
]
|
|
|
|
if _debug_mask_minibatches:
|
|
# Infer size of individual chunks by running `tensor_split` again
|
|
overall_shape = partial_values[0].shape
|
|
for val in partial_values[1:]:
|
|
assert val.shape == overall_shape
|
|
meta_chunks = torch.tensor_split(
|
|
torch.empty(*overall_shape, device="meta"),
|
|
sections=len(partial_values),
|
|
dim=arg.split_dim,
|
|
)
|
|
|
|
values_to_cat = []
|
|
chunk_start_idx = 0
|
|
assert len(partial_values) == len(meta_chunks)
|
|
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
|
|
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
|
|
|
|
slice_indices = [slice(None, None, None)] * partial_value.ndim
|
|
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
|
|
sliced = partial_value[slice_indices]
|
|
values_to_cat.append(sliced)
|
|
|
|
chunk_start_idx = chunk_end_idx
|
|
|
|
else:
|
|
values_to_cat = partial_values
|
|
|
|
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
|
|
elif isinstance(arg, _CustomReducer):
|
|
reduced_val = arg.init_value
|
|
|
|
for chunk_idx in range(len(chunks_flattened)):
|
|
reduced_val = arg.reduce_fn(
|
|
reduced_val, chunks_flattened[chunk_idx][arg_idx]
|
|
)
|
|
|
|
args_flattened.append(reduced_val)
|
|
else:
|
|
value = chunks_flattened[0][arg_idx]
|
|
for chunk_idx in range(1, len(chunks_flattened)):
|
|
assert chunks_flattened[chunk_idx][arg_idx] == value
|
|
args_flattened.append(value)
|
|
|
|
# Stage 4: Unflatten combined args
|
|
return tree_unflatten(args_flattened, flatten_spec)
|