Compare commits

...

4 Commits

Author SHA1 Message Date
cc5a1c93de Update on "[dynamo][compile time] Special case for torch.utils._pytree._get_node_type"
<Replace this line with a title. Use 1 line only, 67 chars or less>

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-17 22:16:25 -08:00
dd3575139b Update on "[dynamo][compile time] Special case for torch.utils._pytree._get_node_type"
<Replace this line with a title. Use 1 line only, 67 chars or less>

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-17 22:10:35 -08:00
1a55a4777c [dynamo][compile time] Special case for torch.utils._pytree._get_node_type
<Replace this line with a title. Use 1 line only, 67 chars or less>

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
2025-11-17 16:21:20 -08:00
456833f78a [pytree][compile] Slightly faster TreeSpec init
Helps with reducing Dynamo tracing time. Earlier the generator object
would cause more polyfills.

[ghstack-poisoned]
2025-11-17 11:39:41 -08:00
6 changed files with 108 additions and 4 deletions

View File

@ -8184,6 +8184,65 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
def test_pytree_get_node_type_not_traced(self):
# Test that torch.utils._pytree._get_node_type is not traced into
# and doesn't cause excessive trace time overhead
from torch.utils._pytree import _get_node_type
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x, y):
# Call _get_node_type which is used internally by pytree operations
node_type = _get_node_type([x, y])
assert node_type is list
# Do some work with pytree structures
data = {"a": x, "b": y}
flat, spec = pytree.tree_flatten(data)
result = flat[0] + flat[1]
return result
x = torch.randn(3, 4)
y = torch.randn(3, 4)
result = fn(x, y)
expected = x + y
self.assertTrue(torch.allclose(result, expected))
# Should compile successfully with fullgraph=True
self.assertEqual(cnt.frame_count, 1)
def test_pytree_get_node_type_with_namedtuple(self):
# Test that torch.utils._pytree._get_node_type handles namedtuples correctly
# without being traced into, even when is_namedtuple_class is True
from collections import namedtuple
from torch.utils._pytree import _get_node_type
Point = namedtuple("Point", ["x", "y"])
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(a, b):
# Create a namedtuple
point = Point(a, b)
# Call _get_node_type with a namedtuple instance
node_type = _get_node_type(point)
assert node_type is namedtuple
# Use pytree operations with namedtuples
flat, spec = pytree.tree_flatten(point)
result = flat[0] + flat[1]
return result
x = torch.randn(3, 4)
y = torch.randn(3, 4)
result = fn(x, y)
expected = x + y
self.assertTrue(torch.allclose(result, expected))
# Should compile successfully with fullgraph=True
self.assertEqual(cnt.frame_count, 1)
instantiate_parametrized_tests(ReproTests)

View File

@ -201,8 +201,11 @@ class PyTreeSpec:
num_children = 0
else:
assert callable(self._unflatten_func)
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_nodes = 1
num_leaves = 0
for child in self._children:
num_nodes += child.num_nodes
num_leaves += child.num_leaves
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)

View File

@ -64,6 +64,7 @@ from .variables import (
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
PyTreeGetNodeTypeFunctionVariable,
ReparametrizeModuleCallVariable,
SkipFunctionVariable,
TorchInGraphFunctionVariable,
@ -378,6 +379,7 @@ manual_torch_name_rule_map: dict[
f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable,
"torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable,
f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable,
"torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable,
}

View File

@ -64,6 +64,7 @@ from .functions import (
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
PyTreeGetNodeTypeFunctionVariable,
SkipFunctionVariable,
TMADescriptorExperimentalVariable,
TMADescriptorStableVariable,

View File

@ -29,6 +29,7 @@ import logging
import sys
import traceback
import types
from collections import namedtuple
from collections.abc import Callable, Sequence
from types import CellType, FunctionType
from typing import Any, Optional, TYPE_CHECKING, TypeVar
@ -38,6 +39,7 @@ from weakref import WeakKeyDictionary
import torch
from torch._dynamo.exc import get_stack_above_dynamo
from torch._guards import Source
from torch.utils._pytree import is_namedtuple_class
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
@ -2717,3 +2719,37 @@ class CreateTMADescriptorStableVariable(VariableTracker):
tensor=tensor, # type: ignore[arg-type]
block_shape=block_shape, # type: ignore[arg-type]
)
class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable):
"""
`torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time.
def _get_node_type(tree: Any) -> Any:
node_type = type(tree)
# All namedtuple types are implicitly registered as pytree nodes.
# XXX: Other parts of the codebase expect namedtuple types always return
# `namedtuple` instead of the actual namedtuple type. Even if the type
# is explicitly registered.
if is_namedtuple_class(node_type):
return namedtuple
return node_type
"""
def call_function(
self,
tx: "InstructionTranslator",
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if len(args) != 1:
raise_type_error_exc(
tx,
f"pytree_get_node_type requires exactly 1 argument, got {len(args)}",
)
if args[0].source:
install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH))
python_type = args[0].python_type()
if is_namedtuple_class(python_type):
return VariableTracker.build(tx, namedtuple)
return VariableTracker.build(tx, python_type)

View File

@ -1113,8 +1113,11 @@ class TreeSpec:
num_leaves = 1
num_children = 0
else:
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_nodes = 1
num_leaves = 0
for child in self._children:
num_nodes += child.num_nodes
num_leaves += child.num_leaves
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)