mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[PP] Let PP split BlockMask into micro-BlockMask (#164111)
BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111 Approved by: https://github.com/H-Huang
This commit is contained in:
committed by
PyTorch MergeBot
parent
483f4e0db9
commit
e3ae80fc03
@ -2,10 +2,11 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
@ -115,6 +116,107 @@ class _Replicate:
|
||||
pass
|
||||
|
||||
|
||||
def _split_block_mask(
|
||||
block_mask: BlockMask,
|
||||
num_chunks: int,
|
||||
) -> list[BlockMask]:
|
||||
"""Given a block mask, split the block mask along the batch dimension (dim0).
|
||||
|
||||
Args:
|
||||
block_mask: Block mask to split
|
||||
num_chunks: Number of chunks to split the block mask into
|
||||
|
||||
Returns:
|
||||
chunk_block_masks: List of chunked block masks
|
||||
"""
|
||||
|
||||
assert block_mask.kv_num_blocks.size(0) >= num_chunks, (
|
||||
"Block mask has fewer batch size than the number of chunks. "
|
||||
)
|
||||
|
||||
batch_dim = 0
|
||||
kv_num_blocks_chunks = torch.tensor_split(
|
||||
block_mask.kv_num_blocks, num_chunks, batch_dim
|
||||
)
|
||||
kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
|
||||
full_kv_num_blocks_chunks = (
|
||||
torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
|
||||
if block_mask.full_kv_num_blocks is not None
|
||||
else [None] * num_chunks
|
||||
)
|
||||
full_kv_indices_chunks = (
|
||||
torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
|
||||
if block_mask.full_kv_indices is not None
|
||||
else [None] * num_chunks
|
||||
)
|
||||
|
||||
chunk_block_masks = []
|
||||
batch_offset = 0
|
||||
for chunk_idx in range(num_chunks):
|
||||
|
||||
def create_mask_mod(idx):
|
||||
def batch_offset_mask_mod(b, h, q_idx, kv_idx):
|
||||
b_offset = torch.full_like(b, idx)
|
||||
return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
|
||||
|
||||
return batch_offset_mask_mod
|
||||
|
||||
chunk_block_masks.append(
|
||||
BlockMask.from_kv_blocks(
|
||||
kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
|
||||
kv_indices=kv_indices_chunks[chunk_idx],
|
||||
full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
|
||||
full_kv_indices=full_kv_indices_chunks[chunk_idx],
|
||||
BLOCK_SIZE=block_mask.BLOCK_SIZE,
|
||||
mask_mod=create_mask_mod(batch_offset),
|
||||
seq_lengths=block_mask.seq_lengths,
|
||||
)
|
||||
)
|
||||
batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
|
||||
return chunk_block_masks
|
||||
|
||||
|
||||
def _split_tensor(
|
||||
tensor: torch.Tensor,
|
||||
spec: TensorChunkSpec,
|
||||
num_chunks: int,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Given a tensor, and a chunking spec, split the tensor.
|
||||
Args:
|
||||
|
||||
tensor: Tensor to split
|
||||
spec: Chunking spec
|
||||
num_chunks: Number of chunks to split the tensor into
|
||||
|
||||
Returns:
|
||||
chunk_tensors: List of chunked tensors
|
||||
"""
|
||||
|
||||
assert tensor.size(spec.split_dim) >= num_chunks, (
|
||||
f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
|
||||
)
|
||||
chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
|
||||
|
||||
if not _debug_mask_minibatches:
|
||||
return chunk_tensors
|
||||
|
||||
expanded_chunks = []
|
||||
split_dim_idx = 0
|
||||
for chunk_tensor in chunk_tensors:
|
||||
new_val = torch.zeros_like(tensor)
|
||||
upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
|
||||
|
||||
slice_indices = [slice(None, None, None)] * new_val.ndim
|
||||
slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
|
||||
new_val[slice_indices] = chunk_tensor
|
||||
|
||||
expanded_chunks.append(new_val)
|
||||
|
||||
split_dim_idx += chunk_tensor.size(spec.split_dim)
|
||||
|
||||
return expanded_chunks
|
||||
|
||||
|
||||
def _shard_dict_of_args(
|
||||
args_dict,
|
||||
args_chunk_spec,
|
||||
@ -132,114 +234,60 @@ def _shard_dict_of_args(
|
||||
Returns:
|
||||
args_split: List of sharded args
|
||||
"""
|
||||
# Stage 1+2: flatten and shard/replicate
|
||||
|
||||
# args_sharded_replicated : [num args, num flat values, num chunks]
|
||||
args_sharded_replicated = {}
|
||||
arg_specs = []
|
||||
|
||||
real_num_chunks = num_chunks
|
||||
first_tensor = True
|
||||
if not args_dict:
|
||||
return [{} for _ in range(num_chunks)]
|
||||
|
||||
assert len(args_dict) == len(args_chunk_spec), (
|
||||
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
||||
f"args_dict.keys() = {list(args_dict.keys())} "
|
||||
f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
||||
)
|
||||
assert args_chunk_spec is not None # Should have been set by caller
|
||||
|
||||
for arg_key, arg in args_dict.items():
|
||||
flat, spec = tree_flatten(arg)
|
||||
arg_specs.append(spec)
|
||||
values, tree_spec = tree_flatten(args_dict)
|
||||
chunk_specs, _ = tree_flatten(args_chunk_spec)
|
||||
|
||||
chunk_spec = args_chunk_spec[arg_key]
|
||||
assert chunk_spec is not None # Should have been set by caller
|
||||
chunk_spec_flat, _ = tree_flatten(chunk_spec)
|
||||
if len(flat) != len(chunk_spec_flat):
|
||||
# First check and find the actual number of chunks
|
||||
split_sizes = []
|
||||
for v, spec in zip(values, chunk_specs, strict=True):
|
||||
if spec is _Replicate:
|
||||
split_sizes.append(num_chunks)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
assert isinstance(spec, TensorChunkSpec)
|
||||
split_sizes.append(v.size(spec.split_dim))
|
||||
elif isinstance(v, BlockMask):
|
||||
assert isinstance(spec, TensorChunkSpec)
|
||||
assert spec.split_dim == 0, "BlockMask only supports split_dim=0"
|
||||
split_sizes.append(v.kv_num_blocks.size(0))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Argument value {arg} did not have the same number of "
|
||||
f"values as as chunk spec {chunk_spec}"
|
||||
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
||||
)
|
||||
result_num_chunks = min(*split_sizes, num_chunks)
|
||||
|
||||
flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
|
||||
for v, spec in zip(values, chunk_specs, strict=True):
|
||||
v_splits: Sequence[Any] = []
|
||||
if spec is _Replicate:
|
||||
v_splits = [v] * result_num_chunks
|
||||
elif isinstance(v, torch.Tensor):
|
||||
v_splits = _split_tensor(v, spec, result_num_chunks)
|
||||
elif isinstance(v, BlockMask):
|
||||
v_splits = _split_block_mask(v, result_num_chunks)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
||||
)
|
||||
|
||||
sharded_arg_flat = []
|
||||
for _flat_split_result, _v_split in zip(
|
||||
flat_split_results, v_splits, strict=True
|
||||
):
|
||||
_flat_split_result.append(_v_split)
|
||||
|
||||
for v, chunk_v in zip(flat, chunk_spec_flat):
|
||||
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
|
||||
sharded_arg_flat.append([v] * real_num_chunks)
|
||||
elif isinstance(chunk_v, TensorChunkSpec):
|
||||
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
|
||||
# If it's a collection type, split it as you would expect. Otherwise,
|
||||
# Throw an error
|
||||
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
|
||||
|
||||
v_split_dim_size = v.size(chunk_v.split_dim)
|
||||
if v_split_dim_size < real_num_chunks:
|
||||
if first_tensor:
|
||||
# We can only adjust number of chunks when we hit this
|
||||
# issue at the first tensor encountered
|
||||
logger.warning(
|
||||
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
|
||||
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
|
||||
)
|
||||
real_num_chunks = v_split_dim_size
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
|
||||
f"smaller than the number of chunks {num_chunks}. "
|
||||
"PiPPy cannot reduce the number of chunks because "
|
||||
"other arguments have bigger chunk-dimension sizes. "
|
||||
"Please adjust your num_chunks setting."
|
||||
)
|
||||
|
||||
chunk_tensors = torch.tensor_split(
|
||||
v, real_num_chunks, chunk_v.split_dim
|
||||
)
|
||||
|
||||
if _debug_mask_minibatches:
|
||||
expanded_chunks = []
|
||||
|
||||
split_dim_idx = 0
|
||||
for chunk_tensor in chunk_tensors:
|
||||
new_val = torch.zeros_like(v)
|
||||
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
|
||||
|
||||
slice_indices = [slice(None, None, None)] * new_val.ndim
|
||||
slice_indices[chunk_v.split_dim] = slice(
|
||||
split_dim_idx, upper_idx
|
||||
)
|
||||
new_val[slice_indices] = chunk_tensor
|
||||
|
||||
expanded_chunks.append(new_val)
|
||||
|
||||
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
|
||||
|
||||
sharded_arg_flat.append(expanded_chunks)
|
||||
else:
|
||||
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
|
||||
|
||||
first_tensor = False
|
||||
else:
|
||||
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
|
||||
|
||||
args_sharded_replicated[arg_key] = sharded_arg_flat
|
||||
|
||||
# chunks_flat : [num chunks, num args, num flat values]
|
||||
chunks_flat = []
|
||||
for chunk_idx in range(real_num_chunks):
|
||||
chunk_args = {}
|
||||
for key, arg in args_sharded_replicated.items():
|
||||
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
|
||||
chunk_args[key] = arg_single_chunk
|
||||
chunks_flat.append(chunk_args)
|
||||
|
||||
# args_split : [num chunks, num args]
|
||||
args_split = []
|
||||
|
||||
for chunk in chunks_flat:
|
||||
per_chunk_args = {}
|
||||
assert len(arg_specs) == len(chunk)
|
||||
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
|
||||
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
|
||||
args_split.append(per_chunk_args)
|
||||
|
||||
return args_split
|
||||
return [
|
||||
tree_unflatten(_flat_split_result, tree_spec)
|
||||
for _flat_split_result in flat_split_results
|
||||
]
|
||||
|
||||
|
||||
def split_args_kwargs_into_chunks(
|
||||
|
Reference in New Issue
Block a user