# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging import operator from collections.abc import Sequence from typing import Any, Optional import torch from torch.fx.node import map_aggregate from torch.nn.attention.flex_attention import BlockMask from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten __all__ = [ "TensorChunkSpec", "split_args_kwargs_into_chunks", "merge_chunks", ] logger = logging.getLogger(__name__) """ _debug_mask_minibatches specifies to send masked versions of the mini-batch through instead of micro-batch slices--this can be used for more stable numerical testing (see [A Note About Correctness Testing]) """ _debug_mask_minibatches = False class _CustomReducer: """ Custom reducer class that can be used to specify a custom operation that reduces losses of multiple microbatches into one value. Example: >>> # xdoctest: +SKIP >>> sum_reducer = _CustomReducer( >>> torch.tensor(0.0), >>> lambda a, b: a + b >>> ) """ def __init__(self, init_value, reduce_fn): self.init_value = init_value self.reduce_fn = reduce_fn class _LossReducer(_CustomReducer): pass sum_reducer = _LossReducer(torch.tensor(0.0), operator.add) # Default chunking dimension is 0. This is used for the case where the user did # not specify a chunking dimension. DEFAULT_CHUNK_DIM = 0 class TensorChunkSpec: """ Class used to specify chunking of inputs """ def __init__(self, split_dim): self.split_dim = split_dim split_dim: int def __repr__(self): return ( f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" ) def __str__(self): return f"TensorChunkSpec({self.split_dim})" @staticmethod def from_tuple( chunk_dims: tuple[int, ...], ): """ A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk dimensions (int's). Example: >>> # xdoctest: +SKIP >>> # There are three positional arguments to the model, and >>> # we are chunking them along dimension 0, 0 and 1, respectively >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) """ args_chunk_spec = map_aggregate( chunk_dims, lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] ) return args_chunk_spec @staticmethod def from_dict( chunk_dims: dict[str, int], ): """ A helper for creating a dictionary of `TensorChunkSpec` from a dictionary of chunk dimensions (int's). Example: >>> # xdoctest: +SKIP >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) """ kwargs_chunk_spec = map_aggregate( chunk_dims, lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] ) return kwargs_chunk_spec # Class used to specify replication of inputs class _Replicate: pass def _split_block_mask( block_mask: BlockMask, num_chunks: int, ) -> list[BlockMask]: """Given a block mask, split the block mask along the batch dimension (dim0). Args: block_mask: Block mask to split num_chunks: Number of chunks to split the block mask into Returns: chunk_block_masks: List of chunked block masks """ # BlockMask will broadcast if B is 1. if block_mask.kv_num_blocks.size(0) == 1: return [block_mask] * num_chunks assert block_mask.kv_num_blocks.size(0) >= num_chunks, ( "Block mask has fewer batch size than the number of chunks. " ) batch_dim = 0 kv_num_blocks_chunks = torch.tensor_split( block_mask.kv_num_blocks, num_chunks, batch_dim ) kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim) full_kv_num_blocks_chunks = ( torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim) if block_mask.full_kv_num_blocks is not None else [None] * num_chunks ) full_kv_indices_chunks = ( torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim) if block_mask.full_kv_indices is not None else [None] * num_chunks ) chunk_block_masks = [] batch_offset = 0 for chunk_idx in range(num_chunks): def create_mask_mod(idx): def batch_offset_mask_mod(b, h, q_idx, kv_idx): b_offset = torch.full_like(b, idx) return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx) return batch_offset_mask_mod chunk_block_masks.append( BlockMask.from_kv_blocks( kv_num_blocks=kv_num_blocks_chunks[chunk_idx], kv_indices=kv_indices_chunks[chunk_idx], full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx], full_kv_indices=full_kv_indices_chunks[chunk_idx], BLOCK_SIZE=block_mask.BLOCK_SIZE, mask_mod=create_mask_mod(batch_offset), seq_lengths=block_mask.seq_lengths, ) ) batch_offset += kv_num_blocks_chunks[chunk_idx].size(0) return chunk_block_masks def _split_tensor( tensor: torch.Tensor, spec: TensorChunkSpec, num_chunks: int, ) -> Sequence[torch.Tensor]: """Given a tensor, and a chunking spec, split the tensor. Args: tensor: Tensor to split spec: Chunking spec num_chunks: Number of chunks to split the tensor into Returns: chunk_tensors: List of chunked tensors """ assert tensor.size(spec.split_dim) >= num_chunks, ( f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks" ) chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim) if not _debug_mask_minibatches: return chunk_tensors expanded_chunks = [] split_dim_idx = 0 for chunk_tensor in chunk_tensors: new_val = torch.zeros_like(tensor) upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim) slice_indices = [slice(None, None, None)] * new_val.ndim slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx) new_val[slice_indices] = chunk_tensor expanded_chunks.append(new_val) split_dim_idx += chunk_tensor.size(spec.split_dim) return expanded_chunks def _shard_dict_of_args( args_dict, args_chunk_spec, num_chunks, ): """ Given a dictionary of args, and a dictionary of chunking specs, shard the args according to the chunking specs. Args: args_dict: Dictionary of args args_chunk_spec: Dictionary of chunking specs num_chunks: Number of chunks to shard the args into Returns: args_split: List of sharded args """ if not args_dict: return [{} for _ in range(num_chunks)] assert len(args_dict) == len(args_chunk_spec), ( f"args_dict.keys() = {list(args_dict.keys())} " f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" ) assert args_chunk_spec is not None # Should have been set by caller values, tree_spec = tree_flatten(args_dict) chunk_specs, _ = tree_flatten(args_chunk_spec) # First check and find the actual number of chunks split_sizes = [] for v, spec in zip(values, chunk_specs, strict=True): # The original logic is "spec is _Replicate". This doesn't seem to be # correct. But we keep it for backward compatibility. if spec is _Replicate or isinstance(spec, _Replicate): split_sizes.append(num_chunks) elif isinstance(v, torch.Tensor): assert isinstance(spec, TensorChunkSpec) split_sizes.append(v.size(spec.split_dim)) elif isinstance(v, BlockMask): assert isinstance(spec, TensorChunkSpec) assert spec.split_dim == 0, "BlockMask only supports split_dim=0" # BlockMask will broadcast if B is 1. if v.kv_num_blocks.size(0) == 1: split_sizes.append(num_chunks) else: split_sizes.append(v.kv_num_blocks.size(0)) else: raise ValueError( f"Unsupported chunk spec: {spec} and value: {v} combination." ) result_num_chunks = min(*split_sizes, num_chunks) flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)] for v, spec in zip(values, chunk_specs, strict=True): v_splits: Sequence[Any] = [] if spec is _Replicate or isinstance(spec, _Replicate): v_splits = [v] * result_num_chunks elif isinstance(v, torch.Tensor): v_splits = _split_tensor(v, spec, result_num_chunks) elif isinstance(v, BlockMask): v_splits = _split_block_mask(v, result_num_chunks) else: raise ValueError( f"Unsupported chunk spec: {spec} and value: {v} combination." ) # pyrefly: ignore # no-matching-overload for _flat_split_result, _v_split in zip( flat_split_results, v_splits, strict=True ): _flat_split_result.append(_v_split) return [ tree_unflatten(_flat_split_result, tree_spec) for _flat_split_result in flat_split_results ] def split_args_kwargs_into_chunks( args: tuple[Any, ...], kwargs: Optional[dict[str, Any]], chunks: int, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, ) -> tuple[list[tuple], list[dict]]: """ Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs. Args: args: Tuple of args kwargs: Dict of kwargs chunks: Number of chunks to split the args and kwargs into args_chunk_spec: chunking specs for args, in same shape as args kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs Returns: args_split: List of sharded args kwargs_split: List of sharded kwargs """ # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` # and `kwargs_chunk_spec` specifications. The steps are as follows: # # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. # To use a running example: suppose our inputs look like # # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) # (kwargs not shown but it's a similar process) # # Then for this step we would end up with # # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) # # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 # # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) # # 3. Rotate the nesting order such that chunks are the outer dimension # # args_chunks = [ # ([A, B, C_1], D), # ([A, B, C_2], D), # ] # # 4. Unflatten each chunk according to the spec # # args_chunks = [ # ([A, [B, C_1]], D), # ([A, [B, C_2]], D), # ] # TODO: _debug_mask_minibatches # Handle the case where kwargs is None if kwargs is None: kwargs = {} # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend # their format and use default chunking along dim 0 def default_spec(v): if isinstance(v, torch.Tensor | BlockMask): return TensorChunkSpec(DEFAULT_CHUNK_DIM) else: return _Replicate() if args_chunk_spec is None: args_chunk_spec = tree_map(default_spec, args) if kwargs_chunk_spec is None: kwargs_chunk_spec = tree_map(default_spec, kwargs) args_split_dict = _shard_dict_of_args( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), chunks, ) real_num_chunks = len(args_split_dict) kwargs_split = _shard_dict_of_args( kwargs, kwargs_chunk_spec, real_num_chunks, ) if len(kwargs_split) < real_num_chunks: # In case kwargs are sharded into less chunks # e.g. when `args` has no tensor, just values real_num_chunks = len(kwargs_split) # Re-shard args args_split_dict = _shard_dict_of_args( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), real_num_chunks, ) if len(args_split_dict) != len(kwargs_split): raise RuntimeError( "args and kwargs are split into different number of chunks: " f"{len(args_split_dict)}, {len(kwargs_split)}" ) args_split = [ tuple(chunk_args[i] for i in range(len(chunk_args))) for chunk_args in args_split_dict ] return args_split, kwargs_split def merge_chunks( chunks: list[Any], chunk_spec, ): """ Given a list of chunks, merge them into a single value according to the chunk spec. Args: chunks: list of chunks chunk_spec: Chunking spec for the chunks Returns: value: Merged value """ # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the # steps are similar to the steps in that function but in reverse. Given the # input values: # # chunks = [ # ([A, [B, C_1]], D), # ([A, [B, C_2]], D), # ] # args_spec = ([None, [None, TensorChunkSpec]], None) # # 1. Flatten the chunks according to the chunk_spec # # chunks_flat = [ # ([A, B, C_1], D), # ([A, B, C_2], D), # ] # # 2. Rotate the nesting order such that chunks are the inner dimension # # value_inner = ([A, B, [C_1, C_2]], D) # # 3. Concatenate sharded arguments # # value_combined = ([A, B, C], D) # # 4. Unflatten the combined args given the spec # # value = ([A, [B, C]], D) # Preliminary: flatten the chunk spec if chunk_spec is not None: spec_flattened, flatten_spec = tree_flatten(chunk_spec) else: # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields # We obtain the output structure by flattening chunk 0 and generate the chunk_spec chunk0_flat, flatten_spec = tree_flatten(chunks[0]) spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) # Stage 1: flatten chunks # chunks_flattened : [num chunks, num args] chunks_flattened = [] for chunk in chunks: chunk_flattened, _ = tree_flatten(chunk) if len(chunk_flattened) != len(spec_flattened): raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") chunks_flattened.append(chunk_flattened) # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and # concatenate sharded operands # args_flattened : [num args] args_flattened = [] for arg_idx, arg in enumerate(spec_flattened): if isinstance(arg, TensorChunkSpec): partial_values = [ chunks_flattened[chunk_idx][arg_idx] for chunk_idx in range(len(chunks_flattened)) ] if _debug_mask_minibatches: # Infer size of individual chunks by running `tensor_split` again overall_shape = partial_values[0].shape for val in partial_values[1:]: assert val.shape == overall_shape meta_chunks = torch.tensor_split( torch.empty(*overall_shape, device="meta"), sections=len(partial_values), dim=arg.split_dim, ) values_to_cat = [] chunk_start_idx = 0 assert len(partial_values) == len(meta_chunks) for partial_value, meta_chunk in zip(partial_values, meta_chunks): chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) slice_indices = [slice(None, None, None)] * partial_value.ndim slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) sliced = partial_value[slice_indices] values_to_cat.append(sliced) chunk_start_idx = chunk_end_idx else: values_to_cat = partial_values args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) elif isinstance(arg, _CustomReducer): reduced_val = arg.init_value for chunk_idx in range(len(chunks_flattened)): reduced_val = arg.reduce_fn( reduced_val, chunks_flattened[chunk_idx][arg_idx] ) args_flattened.append(reduced_val) else: value = chunks_flattened[0][arg_idx] for chunk_idx in range(1, len(chunks_flattened)): assert chunks_flattened[chunk_idx][arg_idx] == value args_flattened.append(value) # Stage 4: Unflatten combined args return tree_unflatten(args_flattened, flatten_spec)