mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
[DataPipe] Improve Mapper to accept input/output index when apply fn (#64951)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64951 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D30910035 Pulled By: ejguan fbshipit-source-id: d687fe10939920a3617a60552fe743e8526438a0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
670853295a
commit
c65128679b
@ -38,7 +38,7 @@ import torch.utils.data.backward_compatibility
|
||||
import torch.utils.data.datapipes as dp
|
||||
import torch.utils.data.graph
|
||||
import torch.utils.data.sharding
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
DataChunk,
|
||||
@ -902,7 +902,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
len(dp2)
|
||||
|
||||
|
||||
@suppress_warnings # Suppress warning for lambda fn
|
||||
def test_map_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
|
||||
@ -927,12 +927,137 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
|
||||
|
||||
input_dp_nl = IDP_NoLen(range(10))
|
||||
map_dp_nl = input_dp_nl.map()
|
||||
map_dp_nl = input_dp_nl.map(lambda x: x)
|
||||
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
|
||||
len(map_dp_nl)
|
||||
for x, y in zip(map_dp_nl, input_dp_nl):
|
||||
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
|
||||
|
||||
@suppress_warnings # Suppress warning for lambda fn
|
||||
def test_map_tuple_list_with_col_datapipe(self):
|
||||
def fn_11(d):
|
||||
return -d
|
||||
|
||||
def fn_1n(d):
|
||||
return -d, d
|
||||
|
||||
def fn_n1(d0, d1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_nn(d0, d1):
|
||||
return -d0, -d1, d0 + d1
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
for constr in (list, tuple):
|
||||
datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
|
||||
# Replacing with one input column and default output column
|
||||
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
|
||||
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
|
||||
# The index of input column is out of range
|
||||
with self.assertRaises(IndexError):
|
||||
_helper(None, fn_1n, 3)
|
||||
# Unmatched input columns with fn arguments
|
||||
with self.assertRaises(TypeError):
|
||||
_helper(None, fn_n1, 1)
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
|
||||
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
|
||||
|
||||
# output_col can only be specified when input_col is not None
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, 1)
|
||||
# output_col can only be single-element list or tuple
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, [0, 1])
|
||||
# Single-element list as output_col
|
||||
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
|
||||
# Replacing with one input column and single specified output column
|
||||
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
|
||||
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
|
||||
# The index of output column is out of range
|
||||
with self.assertRaises(IndexError):
|
||||
_helper(None, fn_1n, 1, 3)
|
||||
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
|
||||
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
|
||||
|
||||
# Appending the output at the end
|
||||
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
|
||||
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
|
||||
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
|
||||
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)
|
||||
|
||||
@suppress_warnings # Suppress warning for lambda fn
|
||||
def test_map_dict_with_col_datapipe(self):
|
||||
def fn_11(d):
|
||||
return -d
|
||||
|
||||
def fn_1n(d):
|
||||
return -d, d
|
||||
|
||||
def fn_n1(d0, d1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_nn(d0, d1):
|
||||
return -d0, -d1, d0 + d1
|
||||
|
||||
# Prevent modification in-place to support resetting
|
||||
def _dict_update(data, newdata, remove_idx=None):
|
||||
_data = dict(data)
|
||||
_data.update(newdata)
|
||||
if remove_idx:
|
||||
for idx in remove_idx:
|
||||
del _data[idx]
|
||||
return _data
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
datapipe = IDP([{"x": 0, "y": 1, "z": 2},
|
||||
{"x": 3, "y": 4, "z": 5},
|
||||
{"x": 6, "y": 7, "z": 8}])
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
|
||||
# Replacing with one input column and default output column
|
||||
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
|
||||
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
|
||||
# The key of input column is not in dict
|
||||
with self.assertRaises(KeyError):
|
||||
_helper(None, fn_1n, "a")
|
||||
# Unmatched input columns with fn arguments
|
||||
with self.assertRaises(TypeError):
|
||||
_helper(None, fn_n1, "y")
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
|
||||
_helper(lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])
|
||||
|
||||
# output_col can only be specified when input_col is not None
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, "x")
|
||||
# output_col can only be single-element list or tuple
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, ["x", "y"])
|
||||
# Single-element list as output_col
|
||||
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
|
||||
# Replacing with one input column and single specified output column
|
||||
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
|
||||
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
|
||||
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
|
||||
_helper(lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "x")
|
||||
|
||||
# Adding new key to dict for the output
|
||||
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
|
||||
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
|
||||
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
|
||||
_helper(lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "a")
|
||||
|
||||
# TODO(VitalyFedyunin): If dill installed this test fails
|
||||
def _test_map_datapipe_nested_level(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user