Compare commits

...

2 Commits

Author SHA1 Message Date
1a67ca199c Merge branch 'main' into fix_require_class 2025-04-11 17:06:36 +02:00
94ea20ad41 refix 2025-04-10 16:23:41 +00:00

View File

@ -43,6 +43,7 @@ class QuarkConfigTest(unittest.TestCase):
@slow
@require_quark
@require_torch_gpu
@require_read_token
class QuarkTest(unittest.TestCase):
reference_model_name = "meta-llama/Llama-3.1-8B-Instruct"
quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
@ -76,13 +77,11 @@ class QuarkTest(unittest.TestCase):
device_map=cls.device_map,
)
@require_read_token
def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint()
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
@require_read_token
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
@ -96,7 +95,6 @@ class QuarkTest(unittest.TestCase):
# Tries with a `dtype``
self.quantized_model.to(torch.float16)
@require_read_token
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
@ -107,7 +105,6 @@ class QuarkTest(unittest.TestCase):
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
@require_read_token
def check_inference_correctness(self, model):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
@ -131,7 +128,6 @@ class QuarkTest(unittest.TestCase):
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_read_token
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens