[DataLoader] Typing Enforcement for DataPipe at construct-time (#54066)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54066

## Feature
- Add a decorator `construct_time_validation` to validate each input datapipe according to the corresponding type hint.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D27327236

Pulled By: ejguan

fbshipit-source-id: a9d4c6edb5b05090bd5a369eee50a6fb4d7cf957
This commit is contained in:
Erjia Guan
2021-04-02 15:19:06 -07:00
committed by Facebook GitHub Bot
parent 44edf8c421
commit 1535520f08
3 changed files with 79 additions and 4 deletions

View File

@ -13,7 +13,7 @@ from unittest import skipIf
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader, construct_time_validation
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Set, Union
import torch.utils.data.datapipes as dp
@ -697,6 +697,40 @@ class TestTyping(TestCase):
dp = DP6() # type: ignore
self.assertTrue(dp.type.param == int)
def test_construct_time(self):
class DP0(IterDataPipe[Tuple]):
@construct_time_validation
def __init__(self, dp: IterDataPipe):
self.dp = dp
def __iter__(self) -> Iterator[Tuple]:
for d in self.dp:
yield d, str(d)
class DP1(IterDataPipe[int]):
@construct_time_validation
def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
self.dp = dp
def __iter__(self) -> Iterator[int]:
for a, b in self.dp:
yield a
# Non-DataPipe input with DataPipe hint
datasource = [(1, '1'), (2, '2'), (3, '3')]
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
dp = DP0(datasource)
dp = DP0(IDP(range(10)))
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
dp = DP1(dp)
with self.assertRaisesRegex(TypeError, r"Can not decorate"):
class InvalidDP1(IterDataPipe[int]):
@construct_time_validation
def __iter__(self):
yield 0
if __name__ == '__main__':
run_tests()