mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Fix slow static cache export tests (#40261)
This commit is contained in:
@ -325,14 +325,14 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"input_ids": input_ids,
|
||||
"cache_position": cache_position
|
||||
if cache_position is not None
|
||||
else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
|
||||
else torch.arange(input_ids.shape[-1], dtype=torch.long, device=model_device),
|
||||
}
|
||||
else: # inputs_embeds
|
||||
input_kwargs = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"cache_position": cache_position
|
||||
if cache_position is not None
|
||||
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
|
||||
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model_device),
|
||||
}
|
||||
|
||||
exported_program = torch.export.export(
|
||||
|
@ -463,8 +463,8 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
|
@ -368,8 +368,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
|
@ -354,8 +354,8 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
|
@ -387,8 +387,8 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
|
@ -415,8 +415,8 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
|
@ -305,8 +305,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
"2.7.0"
|
||||
) # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
|
@ -295,8 +295,8 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
|
Reference in New Issue
Block a user