From e3b392bdfd10bfee6f92eece2457f82cd2efcb35 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Tue, 23 Sep 2025 05:02:06 +0000 Subject: [PATCH] [BC breaking] Remove deprecated imports for torch.utils.data.datapipes.iter.grouping (#163438) This PR removes import tricks of `SHARDING_PRIORITIES` and `ShardingFilterIterDataPipe` from `torch.utils.data.datapipes.iter.grouping`. They are declared to be removed in PyTorch 2.1 but not. Before change: ``` import torch.utils.data.datapipes.iter.grouping.SHARDING_PRIORITIES import torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe ``` works After change: there is an import error exception. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163438 Approved by: https://github.com/janeyx99 --- test/test_datapipe.py | 57 --------------------- torch/utils/data/datapipes/iter/grouping.py | 12 ----- 2 files changed, 69 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 2a57bef2075b..cb8dd252ec4b 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -3356,63 +3356,6 @@ class TestSharding(TestCase): with self.assertRaises(Exception): dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) - # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatibility - # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated - def test_sharding_groups_in_legacy_grouping_package(self): - with self.assertWarnsRegex( - FutureWarning, - r"Please use `SHARDING_PRIORITIES` " - "from the `torch.utils.data.datapipes.iter.sharding`", - ): - from torch.utils.data.datapipes.iter.grouping import ( - SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES, - ) - - def construct_sharded_pipe(): - sharding_pipes = [] - dp = NumbersDataset(size=90) - dp = dp.sharding_filter( - sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED - ) - sharding_pipes.append(dp) - dp = dp.sharding_filter( - sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING - ) - sharding_pipes.append(dp) - dp = dp.sharding_filter(sharding_group_filter=300) - sharding_pipes.append(dp) - return dp, sharding_pipes - - dp, sharding_pipes = construct_sharded_pipe() - - for pipe in sharding_pipes: - pipe.apply_sharding( - 2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED - ) - pipe.apply_sharding( - 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING - ) - pipe.apply_sharding(3, 1, sharding_group=300) - - actual = list(dp) - expected = [17, 47, 77] - self.assertEqual(expected, actual) - self.assertEqual(3, len(dp)) - - dp, _ = construct_sharded_pipe() - dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) - with self.assertRaises(Exception): - dp.apply_sharding( - 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING - ) - - dp, _ = construct_sharded_pipe() - dp.apply_sharding( - 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING - ) - with self.assertRaises(Exception): - dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) - def test_legacy_custom_sharding(self): dp = self._get_pipeline() sharded_dp = CustomShardingIterDataPipe(dp) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 055d9c28b09b..e7e50d302e12 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,10 +1,8 @@ # mypy: allow-untyped-defs -import warnings from collections import defaultdict from collections.abc import Iterator, Sized from typing import Any, Callable, Optional, TypeVar -import torch.utils.data.datapipes.iter.sharding from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn @@ -21,16 +19,6 @@ _T_co = TypeVar("_T_co", covariant=True) def __getattr__(name: str): - if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]: - warnings.warn( - f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1" - f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`", - category=FutureWarning, - stacklevel=2, - ) - - return getattr(torch.utils.data.datapipes.iter.sharding, name) - raise AttributeError(f"module {__name__} has no attribute {name}")