Compare commits

...

1 Commits

Author SHA1 Message Date
41d1847519 try 2025-04-30 15:10:49 +02:00

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch GotOcr2 model."""
import accelerate
import unittest
from transformers import (
@ -291,9 +292,8 @@ class Mistral3IntegrationTest(unittest.TestCase):
@require_read_token
def test_mistral3_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
model = Mistral3ForConditionalGeneration.from_pretrained(self.model_checkpoint, torch_dtype=torch.bfloat16)
accelerate.cpu_offload(model, execution_device=torch_device)
messages = [
{
@ -319,9 +319,8 @@ class Mistral3IntegrationTest(unittest.TestCase):
@require_read_token
def test_mistral3_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
model = Mistral3ForConditionalGeneration.from_pretrained(self.model_checkpoint, torch_dtype=torch.bfloat16)
accelerate.cpu_offload(model, execution_device=torch_device)
messages = [
{
"role": "user",
@ -346,9 +345,8 @@ class Mistral3IntegrationTest(unittest.TestCase):
@require_read_token
def test_mistral3_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
model = Mistral3ForConditionalGeneration.from_pretrained(self.model_checkpoint, torch_dtype=torch.bfloat16)
accelerate.cpu_offload(model, execution_device=torch_device)
messages = [
[
{
@ -402,6 +400,7 @@ class Mistral3IntegrationTest(unittest.TestCase):
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, quantization_config=quantization_config
)
accelerate.cpu_offload(model, execution_device=torch_device)
# Prepare inputs
messages = [