mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
This commit is contained in:
@ -198,34 +198,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
max_batch_size: int = 1,
|
||||
max_cache_len: int = 4096,
|
||||
):
|
||||
"""
|
||||
Initializes the exportable module with `HybridCache`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||
max_batch_size (int): Maximum batch size for the cache.
|
||||
max_cache_len (int): Maximum sequence length for the cache.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is configured with a unsupported cache implementation.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
|
||||
config = model.config.get_text_config()
|
||||
_generation_config = model.generation_config
|
||||
|
||||
if not hasattr(config, "use_cache") or config.use_cache is False:
|
||||
raise ValueError("The model must have caching enabled to be performant.")
|
||||
|
||||
if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None:
|
||||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||
if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
|
||||
self.model = TorchExportableModuleWithHybridCache(model)
|
||||
else:
|
||||
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
|
||||
# there is only 1 type of layers, so export will use `StaticCache` by default.
|
||||
logging.info(
|
||||
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
|
||||
)
|
||||
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len)
|
||||
self.model = TorchExportableModuleWithStaticCache(model)
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
@ -233,24 +232,31 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
"""
|
||||
return self.model.forward(input_ids, cache_position)
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
def export(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
dynamic_shapes: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
@ -260,14 +266,49 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
|
||||
Args:
|
||||
input_ids (`Optional[torch.Tensor]`):
|
||||
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
|
||||
Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
|
||||
inputs_embeds (`Optional[torch.Tensor]`):
|
||||
Tensor representing current input embeddings to the module. Must specify either this or input_ids.
|
||||
cache_position (`Optional[torch.Tensor]`):
|
||||
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
|
||||
dynamic_shapes (`Optional[dict]`):
|
||||
Dynamic shapes to use for export if specified.
|
||||
strict(`Optional[bool]`):
|
||||
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||
|
||||
Returns:
|
||||
torch.export.ExportedProgram: The exported program that can be used for inference.
|
||||
|
||||
Examples:
|
||||
Export with input_ids:
|
||||
```python
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
|
||||
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
|
||||
|
||||
# Export
|
||||
exported = exportable_module.export(
|
||||
input_ids=input_ids,
|
||||
cache_position=cache_position
|
||||
)
|
||||
```
|
||||
|
||||
Export with inputs_embeds:
|
||||
```python
|
||||
# Prepare embeddings
|
||||
inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
|
||||
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
|
||||
|
||||
# Export
|
||||
exported = exportable_module.export(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position
|
||||
)
|
||||
```
|
||||
"""
|
||||
if not (input_ids is None) ^ (inputs_embeds is None):
|
||||
raise ValueError("Need to specify either input_ids or inputs_embeds.")
|
||||
|
||||
if hasattr(self.model, "base_model_prefix"):
|
||||
base = getattr(self.model, self.model.base_model_prefix, self.model)
|
||||
model_device = base.device
|
||||
@ -279,20 +320,29 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
|
||||
)
|
||||
|
||||
example_input_ids = (
|
||||
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device)
|
||||
)
|
||||
example_cache_position = (
|
||||
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
|
||||
)
|
||||
if input_ids is not None:
|
||||
input_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"cache_position": cache_position
|
||||
if cache_position is not None
|
||||
else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
|
||||
}
|
||||
else: # inputs_embeds
|
||||
input_kwargs = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"cache_position": cache_position
|
||||
if cache_position is not None
|
||||
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
|
||||
}
|
||||
|
||||
exported_program = torch.export.export(
|
||||
self.model,
|
||||
args=(example_input_ids, example_cache_position),
|
||||
kwargs={},
|
||||
args=(),
|
||||
kwargs=input_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict if strict is not None else True,
|
||||
)
|
||||
|
||||
return exported_program
|
||||
|
||||
@staticmethod
|
||||
@ -341,7 +391,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||
|
||||
# Forward pass
|
||||
_ = exported_module(curr_input_ids, curr_cache_position)
|
||||
_ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
|
||||
curr_position += 1
|
||||
|
||||
# Generate new tokens
|
||||
@ -351,7 +401,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||
|
||||
# Forward pass to get next token logits
|
||||
outputs = exported_module(curr_input_ids, curr_cache_position)
|
||||
outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
|
||||
|
||||
# Get the next token ID
|
||||
if do_sample:
|
||||
@ -418,8 +468,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
max_batch_size: int = 1,
|
||||
max_cache_len: int = 4096,
|
||||
):
|
||||
"""
|
||||
Initializes the wrapper module with the pretrained model.
|
||||
@ -434,27 +482,31 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Sanity checks
|
||||
if model.generation_config is None:
|
||||
# Use default generation config if not specified
|
||||
model.generation_config = GenerationConfig(
|
||||
use_cache=model.config.use_cache,
|
||||
cache_implementation="static",
|
||||
max_length=max_cache_len,
|
||||
cache_config={
|
||||
"batch_size": max_batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": "cpu",
|
||||
},
|
||||
)
|
||||
config = model.config.get_text_config()
|
||||
generation_config = model.generation_config
|
||||
|
||||
if not model.generation_config.use_cache:
|
||||
# Sanity checks
|
||||
if generation_config is None:
|
||||
raise AssertionError(
|
||||
"The model must have a generation config to be exported with static caching. "
|
||||
"Please set `generation_config` in `model`."
|
||||
)
|
||||
if "batch_size" not in generation_config.cache_config:
|
||||
raise ValueError(
|
||||
"The model's generation config must specify a batch_size in its cache_config. "
|
||||
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
|
||||
)
|
||||
if "max_cache_len" not in generation_config.cache_config:
|
||||
raise ValueError(
|
||||
"The model's generation config must specify a max_cache_len in its cache_config. "
|
||||
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
|
||||
)
|
||||
if not generation_config.use_cache:
|
||||
raise AssertionError(
|
||||
"The model must have caching enabled to be exported with static caching. "
|
||||
"Please set `generation_config.use_cache=True`."
|
||||
)
|
||||
|
||||
if model.generation_config.cache_implementation != "static":
|
||||
if generation_config.cache_implementation != "static":
|
||||
raise AssertionError(
|
||||
"The model must use a 'static' caching implementation to be exported with static caching. "
|
||||
"Please set `generation_config.cache_implementation='static'`."
|
||||
@ -462,22 +514,29 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=self.model.generation_config.cache_config.get("batch_size"),
|
||||
max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"),
|
||||
device=self.model.generation_config.cache_config.get("device"),
|
||||
config=config,
|
||||
max_batch_size=generation_config.cache_config.get("batch_size"),
|
||||
max_cache_len=generation_config.cache_config.get("max_cache_len"),
|
||||
device=generation_config.cache_config.get("device"),
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
for i in range(len(self.static_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch runtime.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
@ -493,15 +552,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
|
||||
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
||||
"""
|
||||
_, seqlen = input_ids.shape
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
past_key_values = self.static_cache
|
||||
|
||||
outs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
attention_mask=None,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
@ -576,33 +633,45 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
max_batch_size: int = 1,
|
||||
max_cache_len: int = 4096,
|
||||
):
|
||||
"""
|
||||
Initializes the exportable module with `HybridCache`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||
max_batch_size (int): Maximum batch size for the cache.
|
||||
max_cache_len (int): Maximum sequence length for the cache.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model doesn't have the expected configuration for HybridCache.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
config = model.config.get_text_config()
|
||||
generation_config = model.generation_config
|
||||
|
||||
# Verify the model is configured for HybridCache
|
||||
if not self.model.config.use_cache:
|
||||
raise AssertionError("Model must have caching enabled")
|
||||
if generation_config is None:
|
||||
raise AssertionError(
|
||||
"The model must have a generation config to be exported with static caching. "
|
||||
"Please set `generation_config` in `model`."
|
||||
)
|
||||
if "batch_size" not in generation_config.cache_config:
|
||||
raise ValueError(
|
||||
"The model's generation config must specify a batch_size in its cache_config. "
|
||||
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
|
||||
)
|
||||
if "max_cache_len" not in generation_config.cache_config:
|
||||
raise ValueError(
|
||||
"The model's generation config must specify a max_cache_len in its cache_config. "
|
||||
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
|
||||
)
|
||||
if not config.use_cache:
|
||||
raise AssertionError("Model must have caching enabled.")
|
||||
|
||||
# Initialize the HybridCache
|
||||
self.cache = HybridCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=self.model.device,
|
||||
config=config,
|
||||
max_batch_size=generation_config.cache_config.get("batch_size"),
|
||||
max_cache_len=generation_config.cache_config.get("max_cache_len"),
|
||||
device=generation_config.cache_config.get("device"),
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
@ -613,32 +682,29 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
"""
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Generate position_ids from cache_position
|
||||
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Forward pass with the model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_values=self.cache,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# Return only the logits to simplify the export
|
||||
@ -692,8 +758,8 @@ def convert_and_export_with_cache(
|
||||
if is_torch_greater_or_equal("2.6.0"):
|
||||
exported_program = torch.export.export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids, example_cache_position),
|
||||
kwargs={},
|
||||
args=(),
|
||||
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict if strict is not None else True,
|
||||
)
|
||||
@ -710,8 +776,8 @@ def convert_and_export_with_cache(
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||
exported_program = torch.export._trace._export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
args=(),
|
||||
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
|
||||
pre_dispatch=False,
|
||||
strict=True,
|
||||
)
|
||||
|
@ -460,7 +460,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -365,7 +365,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
@ -389,7 +392,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
|
@ -822,7 +822,10 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
logging.info(f"\nExported program: {exported_program}")
|
||||
|
||||
# Test generation with the exported model
|
||||
|
@ -353,7 +353,10 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -384,7 +384,10 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -417,7 +417,10 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -303,7 +303,11 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
strict = version.parse(torch.__version__) != version.parse(
|
||||
"2.7.0"
|
||||
) # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -293,7 +293,11 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
129
tests/test_executorch.py
Normal file
129
tests/test_executorch.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleForDecoderOnlyLM,
|
||||
TorchExportableModuleWithHybridCache,
|
||||
TorchExportableModuleWithStaticCache,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
@require_torch
|
||||
class ExecutorchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
self.skipTest("torch >= 2.3 is required")
|
||||
|
||||
set_seed(0)
|
||||
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
self.model.eval()
|
||||
|
||||
# Create generation config with static cache for the model
|
||||
self.model.generation_config = GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation="static",
|
||||
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
|
||||
)
|
||||
|
||||
self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
|
||||
self.cache_position = torch.arange(3, dtype=torch.long)
|
||||
|
||||
def test_static_cache_module_forward(self):
|
||||
"""Test TorchExportableModuleWithStaticCache forward with both input types"""
|
||||
generation_config = GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation="static",
|
||||
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
|
||||
)
|
||||
|
||||
# Set generation config on model
|
||||
self.model.generation_config = generation_config
|
||||
module = TorchExportableModuleWithStaticCache(self.model)
|
||||
|
||||
# Test with input_ids
|
||||
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
|
||||
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
|
||||
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Test with inputs_embeds
|
||||
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
|
||||
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
|
||||
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_hybrid_cache_module_forward(self):
|
||||
"""Test TorchExportableModuleWithHybridCache forward with both input types"""
|
||||
config = self.model.config
|
||||
config.sliding_window = 16
|
||||
config.layer_types = ["full_attention"] * config.num_hidden_layers
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation="hybrid",
|
||||
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
|
||||
)
|
||||
|
||||
# Set generation config on model
|
||||
self.model.generation_config = generation_config
|
||||
module = TorchExportableModuleWithHybridCache(self.model)
|
||||
|
||||
# Test with input_ids
|
||||
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
|
||||
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
|
||||
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Test with inputs_embeds
|
||||
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
|
||||
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
|
||||
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_decoder_only_lm_export_validation(self):
|
||||
"""Test TorchExportableModuleForDecoderOnlyLM export validation"""
|
||||
module = TorchExportableModuleForDecoderOnlyLM(self.model)
|
||||
|
||||
# Should fail with both input_ids and inputs_embeds
|
||||
with self.assertRaises(ValueError):
|
||||
module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)
|
||||
|
||||
# Should fail with neither
|
||||
with self.assertRaises(ValueError):
|
||||
module.export()
|
||||
|
||||
def test_decoder_only_lm_export(self):
|
||||
"""Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
|
||||
module = TorchExportableModuleForDecoderOnlyLM(self.model)
|
||||
|
||||
# Test export with input_ids
|
||||
exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
|
||||
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
|
||||
exported_output_ids = exported_program_ids.module()(
|
||||
input_ids=self.input_ids, cache_position=self.cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Test export with inputs_embeds
|
||||
exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
|
||||
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
|
||||
exported_output_embeds = exported_program_embeds.module()(
|
||||
inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)
|
@ -841,8 +841,24 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
model.eval()
|
||||
max_batch_size = 1
|
||||
max_cache_len = 23
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
||||
exported_program = exportable_module.export()
|
||||
# Set generation config on the model for the hybrid cache model
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
|
||||
model.generation_config = GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation="hybrid",
|
||||
max_length=max_cache_len,
|
||||
cache_config={
|
||||
"batch_size": max_batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": model.device,
|
||||
},
|
||||
)
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
n_g_key_caches = n_g_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("key_cache"):
|
||||
|
Reference in New Issue
Block a user