Compare commits

...

2 Commits

Author SHA1 Message Date
c337ad8dc4 Merge branch 'main' into divyanshk/fix_datalaoder_test_314 2025-11-14 12:48:34 -08:00
a4628bb761 fix unit tests 2025-11-09 16:02:28 -08:00

View File

@ -3480,7 +3480,7 @@ class TestIndividualWorkerQueue(TestCase):
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
timeout=5,
timeout=JOIN_TIMEOUT,
worker_init_fn=self.dataset.worker_init_fn,
)
current_worker_idx = 0
@ -3498,33 +3498,31 @@ class TestIndividualWorkerQueue(TestCase):
"Flaky on Windows and MacOS https://github.com/pytorch/pytorch/issues/68643",
)
def test_ind_worker_queue(self):
max_num_workers = None
if hasattr(os, "sched_getaffinity"):
try:
max_num_workers = len(os.sched_getaffinity(0))
except Exception:
pass
if max_num_workers is None:
cpu_count = os.cpu_count()
if cpu_count is not None:
# Use half number of CPUs
max_num_workers = cpu_count // 2
if max_num_workers is None:
max_num_workers = 1
for batch_size in (8, 16, 32, 64):
for num_workers in range(min(6, max_num_workers)):
for batch_size in (8, 32, 64):
for num_workers in range(1, 6):
self._run_ind_worker_queue_test(
batch_size=batch_size, num_workers=num_workers + 1
batch_size=batch_size, num_workers=num_workers
)
class SetAffinityDataset(IterableDataset):
def __init__(self, expected_affinity=None):
self.expected_affinity = expected_affinity
def __iter__(self):
torch.randperm(1)
after = os.sched_getaffinity(0)
return iter(after)
affinity_mask = os.sched_getaffinity(0)
return iter(affinity_mask)
def _worker_set_affinity_init(worker_id):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
dataset = worker_info.dataset
if (
isinstance(dataset, SetAffinityDataset)
and dataset.expected_affinity is not None
):
os.sched_setaffinity(0, [dataset.expected_affinity])
@unittest.skipIf(
@ -3539,19 +3537,14 @@ class TestSetAffinity(TestCase):
# Choose any
expected_affinity = list(old_affinity)[-1]
def worker_set_affinity(_):
os.sched_setaffinity(0, [expected_affinity])
dataset = SetAffinityDataset()
if not IS_WINDOWS and not IS_MACOS:
import multiprocessing as py_mp
py_mp.set_start_method("fork", force=True)
# Pass expected affinity through the dataset
dataset = SetAffinityDataset(expected_affinity=expected_affinity)
dataloader = torch.utils.data.DataLoader(
dataset, num_workers=2, worker_init_fn=worker_set_affinity
dataset,
num_workers=2,
worker_init_fn=_worker_set_affinity_init,
)
for sample in dataloader:
self.assertEqual(sample, [expected_affinity])