Fix CI unittest asserts (#4234)

This commit is contained in:
Albert Villanova del Moral
2025-10-08 21:18:41 +02:00
committed by GitHub
parent e2c97a805a
commit 521db3520a

View File

@ -1466,12 +1466,12 @@ class TestSFTTrainer(TrlTestCase):
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.