[BE][pytree] cleanup parameterized pytree tests (#160842)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842
Approved by: https://github.com/Skylion007
This commit is contained in:
Xuehai Pan
2025-09-06 01:14:41 +08:00
committed by PyTorch MergeBot
parent 01edcd4df8
commit 2fa0520a64
2 changed files with 595 additions and 691 deletions

View File

@ -87,12 +87,15 @@ from torch.testing._internal.common_methods_invocations import (
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
freeze_rng_state, freeze_rng_state,
instantiate_parametrized_tests,
IS_FBCODE, IS_FBCODE,
parametrize,
scoped_load_inline, scoped_load_inline,
set_default_dtype, set_default_dtype,
skipIfHpu, skipIfHpu,
skipIfNNModuleInlined, skipIfNNModuleInlined,
skipIfWindows, skipIfWindows,
subtest,
TEST_HPU, TEST_HPU,
TEST_XPU, TEST_XPU,
wrapDeterministicFlagAPITest, wrapDeterministicFlagAPITest,
@ -101,11 +104,21 @@ from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.logging_utils import logs_to_string from torch.testing._internal.logging_utils import logs_to_string
pytree_modules = {
"python": python_pytree,
}
if python_pytree._cxx_pytree_dynamo_traceable: if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree import torch.utils._cxx_pytree as cxx_pytree
pytree_modules["cxx"] = cxx_pytree
else: else:
cxx_pytree = None cxx_pytree = None
parametrize_pytree_module = parametrize(
"pytree",
[subtest(module, name=name) for name, module in pytree_modules.items()],
)
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
T = typing.TypeVar("T") T = typing.TypeVar("T")
@ -9107,71 +9120,6 @@ def ___make_guard_fn():
opt = torch.compile(fn, backend="eager") opt = torch.compile(fn, backend="eager")
opt() opt()
def test_tracing_py_tree(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
counter = CompileCounter()
torch.compile(fn, backend=counter, fullgraph=True)(xs)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 3)
def test_tracing_nested_py_tree(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = [xs, xs, xs, xs]
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 12)
def test_tracing_nested_py_tree_tuples(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = (xs, xs, xs, xs)
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 12)
def test_tracing_nested_py_tree_dicts(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = {
"a": xs,
"b": xs,
"c": xs,
}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
def test_dynamic_one_hot(self): def test_dynamic_one_hot(self):
def fn(x): def fn(x):
x = x + 1 x = x + 1
@ -9188,28 +9136,6 @@ def ___make_guard_fn():
self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 2) self.assertEqual(counter.op_count, 2)
def test_tracing_nested_py_tree_mixed_all(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsa = (xs, xs)
xsb = {"aa": xsa, "ab": xs}
xsl = {
"a": xs,
"b": xsa,
"c": xsb,
}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 18)
def test_any_all_symnode(self): def test_any_all_symnode(self):
cnt = CompileCounter() cnt = CompileCounter()
@ -9236,46 +9162,6 @@ def ___make_guard_fn():
self.assertEqual(fn(y3), y3 - 3) self.assertEqual(fn(y3), y3 - 3)
self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.frame_count, 2)
def test_tracing_py_tree_tensor_subclass(self):
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import checkpoint
def fn(xs):
nested_xs = [[xs]]
flat_xs, spec = python_pytree.tree_flatten(xs)
return flat_xs[0].clone()
# use checkpoint to trigger a "sourceless" tensor subclass
def checkpoint_fn(xs):
return checkpoint(fn, xs, use_reentrant=True)
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
counter = CompileCounter()
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
def test_tracing_tree_map_only(self):
def fn(xs):
def mapper(x):
return x.clone()
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
return y
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
xsa = (xs, xs)
xsb = {"aa": xsa, "ab": xs}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
real_out = fn(xsb)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
@torch._dynamo.config.patch( @torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
) )
@ -10718,139 +10604,6 @@ def ___make_guard_fn():
expected = fn(*inps) expected = fn(*inps)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_pytree_tree_leaves(self):
implementations = [("python", python_pytree)]
if cxx_pytree is not None:
implementations.append(("cxx", cxx_pytree))
for name, module in implementations:
with self.subTest(f"pytree implement: {name}"):
def fn(x):
tree = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
leaves = module.tree_leaves(tree)
return leaves
x = torch.randn(3, 2)
expected = fn(x)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x)
self.assertEqual(actual, expected)
def test_pytree_tree_flatten_unflatten(self):
implementations = [("python", python_pytree)]
if cxx_pytree is not None:
implementations.append(("cxx", cxx_pytree))
for name, module in implementations:
with self.subTest(f"pytree implement: {name}"):
def fn(x, y):
tree = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
leaves, treespec = module.tree_flatten(tree)
new_leaves = [
x - 1,
y,
x * y,
3.0,
y - 2,
1,
torch.zeros(2, 2),
2 * y,
-y,
x + y,
x - y,
torch.ones(3, 2),
1,
]
new_tree = module.tree_unflatten(new_leaves, treespec)
return leaves, new_tree
x = torch.randn(3, 2)
y = torch.randn(3, 2)
expected = fn(x, y)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x, y)
self.assertEqual(actual, expected)
def test_pytree_tree_map(self):
implementations = [("python", python_pytree)]
if cxx_pytree is not None:
implementations.append(("cxx", cxx_pytree))
for name, module in implementations:
with self.subTest(f"pytree implement: {name}"):
def fn(x, y):
tree1 = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
tree2 = collections.OrderedDict(
[
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
("a", [y, y + 1]),
("b", y + 2),
(
"d",
{
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
"e": torch.return_types.qr((2 * y, None)),
},
),
],
)
return module.tree_map(lambda u, v: (u, v), tree1, tree2)
x = torch.randn(3, 2)
y = torch.randn(3, 2)
expected = fn(x, y)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x, y)
self.assertEqual(actual, expected)
def test_shape_env_no_recording(self): def test_shape_env_no_recording(self):
main = ShapeEnv(should_record_events=False) main = ShapeEnv(should_record_events=False)
@ -12886,6 +12639,257 @@ fn
self.assertRaises(Unsupported, f, "1 + j") self.assertRaises(Unsupported, f, "1 + j")
class MiscTestsPyTree(torch._inductor.test_case.TestCase):
@parametrize_pytree_module
def test_tracing_pytree(self, pytree):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
counter = CompileCounter()
torch.compile(fn, backend=counter, fullgraph=True)(xs)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 3)
@parametrize_pytree_module
def test_tracing_nested_pytree(self, pytree):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = [xs, xs, xs, xs]
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 12)
@parametrize_pytree_module
def test_tracing_nested_tuples(self, pytree):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = (xs, xs, xs, xs)
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 12)
@parametrize_pytree_module
def test_tracing_nested_dicts(self, pytree):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsl = {
"a": xs,
"b": xs,
"c": xs,
}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
@parametrize_pytree_module
def test_tracing_nested_mixed_all(self, pytree):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsa = (xs, xs)
xsb = {"aa": xsa, "ab": xs}
xsl = {
"a": xs,
"b": xsa,
"c": xsb,
}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 18)
@parametrize_pytree_module
def test_tracing_nested_tensor_subclass(self, pytree):
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import checkpoint
def fn(xs):
nested_xs = [[xs]]
flat_xs, spec = pytree.tree_flatten(xs)
return flat_xs[0].clone()
# use checkpoint to trigger a "sourceless" tensor subclass
def checkpoint_fn(xs):
return checkpoint(fn, xs, use_reentrant=True)
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
counter = CompileCounter()
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
@parametrize_pytree_module
def test_pytree_tree_leaves(self, pytree):
def fn(x):
tree = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
leaves = pytree.tree_leaves(tree)
return leaves
x = torch.randn(3, 2)
expected = fn(x)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x)
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_flatten_unflatten(self, pytree):
def fn(x, y):
tree = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
leaves, treespec = pytree.tree_flatten(tree)
new_leaves = [
x - 1,
y,
x * y,
3.0,
y - 2,
1,
torch.zeros(2, 2),
2 * y,
-y,
x + y,
x - y,
torch.ones(3, 2),
1,
]
new_tree = pytree.tree_unflatten(new_leaves, treespec)
return leaves, new_tree
x = torch.randn(3, 2)
y = torch.randn(3, 2)
expected = fn(x, y)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x, y)
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_map(self, pytree):
def fn(x, y):
tree1 = {
"a": [x, x - 1],
"b": x + 2,
"c": (
x,
3.0,
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
"e": torch.return_types.qr((2 * x, None)),
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
},
),
}
tree2 = collections.OrderedDict(
[
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
("a", [y, y + 1]),
("b", y + 2),
(
"d",
{
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
"e": torch.return_types.qr((2 * y, None)),
},
),
],
)
return pytree.tree_map(lambda u, v: (u, v), tree1, tree2)
x = torch.randn(3, 2)
y = torch.randn(3, 2)
expected = fn(x, y)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x, y)
self.assertEqual(actual, expected)
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
def fn(xs):
def mapper(x):
return x.clone()
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
return y
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
xsa = (xs, xs)
xsb = {"aa": xsa, "ab": xs}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
real_out = fn(xsb)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
class TestTracer(JitTestCase): class TestTracer(JitTestCase):
def test_jit_save(self): def test_jit_save(self):
def fn(): def fn():
@ -13266,10 +13270,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
# RuntimeError: value cannot be converted to type at::Half without overflow # RuntimeError: value cannot be converted to type at::Half without overflow
instantiate_parametrized_tests(MiscTestsPyTree)
devices = ("cuda", "hpu", "xpu") devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests( instantiate_device_type_tests(
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
) )
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

File diff suppressed because it is too large Load Diff