mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Fix CI unittest asserts (#4234)
This commit is contained in:
committed by
GitHub
parent
e2c97a805a
commit
521db3520a
@ -1466,12 +1466,12 @@ class TestSFTTrainer(TrlTestCase):
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
|
||||
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||
|
||||
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
|
||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||
|
Reference in New Issue
Block a user