Simplify test logic (#2697)

* simplify test logic 😅

* 😅
This commit is contained in:
Sourab Mangrulkar
2024-04-23 02:49:55 +05:30
committed by GitHub
parent baafaf4a6e
commit ef0f62c12a

View File

@ -255,9 +255,8 @@ def test_data_loader(data_loader, accelerator):
sorted_all_examples = sorted(all_examples)
# Check if all elements are present in the sorted list of iterated samples
label_data = list(range(NUM_ELEMENTS))
assert set(sorted_all_examples).intersection(set(label_data)) == set(
label_data
assert (
len(set(sorted_all_examples)) == NUM_ELEMENTS
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."