Support input_embeds in torch exportable decoders (#39836)

* Support input_embeds in torch exportable decoders

* Hybrid cache update

* Manually change some callsites

* AI changes the rest of the call sites

* Make either input_ids/inputs_embeds mandatory

* Clean up

* Ruff check --fix

* Fix test

* pr review

* Revert config/generation_config changes

* Ruff check
This commit is contained in:
Jack
2025-08-07 01:51:31 -07:00
committed by GitHub
parent cdeaad96b7
commit 6121e9e46c
11 changed files with 325 additions and 85 deletions

View File

@ -198,34 +198,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.
Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises:
ValueError: If the model is configured with a unsupported cache implementation.
"""
super().__init__()
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
config = model.config.get_text_config()
_generation_config = model.generation_config
if not hasattr(config, "use_cache") or config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.")
if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model)
else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len)
self.model = TorchExportableModuleWithStaticCache(model)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@ -233,24 +232,31 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
"""
return self.model.forward(input_ids, cache_position)
return self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
)
def export(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None,
@ -260,14 +266,49 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
Args:
input_ids (`Optional[torch.Tensor]`):
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
inputs_embeds (`Optional[torch.Tensor]`):
Tensor representing current input embeddings to the module. Must specify either this or input_ids.
cache_position (`Optional[torch.Tensor]`):
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified.
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
Returns:
torch.export.ExportedProgram: The exported program that can be used for inference.
Examples:
Export with input_ids:
```python
# Prepare inputs
input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
# Export
exported = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position
)
```
Export with inputs_embeds:
```python
# Prepare embeddings
inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
# Export
exported = exportable_module.export(
inputs_embeds=inputs_embeds,
cache_position=cache_position
)
```
"""
if not (input_ids is None) ^ (inputs_embeds is None):
raise ValueError("Need to specify either input_ids or inputs_embeds.")
if hasattr(self.model, "base_model_prefix"):
base = getattr(self.model, self.model.base_model_prefix, self.model)
model_device = base.device
@ -279,20 +320,29 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
)
example_input_ids = (
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device)
)
example_cache_position = (
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
)
if input_ids is not None:
input_kwargs = {
"input_ids": input_ids,
"cache_position": cache_position
if cache_position is not None
else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
}
else: # inputs_embeds
input_kwargs = {
"inputs_embeds": inputs_embeds,
"cache_position": cache_position
if cache_position is not None
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
}
exported_program = torch.export.export(
self.model,
args=(example_input_ids, example_cache_position),
kwargs={},
args=(),
kwargs=input_kwargs,
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)
return exported_program
@staticmethod
@ -341,7 +391,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass
_ = exported_module(curr_input_ids, curr_cache_position)
_ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
curr_position += 1
# Generate new tokens
@ -351,7 +401,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass to get next token logits
outputs = exported_module(curr_input_ids, curr_cache_position)
outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
# Get the next token ID
if do_sample:
@ -418,8 +468,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the wrapper module with the pretrained model.
@ -434,27 +482,31 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
super().__init__()
# Sanity checks
if model.generation_config is None:
# Use default generation config if not specified
model.generation_config = GenerationConfig(
use_cache=model.config.use_cache,
cache_implementation="static",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": "cpu",
},
)
config = model.config.get_text_config()
generation_config = model.generation_config
if not model.generation_config.use_cache:
# Sanity checks
if generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not generation_config.use_cache:
raise AssertionError(
"The model must have caching enabled to be exported with static caching. "
"Please set `generation_config.use_cache=True`."
)
if model.generation_config.cache_implementation != "static":
if generation_config.cache_implementation != "static":
raise AssertionError(
"The model must use a 'static' caching implementation to be exported with static caching. "
"Please set `generation_config.cache_implementation='static'`."
@ -462,22 +514,29 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
self.model = model
self.static_cache = StaticCache(
config=self.model.config,
max_batch_size=self.model.generation_config.cache_config.get("batch_size"),
max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"),
device=self.model.generation_config.cache_config.get("device"),
config=config,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=self.model.dtype,
)
for i in range(len(self.static_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
):
"""
Forward pass of the module, which is compatible with the ExecuTorch runtime.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
@ -493,15 +552,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache
outs = self.model(
input_ids=input_ids,
attention_mask=None,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
attention_mask=None,
past_key_values=past_key_values,
use_cache=True,
)
@ -576,33 +633,45 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.
Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises:
AssertionError: If the model doesn't have the expected configuration for HybridCache.
"""
super().__init__()
self.model = model
config = model.config.get_text_config()
generation_config = model.generation_config
# Verify the model is configured for HybridCache
if not self.model.config.use_cache:
raise AssertionError("Model must have caching enabled")
if generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not config.use_cache:
raise AssertionError("Model must have caching enabled.")
# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.model.device,
config=config,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=self.model.dtype,
)
@ -613,32 +682,29 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
"""
batch_size = input_ids.shape[0]
# Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
# Forward pass with the model
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
attention_mask=None,
position_ids=position_ids,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
# Return only the logits to simplify the export
@ -692,8 +758,8 @@ def convert_and_export_with_cache(
if is_torch_greater_or_equal("2.6.0"):
exported_program = torch.export.export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids, example_cache_position),
kwargs={},
args=(),
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)
@ -710,8 +776,8 @@ def convert_and_export_with_cache(
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
args=(),
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
pre_dispatch=False,
strict=True,
)

View File

@ -460,7 +460,10 @@ class GemmaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -365,7 +365,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
@ -389,7 +392,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
# Test generation with the exported model
prompt = "What is the capital of France?"

View File

@ -822,7 +822,10 @@ class Gemma3IntegrationTest(unittest.TestCase):
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
logging.info(f"\nExported program: {exported_program}")
# Test generation with the exported model

View File

@ -353,7 +353,10 @@ class LlamaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -384,7 +384,10 @@ class OlmoIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -417,7 +417,10 @@ class Phi3IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -303,7 +303,11 @@ class Qwen2IntegrationTest(unittest.TestCase):
strict = version.parse(torch.__version__) != version.parse(
"2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(strict=strict)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -293,7 +293,11 @@ class Qwen3IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(strict=strict)
exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

129
tests/test_executorch.py Normal file
View File

@ -0,0 +1,129 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoModelForCausalLM, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
TorchExportableModuleWithHybridCache,
TorchExportableModuleWithStaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.testing_utils import require_torch
@require_torch
class ExecutorchTest(unittest.TestCase):
def setUp(self):
if not is_torch_greater_or_equal_than_2_3:
self.skipTest("torch >= 2.3 is required")
set_seed(0)
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.model.eval()
# Create generation config with static cache for the model
self.model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="static",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
self.cache_position = torch.arange(3, dtype=torch.long)
def test_static_cache_module_forward(self):
"""Test TorchExportableModuleWithStaticCache forward with both input types"""
generation_config = GenerationConfig(
use_cache=True,
cache_implementation="static",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
# Set generation config on model
self.model.generation_config = generation_config
module = TorchExportableModuleWithStaticCache(self.model)
# Test with input_ids
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
# Test with inputs_embeds
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
def test_hybrid_cache_module_forward(self):
"""Test TorchExportableModuleWithHybridCache forward with both input types"""
config = self.model.config
config.sliding_window = 16
config.layer_types = ["full_attention"] * config.num_hidden_layers
generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
# Set generation config on model
self.model.generation_config = generation_config
module = TorchExportableModuleWithHybridCache(self.model)
# Test with input_ids
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
# Test with inputs_embeds
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
def test_decoder_only_lm_export_validation(self):
"""Test TorchExportableModuleForDecoderOnlyLM export validation"""
module = TorchExportableModuleForDecoderOnlyLM(self.model)
# Should fail with both input_ids and inputs_embeds
with self.assertRaises(ValueError):
module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)
# Should fail with neither
with self.assertRaises(ValueError):
module.export()
def test_decoder_only_lm_export(self):
"""Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
module = TorchExportableModuleForDecoderOnlyLM(self.model)
# Test export with input_ids
exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
exported_output_ids = exported_program_ids.module()(
input_ids=self.input_ids, cache_position=self.cache_position
)
torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)
# Test export with inputs_embeds
exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
exported_output_embeds = exported_program_embeds.module()(
inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
)
torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)

View File

@ -841,8 +841,24 @@ class CacheExportIntegrationTest(unittest.TestCase):
model.eval()
max_batch_size = 1
max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
exported_program = exportable_module.export()
# Set generation config on the model for the hybrid cache model
from transformers.generation.configuration_utils import GenerationConfig
model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": model.device,
},
)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):