mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BC breaking] Remove deprecated imports for torch.utils.data.datapipes.iter.grouping (#163438)
This PR removes import tricks of `SHARDING_PRIORITIES` and `ShardingFilterIterDataPipe` from `torch.utils.data.datapipes.iter.grouping`. They are declared to be removed in PyTorch 2.1 but not. Before change: ``` import torch.utils.data.datapipes.iter.grouping.SHARDING_PRIORITIES import torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe ``` works After change: there is an import error exception. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163438 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb5be56619
commit
e3b392bdfd
@ -3356,63 +3356,6 @@ class TestSharding(TestCase):
|
||||
with self.assertRaises(Exception):
|
||||
dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
|
||||
|
||||
# Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatibility
|
||||
# TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated
|
||||
def test_sharding_groups_in_legacy_grouping_package(self):
|
||||
with self.assertWarnsRegex(
|
||||
FutureWarning,
|
||||
r"Please use `SHARDING_PRIORITIES` "
|
||||
"from the `torch.utils.data.datapipes.iter.sharding`",
|
||||
):
|
||||
from torch.utils.data.datapipes.iter.grouping import (
|
||||
SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES,
|
||||
)
|
||||
|
||||
def construct_sharded_pipe():
|
||||
sharding_pipes = []
|
||||
dp = NumbersDataset(size=90)
|
||||
dp = dp.sharding_filter(
|
||||
sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
|
||||
)
|
||||
sharding_pipes.append(dp)
|
||||
dp = dp.sharding_filter(
|
||||
sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
|
||||
)
|
||||
sharding_pipes.append(dp)
|
||||
dp = dp.sharding_filter(sharding_group_filter=300)
|
||||
sharding_pipes.append(dp)
|
||||
return dp, sharding_pipes
|
||||
|
||||
dp, sharding_pipes = construct_sharded_pipe()
|
||||
|
||||
for pipe in sharding_pipes:
|
||||
pipe.apply_sharding(
|
||||
2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
|
||||
)
|
||||
pipe.apply_sharding(
|
||||
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
|
||||
)
|
||||
pipe.apply_sharding(3, 1, sharding_group=300)
|
||||
|
||||
actual = list(dp)
|
||||
expected = [17, 47, 77]
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(3, len(dp))
|
||||
|
||||
dp, _ = construct_sharded_pipe()
|
||||
dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
|
||||
with self.assertRaises(Exception):
|
||||
dp.apply_sharding(
|
||||
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
|
||||
)
|
||||
|
||||
dp, _ = construct_sharded_pipe()
|
||||
dp.apply_sharding(
|
||||
5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
|
||||
)
|
||||
with self.assertRaises(Exception):
|
||||
dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
|
||||
|
||||
def test_legacy_custom_sharding(self):
|
||||
dp = self._get_pipeline()
|
||||
sharded_dp = CustomShardingIterDataPipe(dp)
|
||||
|
@ -1,10 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sized
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
import torch.utils.data.datapipes.iter.sharding
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
||||
@ -21,16 +19,6 @@ _T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]:
|
||||
warnings.warn(
|
||||
f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1"
|
||||
f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`",
|
||||
category=FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return getattr(torch.utils.data.datapipes.iter.sharding, name)
|
||||
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user