mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytrees] Allow tree_map_only to support predicate function as filter (#119974)
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead. Alternative: wrap nested int SymNodes with something other than SymInt. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974 Approved by: https://github.com/zou3519 ghstack dependencies: #119661
This commit is contained in:
committed by
PyTorch MergeBot
parent
722e87899a
commit
2e77629b9f
@ -629,6 +629,18 @@ class TestGenericPytree(TestCase):
|
||||
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
# cxx tree_map_only does not support passing predicate fn as filter
|
||||
],
|
||||
)
|
||||
def test_tree_map_only_predicate_fn(self, pytree_impl):
|
||||
self.assertEqual(
|
||||
pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
|
Reference in New Issue
Block a user