[DataPipe] Improve Mapper to accept input/output index when apply fn (#64951)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64951

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30910035

Pulled By: ejguan

fbshipit-source-id: d687fe10939920a3617a60552fe743e8526438a0
This commit is contained in:
Erjia Guan
2021-09-14 15:44:57 -07:00
committed by Facebook GitHub Bot
parent 670853295a
commit c65128679b
2 changed files with 243 additions and 39 deletions

View File

@ -38,7 +38,7 @@ import torch.utils.data.backward_compatibility
import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.sharding
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
from torch.utils.data import (
DataLoader,
DataChunk,
@ -902,7 +902,7 @@ class TestFunctionalIterDataPipe(TestCase):
with self.assertRaises(TypeError):
len(dp2)
@suppress_warnings # Suppress warning for lambda fn
def test_map_datapipe(self):
input_dp = IDP(range(10))
@ -927,12 +927,137 @@ class TestFunctionalIterDataPipe(TestCase):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
input_dp_nl = IDP_NoLen(range(10))
map_dp_nl = input_dp_nl.map()
map_dp_nl = input_dp_nl.map(lambda x: x)
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(map_dp_nl)
for x, y in zip(map_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
@suppress_warnings # Suppress warning for lambda fn
def test_map_tuple_list_with_col_datapipe(self):
def fn_11(d):
return -d
def fn_1n(d):
return -d, d
def fn_n1(d0, d1):
return d0 + d1
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1
def _helper(ref_fn, fn, input_col=None, output_col=None):
for constr in (list, tuple):
datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 3)
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, 1)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, 1)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, [0, 1])
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 1, 3)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
# Appending the output at the end
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)
@suppress_warnings # Suppress warning for lambda fn
def test_map_dict_with_col_datapipe(self):
def fn_11(d):
return -d
def fn_1n(d):
return -d, d
def fn_n1(d0, d1):
return d0 + d1
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1
# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
_data.update(newdata)
if remove_idx:
for idx in remove_idx:
del _data[idx]
return _data
def _helper(ref_fn, fn, input_col=None, output_col=None):
datapipe = IDP([{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
with self.assertRaises(KeyError):
_helper(None, fn_1n, "a")
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, "y")
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])
# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, "x")
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, ["x", "y"])
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
_helper(lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "x")
# Adding new key to dict for the output
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "a")
# TODO(VitalyFedyunin): If dill installed this test fails
def _test_map_datapipe_nested_level(self):