mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214 Approved by: https://github.com/ezyang
536 lines
17 KiB
Python
536 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
import operator
|
|
from collections.abc import Sequence
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.fx.node import map_aggregate
|
|
from torch.nn.attention.flex_attention import BlockMask
|
|
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
|
|
|
|
|
__all__ = [
|
|
"TensorChunkSpec",
|
|
"split_args_kwargs_into_chunks",
|
|
"merge_chunks",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
"""
|
|
_debug_mask_minibatches specifies to send masked versions of the mini-batch
|
|
through instead of micro-batch slices--this can be used for more stable
|
|
numerical testing (see [A Note About Correctness Testing])
|
|
"""
|
|
_debug_mask_minibatches = False
|
|
|
|
|
|
class _CustomReducer:
|
|
"""
|
|
Custom reducer class that can be used to specify a custom operation that
|
|
reduces losses of multiple microbatches into one value.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> sum_reducer = _CustomReducer(
|
|
>>> torch.tensor(0.0),
|
|
>>> lambda a, b: a + b
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(self, init_value, reduce_fn):
|
|
self.init_value = init_value
|
|
self.reduce_fn = reduce_fn
|
|
|
|
|
|
class _LossReducer(_CustomReducer):
|
|
pass
|
|
|
|
|
|
sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
|
|
|
|
# Default chunking dimension is 0. This is used for the case where the user did
|
|
# not specify a chunking dimension.
|
|
DEFAULT_CHUNK_DIM = 0
|
|
|
|
|
|
class TensorChunkSpec:
|
|
"""
|
|
Class used to specify chunking of inputs
|
|
"""
|
|
|
|
def __init__(self, split_dim):
|
|
self.split_dim = split_dim
|
|
|
|
split_dim: int
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"TensorChunkSpec({self.split_dim})"
|
|
|
|
@staticmethod
|
|
def from_tuple(
|
|
chunk_dims: tuple[int, ...],
|
|
):
|
|
"""
|
|
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
|
|
dimensions (int's).
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # There are three positional arguments to the model, and
|
|
>>> # we are chunking them along dimension 0, 0 and 1, respectively
|
|
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
|
|
"""
|
|
args_chunk_spec = map_aggregate(
|
|
chunk_dims,
|
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
|
)
|
|
return args_chunk_spec
|
|
|
|
@staticmethod
|
|
def from_dict(
|
|
chunk_dims: dict[str, int],
|
|
):
|
|
"""
|
|
A helper for creating a dictionary of `TensorChunkSpec` from a
|
|
dictionary of chunk dimensions (int's).
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
|
|
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
|
|
"""
|
|
kwargs_chunk_spec = map_aggregate(
|
|
chunk_dims,
|
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
|
)
|
|
return kwargs_chunk_spec
|
|
|
|
|
|
# Class used to specify replication of inputs
|
|
class _Replicate:
|
|
pass
|
|
|
|
|
|
def _split_block_mask(
|
|
block_mask: BlockMask,
|
|
num_chunks: int,
|
|
) -> list[BlockMask]:
|
|
"""Given a block mask, split the block mask along the batch dimension (dim0).
|
|
|
|
Args:
|
|
block_mask: Block mask to split
|
|
num_chunks: Number of chunks to split the block mask into
|
|
|
|
Returns:
|
|
chunk_block_masks: List of chunked block masks
|
|
"""
|
|
|
|
# BlockMask will broadcast if B is 1.
|
|
if block_mask.kv_num_blocks.size(0) == 1:
|
|
return [block_mask] * num_chunks
|
|
|
|
assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
|
|
"Block mask has fewer batch size than the number of chunks. "
|
|
)
|
|
|
|
batch_dim = 0
|
|
kv_num_blocks_chunks = torch.tensor_split(
|
|
block_mask.kv_num_blocks, num_chunks, batch_dim
|
|
)
|
|
kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
|
|
full_kv_num_blocks_chunks = (
|
|
torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
|
|
if block_mask.full_kv_num_blocks is not None
|
|
else [None] * num_chunks
|
|
)
|
|
full_kv_indices_chunks = (
|
|
torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
|
|
if block_mask.full_kv_indices is not None
|
|
else [None] * num_chunks
|
|
)
|
|
|
|
chunk_block_masks = []
|
|
batch_offset = 0
|
|
for chunk_idx in range(num_chunks):
|
|
|
|
def create_mask_mod(idx):
|
|
def batch_offset_mask_mod(b, h, q_idx, kv_idx):
|
|
b_offset = torch.full_like(b, idx)
|
|
return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
|
|
|
|
return batch_offset_mask_mod
|
|
|
|
chunk_block_masks.append(
|
|
BlockMask.from_kv_blocks(
|
|
kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
|
|
kv_indices=kv_indices_chunks[chunk_idx],
|
|
full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
|
|
full_kv_indices=full_kv_indices_chunks[chunk_idx],
|
|
BLOCK_SIZE=block_mask.BLOCK_SIZE,
|
|
mask_mod=create_mask_mod(batch_offset),
|
|
seq_lengths=block_mask.seq_lengths,
|
|
)
|
|
)
|
|
batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
|
|
return chunk_block_masks
|
|
|
|
|
|
def _split_tensor(
|
|
tensor: torch.Tensor,
|
|
spec: TensorChunkSpec,
|
|
num_chunks: int,
|
|
) -> Sequence[torch.Tensor]:
|
|
"""Given a tensor, and a chunking spec, split the tensor.
|
|
Args:
|
|
|
|
tensor: Tensor to split
|
|
spec: Chunking spec
|
|
num_chunks: Number of chunks to split the tensor into
|
|
|
|
Returns:
|
|
chunk_tensors: List of chunked tensors
|
|
"""
|
|
|
|
assert tensor.size(spec.split_dim) >= num_chunks, (
|
|
f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
|
|
)
|
|
chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
|
|
|
|
if not _debug_mask_minibatches:
|
|
return chunk_tensors
|
|
|
|
expanded_chunks = []
|
|
split_dim_idx = 0
|
|
for chunk_tensor in chunk_tensors:
|
|
new_val = torch.zeros_like(tensor)
|
|
upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
|
|
|
|
slice_indices = [slice(None, None, None)] * new_val.ndim
|
|
slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
|
|
new_val[slice_indices] = chunk_tensor
|
|
|
|
expanded_chunks.append(new_val)
|
|
|
|
split_dim_idx += chunk_tensor.size(spec.split_dim)
|
|
|
|
return expanded_chunks
|
|
|
|
|
|
def _shard_dict_of_args(
|
|
args_dict,
|
|
args_chunk_spec,
|
|
num_chunks,
|
|
):
|
|
"""
|
|
Given a dictionary of args, and a dictionary of chunking specs, shard the
|
|
args according to the chunking specs.
|
|
|
|
Args:
|
|
args_dict: Dictionary of args
|
|
args_chunk_spec: Dictionary of chunking specs
|
|
num_chunks: Number of chunks to shard the args into
|
|
|
|
Returns:
|
|
args_split: List of sharded args
|
|
"""
|
|
|
|
if not args_dict:
|
|
return [{} for _ in range(num_chunks)]
|
|
|
|
assert len(args_dict) == len(args_chunk_spec), (
|
|
f"args_dict.keys() = {list(args_dict.keys())} "
|
|
f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
|
)
|
|
assert args_chunk_spec is not None # Should have been set by caller
|
|
|
|
values, tree_spec = tree_flatten(args_dict)
|
|
chunk_specs, _ = tree_flatten(args_chunk_spec)
|
|
|
|
# First check and find the actual number of chunks
|
|
split_sizes = []
|
|
for v, spec in zip(values, chunk_specs, strict=True):
|
|
# The original logic is "spec is _Replicate". This doesn't seem to be
|
|
# correct. But we keep it for backward compatibility.
|
|
if spec is _Replicate or isinstance(spec, _Replicate):
|
|
split_sizes.append(num_chunks)
|
|
elif isinstance(v, torch.Tensor):
|
|
assert isinstance(spec, TensorChunkSpec)
|
|
split_sizes.append(v.size(spec.split_dim))
|
|
elif isinstance(v, BlockMask):
|
|
assert isinstance(spec, TensorChunkSpec)
|
|
assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
|
|
# BlockMask will broadcast if B is 1.
|
|
if v.kv_num_blocks.size(0) == 1:
|
|
split_sizes.append(num_chunks)
|
|
else:
|
|
split_sizes.append(v.kv_num_blocks.size(0))
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
|
)
|
|
result_num_chunks = min(*split_sizes, num_chunks)
|
|
|
|
flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
|
|
for v, spec in zip(values, chunk_specs, strict=True):
|
|
v_splits: Sequence[Any] = []
|
|
if spec is _Replicate or isinstance(spec, _Replicate):
|
|
v_splits = [v] * result_num_chunks
|
|
elif isinstance(v, torch.Tensor):
|
|
v_splits = _split_tensor(v, spec, result_num_chunks)
|
|
elif isinstance(v, BlockMask):
|
|
v_splits = _split_block_mask(v, result_num_chunks)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
|
)
|
|
|
|
# pyrefly: ignore # no-matching-overload
|
|
for _flat_split_result, _v_split in zip(
|
|
flat_split_results, v_splits, strict=True
|
|
):
|
|
_flat_split_result.append(_v_split)
|
|
|
|
return [
|
|
tree_unflatten(_flat_split_result, tree_spec)
|
|
for _flat_split_result in flat_split_results
|
|
]
|
|
|
|
|
|
def split_args_kwargs_into_chunks(
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[dict[str, Any]],
|
|
chunks: int,
|
|
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
|
) -> tuple[list[tuple], list[dict]]:
|
|
"""
|
|
Given a sequence of args and kwargs, split them into a number of chunks
|
|
according to their respective chunking specs.
|
|
|
|
Args:
|
|
args: Tuple of args
|
|
kwargs: Dict of kwargs
|
|
chunks: Number of chunks to split the args and kwargs into
|
|
args_chunk_spec: chunking specs for args, in same shape as args
|
|
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
|
|
|
|
Returns:
|
|
args_split: List of sharded args
|
|
kwargs_split: List of sharded kwargs
|
|
"""
|
|
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
|
|
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
|
|
# and `kwargs_chunk_spec` specifications. The steps are as follows:
|
|
#
|
|
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
|
|
# To use a running example: suppose our inputs look like
|
|
#
|
|
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
|
|
# (kwargs not shown but it's a similar process)
|
|
#
|
|
# Then for this step we would end up with
|
|
#
|
|
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
|
|
#
|
|
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
|
|
#
|
|
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
|
|
#
|
|
# 3. Rotate the nesting order such that chunks are the outer dimension
|
|
#
|
|
# args_chunks = [
|
|
# ([A, B, C_1], D),
|
|
# ([A, B, C_2], D),
|
|
# ]
|
|
#
|
|
# 4. Unflatten each chunk according to the spec
|
|
#
|
|
# args_chunks = [
|
|
# ([A, [B, C_1]], D),
|
|
# ([A, [B, C_2]], D),
|
|
# ]
|
|
|
|
# TODO: _debug_mask_minibatches
|
|
# Handle the case where kwargs is None
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
|
|
# their format and use default chunking along dim 0
|
|
def default_spec(v):
|
|
if isinstance(v, torch.Tensor | BlockMask):
|
|
return TensorChunkSpec(DEFAULT_CHUNK_DIM)
|
|
else:
|
|
return _Replicate()
|
|
|
|
if args_chunk_spec is None:
|
|
args_chunk_spec = tree_map(default_spec, args)
|
|
|
|
if kwargs_chunk_spec is None:
|
|
kwargs_chunk_spec = tree_map(default_spec, kwargs)
|
|
|
|
args_split_dict = _shard_dict_of_args(
|
|
dict(enumerate(args)),
|
|
dict(enumerate(args_chunk_spec)),
|
|
chunks,
|
|
)
|
|
real_num_chunks = len(args_split_dict)
|
|
|
|
kwargs_split = _shard_dict_of_args(
|
|
kwargs,
|
|
kwargs_chunk_spec,
|
|
real_num_chunks,
|
|
)
|
|
|
|
if len(kwargs_split) < real_num_chunks:
|
|
# In case kwargs are sharded into less chunks
|
|
# e.g. when `args` has no tensor, just values
|
|
real_num_chunks = len(kwargs_split)
|
|
# Re-shard args
|
|
args_split_dict = _shard_dict_of_args(
|
|
dict(enumerate(args)),
|
|
dict(enumerate(args_chunk_spec)),
|
|
real_num_chunks,
|
|
)
|
|
|
|
if len(args_split_dict) != len(kwargs_split):
|
|
raise RuntimeError(
|
|
"args and kwargs are split into different number of chunks: "
|
|
f"{len(args_split_dict)}, {len(kwargs_split)}"
|
|
)
|
|
|
|
args_split = [
|
|
tuple(chunk_args[i] for i in range(len(chunk_args)))
|
|
for chunk_args in args_split_dict
|
|
]
|
|
|
|
return args_split, kwargs_split
|
|
|
|
|
|
def merge_chunks(
|
|
chunks: list[Any],
|
|
chunk_spec,
|
|
):
|
|
"""
|
|
Given a list of chunks, merge them into a single value according to
|
|
the chunk spec.
|
|
|
|
Args:
|
|
chunks: list of chunks
|
|
chunk_spec: Chunking spec for the chunks
|
|
|
|
Returns:
|
|
value: Merged value
|
|
"""
|
|
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
|
|
# steps are similar to the steps in that function but in reverse. Given the
|
|
# input values:
|
|
#
|
|
# chunks = [
|
|
# ([A, [B, C_1]], D),
|
|
# ([A, [B, C_2]], D),
|
|
# ]
|
|
# args_spec = ([None, [None, TensorChunkSpec]], None)
|
|
#
|
|
# 1. Flatten the chunks according to the chunk_spec
|
|
#
|
|
# chunks_flat = [
|
|
# ([A, B, C_1], D),
|
|
# ([A, B, C_2], D),
|
|
# ]
|
|
#
|
|
# 2. Rotate the nesting order such that chunks are the inner dimension
|
|
#
|
|
# value_inner = ([A, B, [C_1, C_2]], D)
|
|
#
|
|
# 3. Concatenate sharded arguments
|
|
#
|
|
# value_combined = ([A, B, C], D)
|
|
#
|
|
# 4. Unflatten the combined args given the spec
|
|
#
|
|
# value = ([A, [B, C]], D)
|
|
|
|
# Preliminary: flatten the chunk spec
|
|
if chunk_spec is not None:
|
|
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
|
|
else:
|
|
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
|
|
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
|
|
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
|
|
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
|
|
|
|
# Stage 1: flatten chunks
|
|
# chunks_flattened : [num chunks, num args]
|
|
chunks_flattened = []
|
|
|
|
for chunk in chunks:
|
|
chunk_flattened, _ = tree_flatten(chunk)
|
|
if len(chunk_flattened) != len(spec_flattened):
|
|
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
|
|
|
|
chunks_flattened.append(chunk_flattened)
|
|
|
|
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
|
|
# concatenate sharded operands
|
|
# args_flattened : [num args]
|
|
args_flattened = []
|
|
for arg_idx, arg in enumerate(spec_flattened):
|
|
if isinstance(arg, TensorChunkSpec):
|
|
partial_values = [
|
|
chunks_flattened[chunk_idx][arg_idx]
|
|
for chunk_idx in range(len(chunks_flattened))
|
|
]
|
|
|
|
if _debug_mask_minibatches:
|
|
# Infer size of individual chunks by running `tensor_split` again
|
|
overall_shape = partial_values[0].shape
|
|
for val in partial_values[1:]:
|
|
assert val.shape == overall_shape
|
|
meta_chunks = torch.tensor_split(
|
|
torch.empty(*overall_shape, device="meta"),
|
|
sections=len(partial_values),
|
|
dim=arg.split_dim,
|
|
)
|
|
|
|
values_to_cat = []
|
|
chunk_start_idx = 0
|
|
assert len(partial_values) == len(meta_chunks)
|
|
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
|
|
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
|
|
|
|
slice_indices = [slice(None, None, None)] * partial_value.ndim
|
|
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
|
|
sliced = partial_value[slice_indices]
|
|
values_to_cat.append(sliced)
|
|
|
|
chunk_start_idx = chunk_end_idx
|
|
|
|
else:
|
|
values_to_cat = partial_values
|
|
|
|
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
|
|
elif isinstance(arg, _CustomReducer):
|
|
reduced_val = arg.init_value
|
|
|
|
for chunk_idx in range(len(chunks_flattened)):
|
|
reduced_val = arg.reduce_fn(
|
|
reduced_val, chunks_flattened[chunk_idx][arg_idx]
|
|
)
|
|
|
|
args_flattened.append(reduced_val)
|
|
else:
|
|
value = chunks_flattened[0][arg_idx]
|
|
for chunk_idx in range(1, len(chunks_flattened)):
|
|
assert chunks_flattened[chunk_idx][arg_idx] == value
|
|
args_flattened.append(value)
|
|
|
|
# Stage 4: Unflatten combined args
|
|
return tree_unflatten(args_flattened, flatten_spec)
|