mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[DataLoader] Typing Enforcement for DataPipe at construct-time (#54066)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54066 ## Feature - Add a decorator `construct_time_validation` to validate each input datapipe according to the corresponding type hint. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D27327236 Pulled By: ejguan fbshipit-source-id: a9d4c6edb5b05090bd5a369eee50a6fb4d7cf957
This commit is contained in:
committed by
Facebook GitHub Bot
parent
44edf8c421
commit
1535520f08
@ -13,7 +13,7 @@ from unittest import skipIf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests)
|
||||
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader
|
||||
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader, construct_time_validation
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Set, Union
|
||||
|
||||
import torch.utils.data.datapipes as dp
|
||||
@ -697,6 +697,40 @@ class TestTyping(TestCase):
|
||||
dp = DP6() # type: ignore
|
||||
self.assertTrue(dp.type.param == int)
|
||||
|
||||
def test_construct_time(self):
|
||||
class DP0(IterDataPipe[Tuple]):
|
||||
@construct_time_validation
|
||||
def __init__(self, dp: IterDataPipe):
|
||||
self.dp = dp
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
for d in self.dp:
|
||||
yield d, str(d)
|
||||
|
||||
class DP1(IterDataPipe[int]):
|
||||
@construct_time_validation
|
||||
def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
|
||||
self.dp = dp
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
for a, b in self.dp:
|
||||
yield a
|
||||
|
||||
# Non-DataPipe input with DataPipe hint
|
||||
datasource = [(1, '1'), (2, '2'), (3, '3')]
|
||||
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
|
||||
dp = DP0(datasource)
|
||||
|
||||
dp = DP0(IDP(range(10)))
|
||||
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
|
||||
dp = DP1(dp)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"Can not decorate"):
|
||||
class InvalidDP1(IterDataPipe[int]):
|
||||
@construct_time_validation
|
||||
def __iter__(self):
|
||||
yield 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user