From d08cabe31475dbe307c49781bae6558ac8eafa52 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 15 Sep 2025 12:05:41 -0700 Subject: [PATCH] [BC Breaking] Remove flex + njt code paths (#161734) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161734 Approved by: https://github.com/jbschlosser --- docs/source/nn.attention.flex_attention.md | 3 - test/test_nestedtensor.py | 121 ------------ torch/_higher_order_ops/flex_attention.py | 7 - torch/nested/_internal/ops.py | 138 -------------- torch/nn/attention/flex_attention.py | 208 --------------------- 5 files changed, 477 deletions(-) diff --git a/docs/source/nn.attention.flex_attention.md b/docs/source/nn.attention.flex_attention.md index 4cfb51c5945c..8c51cee27651 100644 --- a/docs/source/nn.attention.flex_attention.md +++ b/docs/source/nn.attention.flex_attention.md @@ -30,9 +30,6 @@ .. autofunction:: create_mask ``` ```{eval-rst} -.. autofunction:: create_nested_block_mask -``` -```{eval-rst} .. autofunction:: and_masks ``` ```{eval-rst} diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index ac97f2beda8e..5affbb74cca0 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -26,7 +26,6 @@ from torch.nested._internal.nested_tensor import ( NestedTensor, ViewNestedFromBuffer, ) -from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FUSED_ATTENTION, SM70OrLater, @@ -36,7 +35,6 @@ from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, - flex_attention_supported_platform, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -60,7 +58,6 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, serialTest, - skipIfRocm, skipIfSlowGradcheckEnv, skipIfTorchDynamo, subtest, @@ -7285,124 +7282,6 @@ torch.cuda.synchronize() return query, key, value - @unittest.skip( - "Temporarily skip - nested tensor backward pass broken after return-max-scores commit" - ) - @onlyCUDA - @flex_attention_supported_platform - @dtypes(torch.float32) - # non-contiguous with holes not supported yet - @decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"]) - @parametrize("noncontig_with_holes", [False, True]) - @parametrize("cross_attention", [False, True]) - @skipIfRocm - def test_flex_attention(self, device, dtype, noncontig_with_holes, cross_attention): - query, key, value = self._rand_qkv( - device, dtype, noncontig_with_holes, q_and_kv_match=(not cross_attention) - ) - - # Run FlexAttention with a causal mask - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - if cross_attention: - block_mask = create_nested_block_mask( - causal_mask, 1, 1, query, key, _compile=True - ) - else: - block_mask = create_nested_block_mask( - causal_mask, 1, 1, query, _compile=True - ) - - out_flex = flex_attention(query, key, value, block_mask=block_mask) - grad_out = torch.randn_like(out_flex) - grads_flex = torch.autograd.grad( - out_flex, inputs=(query, key, value), grad_outputs=(grad_out,) - ) - flex_outs = [out_flex, *grads_flex] - - # Run FlexAttention with a score_mod that represents causal attention - def causal_score_mod(score, b, h, q_idx, kv_idx): - return torch.where(q_idx >= kv_idx, score, float("-inf")) - - out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod) - grads_flex2 = torch.autograd.grad( - out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,) - ) - flex_outs2 = [out_flex2, *grads_flex2] - - # Run causal SDPA for comparison - out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True) - grads_sdpa = torch.autograd.grad( - out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,) - ) - sdpa_outs = [out_sdpa, *grads_sdpa] - - # Compare flex vs. SDPA output and grads - for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs): - self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested) - self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2) - self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2) - - @onlyCUDA - @flex_attention_supported_platform - @dtypes(torch.float32) - def test_flex_attention_converts_stacked_seq_indices(self, device, dtype): - # This test verifies that a score_mod function written to operate within - # NJT sequence index space, such as a lookup table, works correctly. This - # validates that FlexAttention properly converts indices within the - # "stacked sequence" space used for NJT -> sequence-relative indices. - query, key, value = self._rand_qkv(device, dtype) - - # Test with score_mod - score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype) - - def my_score_mod(score, b, h, q_idx, kv_idx): - return score_mod_table[q_idx] - - flex_attention(query, key, value, score_mod=my_score_mod) - - # Test with batch-specific score_mod - batch_size = query.size(0) - batch_table = torch.randn(batch_size, device=device, dtype=dtype) - # Keep score the same for batch index == 0 - batch_table[0].zero_() - - def batch_specific_score_mod(score, b, h, q_idx, kv_idx): - return score + batch_table[b] - - def identity_score_mod(score, b, h, q_idx, kv_idx): - return score - - output = flex_attention(query, key, value, score_mod=batch_specific_score_mod) - output_identity = flex_attention( - query, key, value, score_mod=identity_score_mod - ) - - # Guard against a bug where the batch index passed to score_mod is always b == 0. - # Output would be equivalent to applying an identity score_mod. - # See https://github.com/pytorch/pytorch/issues/143788 - self.assertFalse(torch.allclose(output._values, output_identity._values)) - - # Test with mask_mod - mask_mod_table = score_mod_table > 0.0 - - def my_mask_mod(b, h, q_idx, kv_idx): - return mask_mod_table[q_idx] - - def my_mask_mod2(b, h, q_idx, kv_idx): - return mask_mod_table[q_idx] & (b == 0) - - block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True) - output = flex_attention(query, key, value, block_mask=block_mask) - - block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True) - output2 = flex_attention(query, key, value, block_mask=block_mask2) - - # Guard against a bug where the batch index passed to mask_mod is always b == 0. - # See https://github.com/pytorch/pytorch/issues/143788 - self.assertFalse(torch.allclose(output._values, output2._values)) - @dtypes(torch.float32) def test_apply_(self, device, dtype): nt = random_nt_from_dims( diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 2d352ae03a45..b52bab0e3272 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -505,13 +505,6 @@ def flex_attention_fake_impl( ): return NotImplemented - # TODO: Figure out a better way to handle this for NJT than using sum() - if query.is_nested: - out = torch.empty_like(query, memory_format=torch.contiguous_format) - logsumexp = query.sum(dim=-1) - max_scores = query.max(dim=-1)[0] - return out, logsumexp, max_scores - v_head_dim = value.size(-1) batch_size, num_heads, seq_len_q, _q_head_dim = query.shape logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 19b1fe670835..8cec2634a30f 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -2665,144 +2665,6 @@ def matmul_backward_default(func, *args, **kwargs): return (grad_self, grad_other) -from torch._higher_order_ops.flex_attention import ( - flex_attention as flex_attention_hop, - flex_attention_backward as flex_attention_backward_hop, -) -from torch.fx.graph_module import GraphModule - - -@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc] -def flex_njt( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - score_mod: Callable, - block_mask: Tuple, - scale: float, - kernel_options: Dict[str, Any], - score_mod_other_buffers: Tuple = (), - mask_mod_other_buffers: Tuple = (), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4 - - # TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense. - if any( - isinstance(buf, torch.Tensor) and buf.is_nested - for buf in score_mod_other_buffers + mask_mod_other_buffers - ): - raise RuntimeError( - "flex_attention(): Nested tensor score_mod / mask_mod buffers are not " - "currently supported. Please file an issue if this is important to you." - ) - - # Always set them since 0 sized elements are not handled gracefully - kernel_options = {**kernel_options, "OUTPUT_MAX": True, "OUTPUT_LOGSUMEXP": True} - - # need to pass dense tensor of shape (B, n_heads, sum(seq_len), D) - output = flex_attention_hop( - query.values().unsqueeze(0), - key.values().unsqueeze(0), - value.values().unsqueeze(0), - score_mod=score_mod, - block_mask=block_mask, - scale=scale, - kernel_options=kernel_options, - score_mod_other_buffers=score_mod_other_buffers, - mask_mod_other_buffers=mask_mod_other_buffers, - ) - - # wrap outputs as NJT - output_njt = torch.nested.nested_tensor_from_jagged( - output[0].transpose(1, 2).squeeze(0), - query._offsets, # type: ignore[attr-defined] - query._lengths, # type: ignore[attr-defined] - min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - - logsumexp_njt = torch.nested.nested_tensor_from_jagged( - output[1].transpose(1, 2).squeeze(0), - query._offsets, # type: ignore[attr-defined] - query._lengths, # type: ignore[attr-defined] - min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - - max_scores_njt = torch.nested.nested_tensor_from_jagged( - output[2].transpose(1, 2).squeeze(0), - query._offsets, # type: ignore[attr-defined] - query._lengths, # type: ignore[attr-defined] - min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - - return (output_njt, logsumexp_njt, max_scores_njt) - - -@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc] -def flex_njt_backward( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - logsumexp: torch.Tensor, - grad_out: torch.Tensor, - grad_logsumexp: torch.Tensor, - fw_graph: Union[Callable, GraphModule], - joint_graph: GraphModule, - block_mask: Tuple, - scale: float, - kernel_options: Dict[str, Any], - score_mod_other_buffers: Tuple = (), - mask_mod_other_buffers: Tuple = (), -) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] -]: - output = flex_attention_backward_hop( - query.values().unsqueeze(0), - key.values().unsqueeze(0), - value.values().unsqueeze(0), - out=out.values().unsqueeze(0), - logsumexp=logsumexp.values().unsqueeze(0), - grad_out=grad_out.values().unsqueeze(0), - grad_logsumexp=grad_logsumexp.values().unsqueeze(0), - fw_graph=fw_graph, - joint_graph=joint_graph, - block_mask=block_mask, - scale=scale, - kernel_options=kernel_options, - score_mod_other_buffers=score_mod_other_buffers, - mask_mod_other_buffers=mask_mod_other_buffers, - ) - - # wrap grads as NJTs - dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output - njt_q_grad = torch.nested.nested_tensor_from_jagged( - dense_q_grad.transpose(1, 2).squeeze(0), - query._offsets, # type: ignore[attr-defined] - query._lengths, # type: ignore[attr-defined] - min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - njt_k_grad = torch.nested.nested_tensor_from_jagged( - dense_k_grad.transpose(1, 2).squeeze(0), - key._offsets, # type: ignore[attr-defined] - key._lengths, # type: ignore[attr-defined] - min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - njt_v_grad = torch.nested.nested_tensor_from_jagged( - dense_v_grad.transpose(1, 2).squeeze(0), - value._offsets, # type: ignore[attr-defined] - value._lengths, # type: ignore[attr-defined] - min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined] - max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined] - ).transpose(1, 2) - - return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads) - - # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index ccd5697aa49c..a6d6e1228a32 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -74,7 +74,6 @@ __all__ = [ "FlexKernelOptions", "create_block_mask", "create_mask", - "create_nested_block_mask", "or_masks", "and_masks", "noop_mask", @@ -1111,179 +1110,6 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: ) -def _nested_mod_func_adapter( - orig_mod_func: Union[_score_mod_signature, _mask_mod_signature], - q_nt: torch.Tensor, - kv_nt: torch.Tensor, - is_score_mod: bool, -) -> Union[_score_mod_signature, _mask_mod_signature]: - r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func - should be written as if operating over a single sequence at a item. This adapter will - handle conversion from indices operating over a "stacked sequence" of length ``sum(S)`` - for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``. - - Args: - orig_mod_func (Callable): Function to modify attention scores. It takes four or five - arguments, depending on whether a mask_mod or score_mod func is passed. - q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length - structure for query. - kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length - structure for key / value. - is_score_mod (bool): Indicates whether the mod function is a score_mod. - - Returns: - nt_score_mod: An NJT-compatible version of orig_score_mod - """ - - # Used to convert indices within the "stacked" sequence (range [0, sum(*))) - # to "sequence local" indices (range [0, S) for each S). - def _build_seq_idx(offsets, total_length): - range_tensor = torch.arange( - total_length, device=offsets.device, dtype=torch.int32 - ) - - # Use searchsorted to find the index for each position - # NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values. - # If we ever loosen this restriction, this logic will need to be updated. - seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1 - return seq_idx - - q_offsets = q_nt._offsets # type: ignore[attr-defined] - kv_offsets = kv_nt._offsets # type: ignore[attr-defined] - q_seq_idx = _build_seq_idx(q_offsets, q_nt._values.shape[q_nt._ragged_idx - 1]) # type: ignore[attr-defined] - if q_nt is kv_nt: - kv_seq_idx = q_seq_idx - else: - # cross attention case - kv_seq_idx = _build_seq_idx( - kv_offsets, - kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined] - ) - - # Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers - # to the sequence length for each sequence in the NJT, for use in given - # score_mod. This allows the user to write a score_mod as if it were - # operating on a single sequence and the "stacked sequence" is split - # automatically into individual sequences for them. - if is_score_mod: - - def nt_score_mod(score, b, h, q_idx, kv_idx): - b_nested = q_seq_idx[q_idx] - q_nested = q_idx - q_offsets[q_seq_idx[q_idx]] - kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]] - is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx] - return torch.where( - is_same_sequence, - orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg] - # don't allow inter-sequence attention - float("-inf"), - ) - - return nt_score_mod - else: - - def nt_mask_mod(b, h, q_idx, kv_idx): - b_nested = q_seq_idx[q_idx] - q_nested = q_idx - q_offsets[q_seq_idx[q_idx]] - kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]] - # don't allow inter-sequence attention - is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx] - return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg] - - return nt_mask_mod - - -def create_nested_block_mask( - mask_mod: _mask_mod_signature, - B: Optional[int], - H: Optional[int], - q_nt: torch.Tensor, - kv_nt: Optional[torch.Tensor] = None, - BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, - _compile=False, -) -> BlockMask: - r"""This function creates a nested tensor compatible block mask tuple from a mask_mod - function. The returned BlockMask will be on the device specified by the input nested tensor. - - Args: - mask_mod (Callable): mask_mod function. This is a callable that defines the - masking pattern for the attention mechanism. It takes four arguments: - b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). - It should return a boolean tensor indicating which attention connections are allowed - (True) or masked out (False). - B (int): Batch size. - H (int): Number of query heads. - q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length - structure for query. The block mask will be constructed to operate on a "stacked - sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT. - kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length - structure for key / value, allowing for cross attention. The block mask will be - constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence - length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure - for key / value as well. Default: None - BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is - provided it is used for both query and key/value. - - Returns: - BlockMask: A BlockMask object that contains the block mask information. - - Example Usage: - .. code-block:: python - - # shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch - query = torch.nested.nested_tensor(..., layout=torch.jagged) - key = torch.nested.nested_tensor(..., layout=torch.jagged) - value = torch.nested.nested_tensor(..., layout=torch.jagged) - - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - - block_mask = create_nested_block_mask( - causal_mask, 1, 1, query, _compile=True - ) - output = flex_attention(query, key, value, block_mask=block_mask) - - .. code-block:: python - - # shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch - query = torch.nested.nested_tensor(..., layout=torch.jagged) - key = torch.nested.nested_tensor(..., layout=torch.jagged) - value = torch.nested.nested_tensor(..., layout=torch.jagged) - - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - - # cross attention case: pass both query and key/value NJTs - block_mask = create_nested_block_mask( - causal_mask, 1, 1, query, key, _compile=True - ) - output = flex_attention(query, key, value, block_mask=block_mask) - """ - # use same structure for kv as for q by default - if kv_nt is None: - kv_nt = q_nt - if q_nt.device != kv_nt.device: - raise ValueError( - "create_nested_block_mask(): Expected q_nt and kv_nt to be on the same device" - ) - return create_block_mask( - _nested_mod_func_adapter(mask_mod, q_nt, kv_nt, is_score_mod=False), # type: ignore[arg-type] - B, - H, - q_nt._values.shape[q_nt._ragged_idx - 1], # type: ignore[attr-defined] - kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined] - device=q_nt.device, # type: ignore[arg-type] - # compile is important so we don't materialize a mask_tensor of - # shape (1, 1, total_seqlen, total_seqlen) - BLOCK_SIZE=BLOCK_SIZE, - _compile=_compile, - ) - - def _apply_kernel_options( query: Tensor, key: Tensor, @@ -1359,25 +1185,6 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor): ) -def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor): - # Currently, inputs can only be all nested or no nested. - if query.is_nested != key.is_nested or key.is_nested != value.is_nested: - raise ValueError( - "FlexAttention does not support mixed nested tensor / non-nested tensor inputs. " - "Please file an issue requesting this if it is important to you." - ) - - if ( - (query.is_nested and query._lengths is not None) # type: ignore[attr-defined] - or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined] - or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined] - ): - raise ValueError( - "FlexAttention does not support nested tensors that are non-contiguous with holes. " - "Please file an issue requesting this if it is important to you." - ) - - def _enforce_mem_layouts( query: Tensor, key: Tensor, value: Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1517,7 +1324,6 @@ def flex_attention( _validate_sdpa_input(query, key, value) _validate_embed_dim(query, key, value) _validate_device(query, key, value) - _validate_nestedness(query, key, value) query, key, value = _enforce_mem_layouts(query, key, value) if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: raise NotImplementedError("NYI: query, key, and value must be 4D tensors") @@ -1552,14 +1358,6 @@ def flex_attention( if score_mod is None: score_mod = _identity - elif query.is_nested: - # use same NJT if the ragged structures for sequence lengths match between q and kv - kv = ( - query - if query.size(query._ragged_idx) == key.size(query._ragged_idx) # type: ignore[attr-defined] - else key - ) - score_mod = _nested_mod_func_adapter(score_mod, query, kv, is_score_mod=True) # type: ignore[assignment] if block_mask is None: block_mask = _create_empty_block_mask(query, key) @@ -1570,12 +1368,6 @@ def flex_attention( ): # This corresponds to the case where we essentially have a "no-op" block mask. pass - elif query.is_nested: - if block_mask.shape[-2] != query._values.size(query._ragged_idx - 1): # type: ignore[attr-defined] - raise RuntimeError( - f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input " - f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined] - ) else: block_mask_q_len = block_mask.shape[-2] block_mask_kv_len = block_mask.shape[-1]