Compare commits

...

2 Commits

Author SHA1 Message Date
1a67ca199c Merge branch 'main' into fix_require_class 2025-04-11 17:06:36 +02:00
94ea20ad41 refix 2025-04-10 16:23:41 +00:00

View File

@ -43,6 +43,7 @@ class QuarkConfigTest(unittest.TestCase):
@slow @slow
@require_quark @require_quark
@require_torch_gpu @require_torch_gpu
@require_read_token
class QuarkTest(unittest.TestCase): class QuarkTest(unittest.TestCase):
reference_model_name = "meta-llama/Llama-3.1-8B-Instruct" reference_model_name = "meta-llama/Llama-3.1-8B-Instruct"
quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
@ -76,13 +77,11 @@ class QuarkTest(unittest.TestCase):
device_map=cls.device_map, device_map=cls.device_map,
) )
@require_read_token
def test_memory_footprint(self): def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint() mem_quantized = self.quantized_model.get_memory_footprint()
self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE) self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE)
@require_read_token
def test_device_and_dtype_assignment(self): def test_device_and_dtype_assignment(self):
r""" r"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error. Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
@ -96,7 +95,6 @@ class QuarkTest(unittest.TestCase):
# Tries with a `dtype`` # Tries with a `dtype``
self.quantized_model.to(torch.float16) self.quantized_model.to(torch.float16)
@require_read_token
def test_original_dtype(self): def test_original_dtype(self):
r""" r"""
A simple test to check if the model succesfully stores the original dtype A simple test to check if the model succesfully stores the original dtype
@ -107,7 +105,6 @@ class QuarkTest(unittest.TestCase):
self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear)) self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear))
@require_read_token
def check_inference_correctness(self, model): def check_inference_correctness(self, model):
r""" r"""
Test the generation quality of the quantized model and see that we are matching the expected output. Test the generation quality of the quantized model and see that we are matching the expected output.
@ -131,7 +128,6 @@ class QuarkTest(unittest.TestCase):
# Get the generation # Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_read_token
def test_generate_quality(self): def test_generate_quality(self):
""" """
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens Simple test to check the quality of the model by comparing the generated tokens with the expected tokens