[FA3] Fix masking and loading logic in same process (#41217)

fix loading and fa3 masking
This commit is contained in:
Anton Vlasjuk
2025-10-01 16:36:12 +02:00
committed by GitHub
parent 3256773974
commit 025531981c
5 changed files with 33 additions and 35 deletions

View File

@ -191,7 +191,7 @@ def load_and_register_kernel(attn_implementation: str) -> None:
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel)
lazy_import_flash_attention(kernel, force_import=True)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
# Register the kernel as a valid attention

View File

@ -625,6 +625,7 @@ class AttentionMaskInterface(GeneralInterface):
"sdpa": sdpa_mask,
"eager": eager_mask,
"flash_attention_2": flash_attention_mask,
"flash_attention_3": flash_attention_mask,
"flex_attention": flex_attention_mask,
}

View File

@ -124,7 +124,7 @@ def _lazy_define_process_function(flash_function):
return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
def lazy_import_flash_attention(implementation: Optional[str]):
def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False):
"""
Lazily import flash attention and return the respective functions + flags.
@ -132,11 +132,11 @@ def lazy_import_flash_attention(implementation: Optional[str]):
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
"""
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
global _process_flash_kwargs_fn
if _process_flash_kwargs_fn is None:
if force_import or _process_flash_kwargs_fn is None:
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn

View File

@ -2557,7 +2557,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
# preload flash attention here to allow compile with fullgraph
if applicable_attn_implementation.startswith("flash_attention"):
lazy_import_flash_attention(applicable_attn_implementation)
lazy_import_flash_attention(applicable_attn_implementation, force_import=True)
return applicable_attn_implementation

View File

@ -81,39 +81,31 @@ class FlashAttentionParityTest(unittest.TestCase):
@slow
def test_flash_attention_2_3_parity(self):
model_id = "meta-llama/Llama-3.2-1B-Instruct"
prompt = "The ETH AI Center is"
prompt = ["The ETH AI Center is", "What is life?"]
# 1. Load FA2 model and tokenizer
model_2 = AutoModelForCausalLM.from_pretrained(
# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
# 2. Load FA3 model
try:
model_3 = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
attn_implementation="flash_attention_3",
).to("cuda")
except (ValueError, ImportError) as e:
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
# 3. Generate with both models
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# 2. Generate with both models
inputs = tokenizer(prompt, padding=True, padding_side="left", return_tensors="pt").to("cuda")
with torch.no_grad():
output_2 = model_2.generate(
output_2 = model.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
output_3 = model_3.generate(
model.set_attn_implementation("flash_attention_3")
output_3 = model.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
# 4. Correctness check
# 4a. Logits
# 3. Correctness check
# 3a. Logits
logits_2 = torch.stack(output_2.scores)
logits_3 = torch.stack(output_3.scores)
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
@ -121,22 +113,27 @@ class FlashAttentionParityTest(unittest.TestCase):
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
# 4b. Generated text
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
# 3b. Generated text
text_2s, text_3s = [], []
for i in range(len(prompt)):
text_2s.append(tokenizer.decode(output_2.sequences[i], skip_special_tokens=True))
text_3s.append(tokenizer.decode(output_3.sequences[i], skip_special_tokens=True))
# 5. Performance check
rouge_scores = self._calculate_rouge_l(text_2s, text_3s)
for i in range(len(rouge_scores)):
assert rouge_scores[i] > 0.99, f"Generated texts at prompt {i} do not match (ROUGE-L: {rouge_scores[i]})"
# 4. Performance check
with torch.no_grad():
time_2 = self._benchmark_generation(model_2, inputs)
time_3 = self._benchmark_generation(model_3, inputs)
time_3 = self._benchmark_generation(model, inputs)
model.set_attn_implementation("flash_attention_2")
time_2 = self._benchmark_generation(model, inputs)
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
print(f"Prompt: '{prompt}'")
print(f"Generated text with Flash Attention 2: {text_2}")
print(f"Generated text with Flash Attention 3: {text_3}")
print(f"ROUGE-L: {rouge_score}")
print(f"Generated text with Flash Attention 2: {text_2s}")
print(f"Generated text with Flash Attention 3: {text_3s}")
print(f"ROUGE-L: {rouge_scores}")
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
print(f"Flash Attention 3 latency: {time_3:.2f} ms")