mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add block mask utility support for batches and heads > 1 (#130227)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130227 Approved by: https://github.com/yanboliang ghstack dependencies: #130160, #130106, #130224
This commit is contained in:
@ -48,18 +48,12 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
|
||||
return result;
|
||||
}
|
||||
|
||||
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
|
||||
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
|
||||
for (const auto& tensor : indices) {
|
||||
if (tensor.has_value() && tensor->defined()) {
|
||||
auto scalarType = tensor->scalar_type();
|
||||
if (allow_int) {
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
|
||||
}
|
||||
} else {
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
||||
}
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ const Tensor& value){
|
||||
}
|
||||
|
||||
inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
||||
checkIndexTensorTypes(orig, /*allow_int*/ true);
|
||||
checkIndexTensorTypes(orig);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
// next broadcast all index tensors together
|
||||
|
@ -389,7 +389,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
|
||||
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
|
||||
checkIndexTensorTypes(orig, /*allow_int*/true);
|
||||
checkIndexTensorTypes(orig);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
for (auto & i : indices) {
|
||||
|
@ -943,7 +943,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
|
||||
|
||||
@supported_platform
|
||||
# @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_njt_causal(self, dtype):
|
||||
offsets = torch.tensor(
|
||||
@ -1210,6 +1209,29 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
)
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_attributes(self):
|
||||
offset = torch.zeros(8, device="cuda")
|
||||
|
||||
def causal(score, b, h, q, kv):
|
||||
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
|
||||
|
||||
block_mask = create_block_mask(causal, 4, 2, 2048, 2048)
|
||||
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
|
||||
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
|
||||
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
|
||||
self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048)
|
||||
self.assertEqual(block_mask.sparsity(), 46.875)
|
||||
self.assertEqual(block_mask[0].sparsity(), 46.875)
|
||||
self.assertEqual(block_mask[1, 0].sparsity(), 46.875)
|
||||
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
|
||||
|
||||
offset = torch.arange(8, device="cuda")
|
||||
block_mask = create_block_mask(causal, 8, 1, 2048, 2048)
|
||||
self.assertEqual(block_mask.sparsity(), 29.1015625)
|
||||
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
|
||||
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_viz(self):
|
||||
def causal(score, b, h, q, kv):
|
||||
@ -1230,7 +1252,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
self.assertExpectedInline(
|
||||
replace_non_printable(str(block_mask)),
|
||||
"""\
|
||||
BlockMask(sparsity=46.88%,smask=
|
||||
BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
(0,s0)
|
||||
@@ssssssssssssssssssssssssssssss
|
||||
@@@@ssssssssssssssssssssssssssss
|
||||
@@@@@@ssssssssssssssssssssssssss
|
||||
@ -1250,6 +1273,15 @@ BlockMask(sparsity=46.88%,smask=
|
||||
)""",
|
||||
)
|
||||
|
||||
offset = torch.arange(8, device="cuda")
|
||||
|
||||
def causal_offset(score, b, h, q, kv):
|
||||
return torch.where(q + offset[b] * 128 >= kv, score, -float("inf"))
|
||||
|
||||
block_mask = create_block_mask(causal_offset, 8, 1, 2048, 2048)
|
||||
str_block_mask = str(block_mask)
|
||||
self.assertTrue("sparsity=29.10" in str_block_mask)
|
||||
|
||||
@supported_platform
|
||||
def test_fw_bw_graph_correctness(self):
|
||||
cnt = CompileCounterWithBackend("aot_eager")
|
||||
|
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This module implements the user facing API for flex_attention in PyTorch."""
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -62,6 +64,8 @@ class BlockMask:
|
||||
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
):
|
||||
if kv_indices.dim() < 2:
|
||||
raise RuntimeError("BlockMask kv_indices must have at least 2 dimensions")
|
||||
self.kv_num_blocks = kv_num_blocks
|
||||
self.kv_indices = kv_indices
|
||||
self.q_num_blocks = q_num_blocks
|
||||
@ -80,17 +84,54 @@ class BlockMask:
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
s = f"BlockMask(sparsity={self.sparsity():.2f}%, mask=\n"
|
||||
s += self.to_string()
|
||||
s += ")"
|
||||
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
|
||||
mask_str = self.to_string().strip()
|
||||
s += mask_str
|
||||
s += "\n)"
|
||||
return s
|
||||
|
||||
def __getitem__(self, index) -> "BlockMask":
|
||||
tensors = self.as_tuple()[:-2]
|
||||
tensors = [x[index] for x in tensors]
|
||||
return BlockMask(
|
||||
tensors[0],
|
||||
tensors[1],
|
||||
tensors[2],
|
||||
tensors[3],
|
||||
KV_BLOCK_SIZE=self.KV_BLOCK_SIZE,
|
||||
Q_BLOCK_SIZE=self.Q_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the shape of the mask.
|
||||
"""
|
||||
*batch_dims, q_length, _ = self.kv_indices.shape
|
||||
q_length = self.kv_num_blocks.shape[-1] * self.KV_BLOCK_SIZE
|
||||
kv_length = self.q_num_blocks.shape[-1] * self.Q_BLOCK_SIZE
|
||||
return tuple(batch_dims + [q_length, kv_length])
|
||||
|
||||
def numel(self):
|
||||
"""
|
||||
Returns the number of elements (not accounting for sparsity) in the mask.
|
||||
"""
|
||||
shape = self.shape
|
||||
|
||||
def _prod(xs):
|
||||
return functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
return _prod(shape)
|
||||
|
||||
def sparsity(self) -> float:
|
||||
"""
|
||||
Computes the percentage of blocks that are sparse (i.e. not computed)
|
||||
"""
|
||||
dense_mask = self.to_dense()
|
||||
dense_ratio = ((dense_mask != 0).sum()) / dense_mask.numel()
|
||||
total_size = self.numel()
|
||||
computed_size = (
|
||||
self.kv_num_blocks.sum().item() * self.KV_BLOCK_SIZE * self.Q_BLOCK_SIZE
|
||||
)
|
||||
dense_ratio = computed_size / total_size
|
||||
return 100 * (1 - dense_ratio)
|
||||
|
||||
def to_dense(self) -> torch.Tensor:
|
||||
@ -99,9 +140,8 @@ class BlockMask:
|
||||
"""
|
||||
num_rows = self.kv_num_blocks.shape[-1]
|
||||
num_cols = self.q_num_blocks.shape[-1]
|
||||
batch, head = self.kv_num_blocks.shape[:2]
|
||||
batch_dims = self.kv_num_blocks.shape[:-1]
|
||||
device = self.kv_num_blocks.device
|
||||
assert batch == 1, head == 1
|
||||
|
||||
def create_dense_one(kv_num_blocks, kv_indices):
|
||||
dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32)
|
||||
@ -116,51 +156,84 @@ class BlockMask:
|
||||
valid_indices = torch.where(index_mask, kv_indices, num_cols)
|
||||
|
||||
# set the values in 'a' to 1 where the indices are valid
|
||||
dense_mask[row_indices, valid_indices] = 1
|
||||
dense_mask[row_indices, valid_indices] = torch.tensor(
|
||||
1, device=dense_mask.device, dtype=dense_mask.dtype
|
||||
)
|
||||
return dense_mask[:, :num_cols]
|
||||
|
||||
out = create_dense_one(self.kv_num_blocks[0, 0], self.kv_indices[0, 0])
|
||||
create_dense_batched = create_dense_one
|
||||
for _ in range(len(batch_dims)):
|
||||
create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0))
|
||||
|
||||
out = create_dense_batched(self.kv_num_blocks, self.kv_indices)
|
||||
return out
|
||||
|
||||
def to_string(self, grid_size=(20, 20)):
|
||||
def to_string(self, grid_size=(20, 20), limit=4):
|
||||
"""
|
||||
Returns a string representation of the block mask. Quite nifty.
|
||||
|
||||
If grid_size is None, prints out an uncompressed version. Warning, it can be quite big!
|
||||
"""
|
||||
dense_mask = self.to_dense()
|
||||
num_rows, num_cols = dense_mask.shape
|
||||
*batch_dims, num_rows, num_cols = dense_mask.shape
|
||||
if isinstance(grid_size, int):
|
||||
max_rows = grid_size
|
||||
max_cols = grid_size
|
||||
elif grid_size is None:
|
||||
elif grid_size == -1:
|
||||
max_rows = num_rows
|
||||
max_cols = num_cols
|
||||
else:
|
||||
max_rows, max_cols = grid_size
|
||||
vis = ""
|
||||
|
||||
def summarize_section(section):
|
||||
percentage = section.float().mean().item()
|
||||
if percentage == 1:
|
||||
return "█"
|
||||
elif percentage == 0:
|
||||
return " "
|
||||
else:
|
||||
return "░"
|
||||
def create_block_vis(*batch_idx):
|
||||
descriptors = []
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + (b - 1)) // b
|
||||
descriptors.append(f"{batch_idx}")
|
||||
|
||||
row_step = max(1, cdiv(num_rows, max_rows))
|
||||
col_step = max(1, cdiv(num_cols, max_cols))
|
||||
vis = ", ".join(reversed(descriptors)) + "\n"
|
||||
|
||||
for r in range(0, num_rows, row_step):
|
||||
for c in range(0, num_cols, col_step):
|
||||
char = summarize_section(dense_mask[r : r + row_step, c : c + col_step])
|
||||
vis += char * 2
|
||||
vis += "\n"
|
||||
return vis
|
||||
def summarize_section(section):
|
||||
percentage = section.float().mean().item()
|
||||
if percentage == 1:
|
||||
return "█"
|
||||
elif percentage == 0:
|
||||
return " "
|
||||
else:
|
||||
return "░"
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + (b - 1)) // b
|
||||
|
||||
row_step = max(1, cdiv(num_rows, max_rows))
|
||||
col_step = max(1, cdiv(num_cols, max_cols))
|
||||
|
||||
for r in range(0, num_rows, row_step):
|
||||
for c in range(0, num_cols, col_step):
|
||||
cur_mask = dense_mask
|
||||
for idx in batch_idx:
|
||||
cur_mask = cur_mask[idx]
|
||||
char = summarize_section(
|
||||
cur_mask[r : r + row_step, c : c + col_step]
|
||||
)
|
||||
vis += char * 2
|
||||
vis += "\n"
|
||||
return vis
|
||||
|
||||
total_vis = []
|
||||
for idx, batch_idx in enumerate(
|
||||
itertools.product(*[range(i) for i in batch_dims])
|
||||
):
|
||||
if idx == limit:
|
||||
total_vis.append("...")
|
||||
total_vis.append("To print out more, set BlockMask.to_string(limit=N)")
|
||||
total_vis.append(
|
||||
"You can also index (BlockMask[batch, head]) to choose a specific batch or head"
|
||||
)
|
||||
break
|
||||
block_vis = create_block_vis(*batch_idx)
|
||||
total_vis.append(block_vis)
|
||||
|
||||
return "\n".join(total_vis)
|
||||
|
||||
|
||||
def _broadcast_to_dim(x, dim):
|
||||
|
Reference in New Issue
Block a user