[BC breaking] Remove deprecated imports for torch.utils.data.datapipes.iter.grouping (#163438)

This PR removes import tricks of `SHARDING_PRIORITIES` and  `ShardingFilterIterDataPipe` from `torch.utils.data.datapipes.iter.grouping`. They are declared to be removed in PyTorch 2.1 but not.
Before change:
```
import torch.utils.data.datapipes.iter.grouping.SHARDING_PRIORITIES
import torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe
```
works
After change:
there is an import error exception.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163438
Approved by: https://github.com/janeyx99
This commit is contained in:
Yuanyuan Chen
2025-09-23 05:02:06 +00:00
committed by PyTorch MergeBot
parent bb5be56619
commit e3b392bdfd
2 changed files with 0 additions and 69 deletions

View File

@ -3356,63 +3356,6 @@ class TestSharding(TestCase):
with self.assertRaises(Exception):
dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
# Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatibility
# TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated
def test_sharding_groups_in_legacy_grouping_package(self):
with self.assertWarnsRegex(
FutureWarning,
r"Please use `SHARDING_PRIORITIES` "
"from the `torch.utils.data.datapipes.iter.sharding`",
):
from torch.utils.data.datapipes.iter.grouping import (
SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES,
)
def construct_sharded_pipe():
sharding_pipes = []
dp = NumbersDataset(size=90)
dp = dp.sharding_filter(
sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
)
sharding_pipes.append(dp)
dp = dp.sharding_filter(
sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
)
sharding_pipes.append(dp)
dp = dp.sharding_filter(sharding_group_filter=300)
sharding_pipes.append(dp)
return dp, sharding_pipes
dp, sharding_pipes = construct_sharded_pipe()
for pipe in sharding_pipes:
pipe.apply_sharding(
2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
)
pipe.apply_sharding(
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
)
pipe.apply_sharding(3, 1, sharding_group=300)
actual = list(dp)
expected = [17, 47, 77]
self.assertEqual(expected, actual)
self.assertEqual(3, len(dp))
dp, _ = construct_sharded_pipe()
dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
with self.assertRaises(Exception):
dp.apply_sharding(
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
)
dp, _ = construct_sharded_pipe()
dp.apply_sharding(
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
)
with self.assertRaises(Exception):
dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
def test_legacy_custom_sharding(self):
dp = self._get_pipeline()
sharded_dp = CustomShardingIterDataPipe(dp)

View File

@ -1,10 +1,8 @@
# mypy: allow-untyped-defs
import warnings
from collections import defaultdict
from collections.abc import Iterator, Sized
from typing import Any, Callable, Optional, TypeVar
import torch.utils.data.datapipes.iter.sharding
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
@ -21,16 +19,6 @@ _T_co = TypeVar("_T_co", covariant=True)
def __getattr__(name: str):
if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]:
warnings.warn(
f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1"
f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`",
category=FutureWarning,
stacklevel=2,
)
return getattr(torch.utils.data.datapipes.iter.sharding, name)
raise AttributeError(f"module {__name__} has no attribute {name}")