mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][pytree] cleanup parameterized pytree tests (#160842)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
01edcd4df8
commit
2fa0520a64
@ -87,12 +87,15 @@ from torch.testing._internal.common_methods_invocations import (
|
|||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
freeze_rng_state,
|
freeze_rng_state,
|
||||||
|
instantiate_parametrized_tests,
|
||||||
IS_FBCODE,
|
IS_FBCODE,
|
||||||
|
parametrize,
|
||||||
scoped_load_inline,
|
scoped_load_inline,
|
||||||
set_default_dtype,
|
set_default_dtype,
|
||||||
skipIfHpu,
|
skipIfHpu,
|
||||||
skipIfNNModuleInlined,
|
skipIfNNModuleInlined,
|
||||||
skipIfWindows,
|
skipIfWindows,
|
||||||
|
subtest,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
wrapDeterministicFlagAPITest,
|
wrapDeterministicFlagAPITest,
|
||||||
@ -101,11 +104,21 @@ from torch.testing._internal.jit_utils import JitTestCase
|
|||||||
from torch.testing._internal.logging_utils import logs_to_string
|
from torch.testing._internal.logging_utils import logs_to_string
|
||||||
|
|
||||||
|
|
||||||
|
pytree_modules = {
|
||||||
|
"python": python_pytree,
|
||||||
|
}
|
||||||
if python_pytree._cxx_pytree_dynamo_traceable:
|
if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
import torch.utils._cxx_pytree as cxx_pytree
|
import torch.utils._cxx_pytree as cxx_pytree
|
||||||
|
|
||||||
|
pytree_modules["cxx"] = cxx_pytree
|
||||||
else:
|
else:
|
||||||
cxx_pytree = None
|
cxx_pytree = None
|
||||||
|
|
||||||
|
parametrize_pytree_module = parametrize(
|
||||||
|
"pytree",
|
||||||
|
[subtest(module, name=name) for name, module in pytree_modules.items()],
|
||||||
|
)
|
||||||
|
|
||||||
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
||||||
T = typing.TypeVar("T")
|
T = typing.TypeVar("T")
|
||||||
|
|
||||||
@ -9107,71 +9120,6 @@ def ___make_guard_fn():
|
|||||||
opt = torch.compile(fn, backend="eager")
|
opt = torch.compile(fn, backend="eager")
|
||||||
opt()
|
opt()
|
||||||
|
|
||||||
def test_tracing_py_tree(self):
|
|
||||||
def fn(xs):
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
res = [x.clone() for x in flat_xs]
|
|
||||||
return python_pytree.tree_unflatten(res, spec)
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
torch.compile(fn, backend=counter, fullgraph=True)(xs)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 3)
|
|
||||||
|
|
||||||
def test_tracing_nested_py_tree(self):
|
|
||||||
def fn(xs):
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
res = [x.clone() for x in flat_xs]
|
|
||||||
return python_pytree.tree_unflatten(res, spec)
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
|
||||||
xsl = [xs, xs, xs, xs]
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
|
||||||
real_out = fn(xsl)
|
|
||||||
self.assertEqual(comp_out, real_out)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 12)
|
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_tuples(self):
|
|
||||||
def fn(xs):
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
res = [x.clone() for x in flat_xs]
|
|
||||||
return python_pytree.tree_unflatten(res, spec)
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
|
||||||
xsl = (xs, xs, xs, xs)
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
|
||||||
real_out = fn(xsl)
|
|
||||||
self.assertEqual(comp_out, real_out)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 12)
|
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_dicts(self):
|
|
||||||
def fn(xs):
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
res = [x.clone() for x in flat_xs]
|
|
||||||
return python_pytree.tree_unflatten(res, spec)
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
|
||||||
xsl = {
|
|
||||||
"a": xs,
|
|
||||||
"b": xs,
|
|
||||||
"c": xs,
|
|
||||||
}
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
|
||||||
real_out = fn(xsl)
|
|
||||||
self.assertEqual(comp_out, real_out)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 9)
|
|
||||||
|
|
||||||
def test_dynamic_one_hot(self):
|
def test_dynamic_one_hot(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
x = x + 1
|
x = x + 1
|
||||||
@ -9188,28 +9136,6 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(counter.frame_count, 2)
|
self.assertEqual(counter.frame_count, 2)
|
||||||
self.assertEqual(counter.op_count, 2)
|
self.assertEqual(counter.op_count, 2)
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_mixed_all(self):
|
|
||||||
def fn(xs):
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
res = [x.clone() for x in flat_xs]
|
|
||||||
return python_pytree.tree_unflatten(res, spec)
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
|
||||||
xsa = (xs, xs)
|
|
||||||
xsb = {"aa": xsa, "ab": xs}
|
|
||||||
xsl = {
|
|
||||||
"a": xs,
|
|
||||||
"b": xsa,
|
|
||||||
"c": xsb,
|
|
||||||
}
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
|
||||||
real_out = fn(xsl)
|
|
||||||
self.assertEqual(comp_out, real_out)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 18)
|
|
||||||
|
|
||||||
def test_any_all_symnode(self):
|
def test_any_all_symnode(self):
|
||||||
cnt = CompileCounter()
|
cnt = CompileCounter()
|
||||||
|
|
||||||
@ -9236,46 +9162,6 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(fn(y3), y3 - 3)
|
self.assertEqual(fn(y3), y3 - 3)
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
def test_tracing_py_tree_tensor_subclass(self):
|
|
||||||
from torch.testing._internal.two_tensor import TwoTensor
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
def fn(xs):
|
|
||||||
nested_xs = [[xs]]
|
|
||||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
|
||||||
return flat_xs[0].clone()
|
|
||||||
|
|
||||||
# use checkpoint to trigger a "sourceless" tensor subclass
|
|
||||||
def checkpoint_fn(xs):
|
|
||||||
return checkpoint(fn, xs, use_reentrant=True)
|
|
||||||
|
|
||||||
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 2)
|
|
||||||
|
|
||||||
def test_tracing_tree_map_only(self):
|
|
||||||
def fn(xs):
|
|
||||||
def mapper(x):
|
|
||||||
return x.clone()
|
|
||||||
|
|
||||||
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
|
|
||||||
return y
|
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
|
||||||
xsa = (xs, xs)
|
|
||||||
xsb = {"aa": xsa, "ab": xs}
|
|
||||||
|
|
||||||
counter = CompileCounter()
|
|
||||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
|
|
||||||
real_out = fn(xsb)
|
|
||||||
|
|
||||||
self.assertEqual(comp_out, real_out)
|
|
||||||
self.assertEqual(counter.frame_count, 1)
|
|
||||||
self.assertEqual(counter.op_count, 9)
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch(
|
@torch._dynamo.config.patch(
|
||||||
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||||
)
|
)
|
||||||
@ -10718,139 +10604,6 @@ def ___make_guard_fn():
|
|||||||
expected = fn(*inps)
|
expected = fn(*inps)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_leaves(self):
|
|
||||||
implementations = [("python", python_pytree)]
|
|
||||||
if cxx_pytree is not None:
|
|
||||||
implementations.append(("cxx", cxx_pytree))
|
|
||||||
|
|
||||||
for name, module in implementations:
|
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
|
||||||
|
|
||||||
def fn(x):
|
|
||||||
tree = {
|
|
||||||
"a": [x, x - 1],
|
|
||||||
"b": x + 2,
|
|
||||||
"c": (
|
|
||||||
x,
|
|
||||||
3.0,
|
|
||||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
|
||||||
),
|
|
||||||
"d": collections.OrderedDict(
|
|
||||||
{
|
|
||||||
"e": torch.return_types.qr((2 * x, None)),
|
|
||||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
leaves = module.tree_leaves(tree)
|
|
||||||
return leaves
|
|
||||||
|
|
||||||
x = torch.randn(3, 2)
|
|
||||||
expected = fn(x)
|
|
||||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
|
||||||
actual = fn_opt(x)
|
|
||||||
|
|
||||||
self.assertEqual(actual, expected)
|
|
||||||
|
|
||||||
def test_pytree_tree_flatten_unflatten(self):
|
|
||||||
implementations = [("python", python_pytree)]
|
|
||||||
if cxx_pytree is not None:
|
|
||||||
implementations.append(("cxx", cxx_pytree))
|
|
||||||
|
|
||||||
for name, module in implementations:
|
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
|
||||||
|
|
||||||
def fn(x, y):
|
|
||||||
tree = {
|
|
||||||
"a": [x, x - 1],
|
|
||||||
"b": x + 2,
|
|
||||||
"c": (
|
|
||||||
x,
|
|
||||||
3.0,
|
|
||||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
|
||||||
),
|
|
||||||
"d": collections.OrderedDict(
|
|
||||||
{
|
|
||||||
"e": torch.return_types.qr((2 * x, None)),
|
|
||||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
leaves, treespec = module.tree_flatten(tree)
|
|
||||||
new_leaves = [
|
|
||||||
x - 1,
|
|
||||||
y,
|
|
||||||
x * y,
|
|
||||||
3.0,
|
|
||||||
y - 2,
|
|
||||||
1,
|
|
||||||
torch.zeros(2, 2),
|
|
||||||
2 * y,
|
|
||||||
-y,
|
|
||||||
x + y,
|
|
||||||
x - y,
|
|
||||||
torch.ones(3, 2),
|
|
||||||
1,
|
|
||||||
]
|
|
||||||
new_tree = module.tree_unflatten(new_leaves, treespec)
|
|
||||||
return leaves, new_tree
|
|
||||||
|
|
||||||
x = torch.randn(3, 2)
|
|
||||||
y = torch.randn(3, 2)
|
|
||||||
expected = fn(x, y)
|
|
||||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
|
||||||
actual = fn_opt(x, y)
|
|
||||||
|
|
||||||
self.assertEqual(actual, expected)
|
|
||||||
|
|
||||||
def test_pytree_tree_map(self):
|
|
||||||
implementations = [("python", python_pytree)]
|
|
||||||
if cxx_pytree is not None:
|
|
||||||
implementations.append(("cxx", cxx_pytree))
|
|
||||||
|
|
||||||
for name, module in implementations:
|
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
|
||||||
|
|
||||||
def fn(x, y):
|
|
||||||
tree1 = {
|
|
||||||
"a": [x, x - 1],
|
|
||||||
"b": x + 2,
|
|
||||||
"c": (
|
|
||||||
x,
|
|
||||||
3.0,
|
|
||||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
|
||||||
),
|
|
||||||
"d": collections.OrderedDict(
|
|
||||||
{
|
|
||||||
"e": torch.return_types.qr((2 * x, None)),
|
|
||||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
tree2 = collections.OrderedDict(
|
|
||||||
[
|
|
||||||
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
|
|
||||||
("a", [y, y + 1]),
|
|
||||||
("b", y + 2),
|
|
||||||
(
|
|
||||||
"d",
|
|
||||||
{
|
|
||||||
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
|
|
||||||
"e": torch.return_types.qr((2 * y, None)),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return module.tree_map(lambda u, v: (u, v), tree1, tree2)
|
|
||||||
|
|
||||||
x = torch.randn(3, 2)
|
|
||||||
y = torch.randn(3, 2)
|
|
||||||
expected = fn(x, y)
|
|
||||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
|
||||||
actual = fn_opt(x, y)
|
|
||||||
|
|
||||||
self.assertEqual(actual, expected)
|
|
||||||
|
|
||||||
def test_shape_env_no_recording(self):
|
def test_shape_env_no_recording(self):
|
||||||
main = ShapeEnv(should_record_events=False)
|
main = ShapeEnv(should_record_events=False)
|
||||||
|
|
||||||
@ -12886,6 +12639,257 @@ fn
|
|||||||
self.assertRaises(Unsupported, f, "1 + j")
|
self.assertRaises(Unsupported, f, "1 + j")
|
||||||
|
|
||||||
|
|
||||||
|
class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_pytree(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
torch.compile(fn, backend=counter, fullgraph=True)(xs)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 3)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_nested_pytree(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
xsl = [xs, xs, xs, xs]
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||||
|
real_out = fn(xsl)
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 12)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_nested_tuples(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
xsl = (xs, xs, xs, xs)
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||||
|
real_out = fn(xsl)
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 12)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_nested_dicts(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
xsl = {
|
||||||
|
"a": xs,
|
||||||
|
"b": xs,
|
||||||
|
"c": xs,
|
||||||
|
}
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||||
|
real_out = fn(xsl)
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 9)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_nested_mixed_all(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
xsa = (xs, xs)
|
||||||
|
xsb = {"aa": xsa, "ab": xs}
|
||||||
|
xsl = {
|
||||||
|
"a": xs,
|
||||||
|
"b": xsa,
|
||||||
|
"c": xsb,
|
||||||
|
}
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||||
|
real_out = fn(xsl)
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 18)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_tracing_nested_tensor_subclass(self, pytree):
|
||||||
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
def fn(xs):
|
||||||
|
nested_xs = [[xs]]
|
||||||
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
|
return flat_xs[0].clone()
|
||||||
|
|
||||||
|
# use checkpoint to trigger a "sourceless" tensor subclass
|
||||||
|
def checkpoint_fn(xs):
|
||||||
|
return checkpoint(fn, xs, use_reentrant=True)
|
||||||
|
|
||||||
|
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 2)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_pytree_tree_leaves(self, pytree):
|
||||||
|
def fn(x):
|
||||||
|
tree = {
|
||||||
|
"a": [x, x - 1],
|
||||||
|
"b": x + 2,
|
||||||
|
"c": (
|
||||||
|
x,
|
||||||
|
3.0,
|
||||||
|
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||||
|
),
|
||||||
|
"d": collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"e": torch.return_types.qr((2 * x, None)),
|
||||||
|
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
leaves = pytree.tree_leaves(tree)
|
||||||
|
return leaves
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
expected = fn(x)
|
||||||
|
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||||
|
actual = fn_opt(x)
|
||||||
|
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_pytree_tree_flatten_unflatten(self, pytree):
|
||||||
|
def fn(x, y):
|
||||||
|
tree = {
|
||||||
|
"a": [x, x - 1],
|
||||||
|
"b": x + 2,
|
||||||
|
"c": (
|
||||||
|
x,
|
||||||
|
3.0,
|
||||||
|
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||||
|
),
|
||||||
|
"d": collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"e": torch.return_types.qr((2 * x, None)),
|
||||||
|
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
leaves, treespec = pytree.tree_flatten(tree)
|
||||||
|
new_leaves = [
|
||||||
|
x - 1,
|
||||||
|
y,
|
||||||
|
x * y,
|
||||||
|
3.0,
|
||||||
|
y - 2,
|
||||||
|
1,
|
||||||
|
torch.zeros(2, 2),
|
||||||
|
2 * y,
|
||||||
|
-y,
|
||||||
|
x + y,
|
||||||
|
x - y,
|
||||||
|
torch.ones(3, 2),
|
||||||
|
1,
|
||||||
|
]
|
||||||
|
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||||
|
return leaves, new_tree
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
y = torch.randn(3, 2)
|
||||||
|
expected = fn(x, y)
|
||||||
|
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||||
|
actual = fn_opt(x, y)
|
||||||
|
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_pytree_tree_map(self, pytree):
|
||||||
|
def fn(x, y):
|
||||||
|
tree1 = {
|
||||||
|
"a": [x, x - 1],
|
||||||
|
"b": x + 2,
|
||||||
|
"c": (
|
||||||
|
x,
|
||||||
|
3.0,
|
||||||
|
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||||
|
),
|
||||||
|
"d": collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"e": torch.return_types.qr((2 * x, None)),
|
||||||
|
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
tree2 = collections.OrderedDict(
|
||||||
|
[
|
||||||
|
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
|
||||||
|
("a", [y, y + 1]),
|
||||||
|
("b", y + 2),
|
||||||
|
(
|
||||||
|
"d",
|
||||||
|
{
|
||||||
|
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
|
||||||
|
"e": torch.return_types.qr((2 * y, None)),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return pytree.tree_map(lambda u, v: (u, v), tree1, tree2)
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
y = torch.randn(3, 2)
|
||||||
|
expected = fn(x, y)
|
||||||
|
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||||
|
actual = fn_opt(x, y)
|
||||||
|
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@parametrize_pytree_module
|
||||||
|
def test_pytree_tree_map_only(self, pytree):
|
||||||
|
def fn(xs):
|
||||||
|
def mapper(x):
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||||
|
return y
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
||||||
|
xsa = (xs, xs)
|
||||||
|
xsb = {"aa": xsa, "ab": xs}
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
|
||||||
|
real_out = fn(xsb)
|
||||||
|
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 9)
|
||||||
|
|
||||||
|
|
||||||
class TestTracer(JitTestCase):
|
class TestTracer(JitTestCase):
|
||||||
def test_jit_save(self):
|
def test_jit_save(self):
|
||||||
def fn():
|
def fn():
|
||||||
@ -13266,10 +13270,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||||||
# RuntimeError: value cannot be converted to type at::Half without overflow
|
# RuntimeError: value cannot be converted to type at::Half without overflow
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_parametrized_tests(MiscTestsPyTree)
|
||||||
|
|
||||||
devices = ("cuda", "hpu", "xpu")
|
devices = ("cuda", "hpu", "xpu")
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
|
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user