mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[FA3
] Fix masking and loading logic in same process (#41217)
fix loading and fa3 masking
This commit is contained in:
@ -191,7 +191,7 @@ def load_and_register_kernel(attn_implementation: str) -> None:
|
||||
if attention_wrapper is None:
|
||||
attention_wrapper = flash_attention_forward
|
||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||
lazy_import_flash_attention(kernel)
|
||||
lazy_import_flash_attention(kernel, force_import=True)
|
||||
elif kernel_name is not None:
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
# Register the kernel as a valid attention
|
||||
|
@ -625,6 +625,7 @@ class AttentionMaskInterface(GeneralInterface):
|
||||
"sdpa": sdpa_mask,
|
||||
"eager": eager_mask,
|
||||
"flash_attention_2": flash_attention_mask,
|
||||
"flash_attention_3": flash_attention_mask,
|
||||
"flex_attention": flex_attention_mask,
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ def _lazy_define_process_function(flash_function):
|
||||
return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
|
||||
|
||||
|
||||
def lazy_import_flash_attention(implementation: Optional[str]):
|
||||
def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False):
|
||||
"""
|
||||
Lazily import flash attention and return the respective functions + flags.
|
||||
|
||||
@ -132,11 +132,11 @@ def lazy_import_flash_attention(implementation: Optional[str]):
|
||||
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
|
||||
"""
|
||||
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
|
||||
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
|
||||
if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
|
||||
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
|
||||
|
||||
global _process_flash_kwargs_fn
|
||||
if _process_flash_kwargs_fn is None:
|
||||
if force_import or _process_flash_kwargs_fn is None:
|
||||
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
|
||||
|
||||
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
|
||||
|
@ -2557,7 +2557,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
# preload flash attention here to allow compile with fullgraph
|
||||
if applicable_attn_implementation.startswith("flash_attention"):
|
||||
lazy_import_flash_attention(applicable_attn_implementation)
|
||||
lazy_import_flash_attention(applicable_attn_implementation, force_import=True)
|
||||
|
||||
return applicable_attn_implementation
|
||||
|
||||
|
@ -81,39 +81,31 @@ class FlashAttentionParityTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attention_2_3_parity(self):
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
prompt = "The ETH AI Center is"
|
||||
prompt = ["The ETH AI Center is", "What is life?"]
|
||||
|
||||
# 1. Load FA2 model and tokenizer
|
||||
model_2 = AutoModelForCausalLM.from_pretrained(
|
||||
# 1. Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# 2. Load FA3 model
|
||||
try:
|
||||
model_3 = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_3",
|
||||
).to("cuda")
|
||||
except (ValueError, ImportError) as e:
|
||||
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
|
||||
|
||||
# 3. Generate with both models
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
# 2. Generate with both models
|
||||
inputs = tokenizer(prompt, padding=True, padding_side="left", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model_2.generate(
|
||||
output_2 = model.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
output_3 = model_3.generate(
|
||||
model.set_attn_implementation("flash_attention_3")
|
||||
output_3 = model.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
|
||||
# 4. Correctness check
|
||||
# 4a. Logits
|
||||
# 3. Correctness check
|
||||
# 3a. Logits
|
||||
logits_2 = torch.stack(output_2.scores)
|
||||
logits_3 = torch.stack(output_3.scores)
|
||||
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
|
||||
@ -121,22 +113,27 @@ class FlashAttentionParityTest(unittest.TestCase):
|
||||
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
|
||||
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
|
||||
|
||||
# 4b. Generated text
|
||||
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
|
||||
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
|
||||
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
|
||||
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
|
||||
# 3b. Generated text
|
||||
text_2s, text_3s = [], []
|
||||
for i in range(len(prompt)):
|
||||
text_2s.append(tokenizer.decode(output_2.sequences[i], skip_special_tokens=True))
|
||||
text_3s.append(tokenizer.decode(output_3.sequences[i], skip_special_tokens=True))
|
||||
|
||||
# 5. Performance check
|
||||
rouge_scores = self._calculate_rouge_l(text_2s, text_3s)
|
||||
for i in range(len(rouge_scores)):
|
||||
assert rouge_scores[i] > 0.99, f"Generated texts at prompt {i} do not match (ROUGE-L: {rouge_scores[i]})"
|
||||
|
||||
# 4. Performance check
|
||||
with torch.no_grad():
|
||||
time_2 = self._benchmark_generation(model_2, inputs)
|
||||
time_3 = self._benchmark_generation(model_3, inputs)
|
||||
time_3 = self._benchmark_generation(model, inputs)
|
||||
model.set_attn_implementation("flash_attention_2")
|
||||
time_2 = self._benchmark_generation(model, inputs)
|
||||
|
||||
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print(f"Generated text with Flash Attention 2: {text_2}")
|
||||
print(f"Generated text with Flash Attention 3: {text_3}")
|
||||
print(f"ROUGE-L: {rouge_score}")
|
||||
print(f"Generated text with Flash Attention 2: {text_2s}")
|
||||
print(f"Generated text with Flash Attention 3: {text_3s}")
|
||||
print(f"ROUGE-L: {rouge_scores}")
|
||||
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
|
||||
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
|
||||
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
|
||||
|
Reference in New Issue
Block a user