Fix slow static cache export tests (#40261)

This commit is contained in:
Jack
2025-08-19 02:24:07 -07:00
committed by GitHub
parent 56c44213b3
commit 2b59207a72
8 changed files with 16 additions and 16 deletions

View File

@ -325,14 +325,14 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"input_ids": input_ids,
"cache_position": cache_position
if cache_position is not None
else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
else torch.arange(input_ids.shape[-1], dtype=torch.long, device=model_device),
}
else: # inputs_embeds
input_kwargs = {
"inputs_embeds": inputs_embeds,
"cache_position": cache_position
if cache_position is not None
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model_device),
}
exported_program = torch.export.export(

View File

@ -463,8 +463,8 @@ class GemmaIntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

View File

@ -368,8 +368,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

View File

@ -354,8 +354,8 @@ class LlamaIntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

View File

@ -387,8 +387,8 @@ class OlmoIntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

View File

@ -415,8 +415,8 @@ class Phi3IntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens

View File

@ -305,8 +305,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
"2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(

View File

@ -295,8 +295,8 @@ class Qwen3IntegrationTest(unittest.TestCase):
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(