Add block mask utility support for batches and heads > 1 (#130227)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130227
Approved by: https://github.com/yanboliang
ghstack dependencies: #130160, #130106, #130224
This commit is contained in:
chilli
2024-07-07 19:10:58 -07:00
committed by PyTorch MergeBot
parent cd683212a2
commit 64139987c0
5 changed files with 143 additions and 44 deletions

View File

@ -48,18 +48,12 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
for (const auto& tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (allow_int) {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
}
} else {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
}
}
}

View File

@ -57,7 +57,7 @@ const Tensor& value){
}
inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
// next broadcast all index tensors together

View File

@ -389,7 +389,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
checkIndexTensorTypes(orig, /*allow_int*/true);
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
for (auto & i : indices) {

View File

@ -943,7 +943,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
@supported_platform
# @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
@common_utils.parametrize("dtype", test_dtypes)
def test_njt_causal(self, dtype):
offsets = torch.tensor(
@ -1210,6 +1209,29 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
)
@supported_platform
def test_block_mask_attributes(self):
offset = torch.zeros(8, device="cuda")
def causal(score, b, h, q, kv):
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
block_mask = create_block_mask(causal, 4, 2, 2048, 2048)
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
self.assertEqual(block_mask.sparsity(), 46.875)
self.assertEqual(block_mask[0].sparsity(), 46.875)
self.assertEqual(block_mask[1, 0].sparsity(), 46.875)
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
offset = torch.arange(8, device="cuda")
block_mask = create_block_mask(causal, 8, 1, 2048, 2048)
self.assertEqual(block_mask.sparsity(), 29.1015625)
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
@supported_platform
def test_block_mask_viz(self):
def causal(score, b, h, q, kv):
@ -1230,7 +1252,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.assertExpectedInline(
replace_non_printable(str(block_mask)),
"""\
BlockMask(sparsity=46.88%,smask=
BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
(0,s0)
@@ssssssssssssssssssssssssssssss
@@@@ssssssssssssssssssssssssssss
@@@@@@ssssssssssssssssssssssssss
@ -1250,6 +1273,15 @@ BlockMask(sparsity=46.88%,smask=
)""",
)
offset = torch.arange(8, device="cuda")
def causal_offset(score, b, h, q, kv):
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
block_mask = create_block_mask(causal_offset, 8, 1, 2048, 2048)
str_block_mask = str(block_mask)
self.assertTrue("sparsity=29.10" in str_block_mask)
@supported_platform
def test_fw_bw_graph_correctness(self):
cnt = CompileCounterWithBackend("aot_eager")

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
"""This module implements the user facing API for flex_attention in PyTorch."""
import functools
import itertools
import operator
from typing import Callable, Optional
import torch
@ -62,6 +64,8 @@ class BlockMask:
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
):
if kv_indices.dim() < 2:
raise RuntimeError("BlockMask kv_indices must have at least 2 dimensions")
self.kv_num_blocks = kv_num_blocks
self.kv_indices = kv_indices
self.q_num_blocks = q_num_blocks
@ -80,17 +84,54 @@ class BlockMask:
)
def __str__(self):
s = f"BlockMask(sparsity={self.sparsity():.2f}%, mask=\n"
s += self.to_string()
s += ")"
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
mask_str = self.to_string().strip()
s += mask_str
s += "\n)"
return s
def __getitem__(self, index) -> "BlockMask":
tensors = self.as_tuple()[:-2]
tensors = [x[index] for x in tensors]
return BlockMask(
tensors[0],
tensors[1],
tensors[2],
tensors[3],
KV_BLOCK_SIZE=self.KV_BLOCK_SIZE,
Q_BLOCK_SIZE=self.Q_BLOCK_SIZE,
)
@property
def shape(self):
"""
Returns the shape of the mask.
"""
*batch_dims, q_length, _ = self.kv_indices.shape
q_length = self.kv_num_blocks.shape[-1] * self.KV_BLOCK_SIZE
kv_length = self.q_num_blocks.shape[-1] * self.Q_BLOCK_SIZE
return tuple(batch_dims + [q_length, kv_length])
def numel(self):
"""
Returns the number of elements (not accounting for sparsity) in the mask.
"""
shape = self.shape
def _prod(xs):
return functools.reduce(operator.mul, xs, 1)
return _prod(shape)
def sparsity(self) -> float:
"""
Computes the percentage of blocks that are sparse (i.e. not computed)
"""
dense_mask = self.to_dense()
dense_ratio = ((dense_mask != 0).sum()) / dense_mask.numel()
total_size = self.numel()
computed_size = (
self.kv_num_blocks.sum().item() * self.KV_BLOCK_SIZE * self.Q_BLOCK_SIZE
)
dense_ratio = computed_size / total_size
return 100 * (1 - dense_ratio)
def to_dense(self) -> torch.Tensor:
@ -99,9 +140,8 @@ class BlockMask:
"""
num_rows = self.kv_num_blocks.shape[-1]
num_cols = self.q_num_blocks.shape[-1]
batch, head = self.kv_num_blocks.shape[:2]
batch_dims = self.kv_num_blocks.shape[:-1]
device = self.kv_num_blocks.device
assert batch == 1, head == 1
def create_dense_one(kv_num_blocks, kv_indices):
dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32)
@ -116,51 +156,84 @@ class BlockMask:
valid_indices = torch.where(index_mask, kv_indices, num_cols)
# set the values in 'a' to 1 where the indices are valid
dense_mask[row_indices, valid_indices] = 1
dense_mask[row_indices, valid_indices] = torch.tensor(
1, device=dense_mask.device, dtype=dense_mask.dtype
)
return dense_mask[:, :num_cols]
out = create_dense_one(self.kv_num_blocks[0, 0], self.kv_indices[0, 0])
create_dense_batched = create_dense_one
for _ in range(len(batch_dims)):
create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0))
out = create_dense_batched(self.kv_num_blocks, self.kv_indices)
return out
def to_string(self, grid_size=(20, 20)):
def to_string(self, grid_size=(20, 20), limit=4):
"""
Returns a string representation of the block mask. Quite nifty.
If grid_size is None, prints out an uncompressed version. Warning, it can be quite big!
"""
dense_mask = self.to_dense()
num_rows, num_cols = dense_mask.shape
*batch_dims, num_rows, num_cols = dense_mask.shape
if isinstance(grid_size, int):
max_rows = grid_size
max_cols = grid_size
elif grid_size is None:
elif grid_size == -1:
max_rows = num_rows
max_cols = num_cols
else:
max_rows, max_cols = grid_size
vis = ""
def summarize_section(section):
percentage = section.float().mean().item()
if percentage == 1:
return ""
elif percentage == 0:
return " "
else:
return ""
def create_block_vis(*batch_idx):
descriptors = []
def cdiv(a, b):
return (a + (b - 1)) // b
descriptors.append(f"{batch_idx}")
row_step = max(1, cdiv(num_rows, max_rows))
col_step = max(1, cdiv(num_cols, max_cols))
vis = ", ".join(reversed(descriptors)) + "\n"
for r in range(0, num_rows, row_step):
for c in range(0, num_cols, col_step):
char = summarize_section(dense_mask[r : r + row_step, c : c + col_step])
vis += char * 2
vis += "\n"
return vis
def summarize_section(section):
percentage = section.float().mean().item()
if percentage == 1:
return ""
elif percentage == 0:
return " "
else:
return ""
def cdiv(a, b):
return (a + (b - 1)) // b
row_step = max(1, cdiv(num_rows, max_rows))
col_step = max(1, cdiv(num_cols, max_cols))
for r in range(0, num_rows, row_step):
for c in range(0, num_cols, col_step):
cur_mask = dense_mask
for idx in batch_idx:
cur_mask = cur_mask[idx]
char = summarize_section(
cur_mask[r : r + row_step, c : c + col_step]
)
vis += char * 2
vis += "\n"
return vis
total_vis = []
for idx, batch_idx in enumerate(
itertools.product(*[range(i) for i in batch_dims])
):
if idx == limit:
total_vis.append("...")
total_vis.append("To print out more, set BlockMask.to_string(limit=N)")
total_vis.append(
"You can also index (BlockMask[batch, head]) to choose a specific batch or head"
)
break
block_vis = create_block_vis(*batch_idx)
total_vis.append(block_vis)
return "\n".join(total_vis)
def _broadcast_to_dim(x, dim):