[pytrees] Allow tree_map_only to support predicate function as filter (#119974)

In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.

Alternative: wrap nested int SymNodes with something other than SymInt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
This commit is contained in:
soulitzer
2024-02-16 11:23:28 -05:00
committed by PyTorch MergeBot
parent 722e87899a
commit 2e77629b9f
4 changed files with 108 additions and 36 deletions

View File

@ -629,6 +629,18 @@ class TestGenericPytree(TestCase):
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
# cxx tree_map_only does not support passing predicate fn as filter
],
)
def test_tree_map_only_predicate_fn(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
)
@parametrize(
"pytree_impl",
[