[fx] Move map_aggregate to C++ (#148243)

Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
This commit is contained in:
Jason Ansel
2025-03-09 21:21:57 -07:00
committed by PyTorch MergeBot
parent b8b1b364c9
commit bec7bdad47
8 changed files with 285 additions and 60 deletions

View File

@ -1,65 +1,65 @@
add_loop_eager,compile_time_instruction_count,3121000000,0.015
add_loop_eager,compile_time_instruction_count,3002000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5807000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5689000000,0.025
add_loop_inductor,compile_time_instruction_count,30000000000,0.015
add_loop_inductor,compile_time_instruction_count,28650000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44150000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42620000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26040000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,25090000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969600000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,964300000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19360000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18060000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17330000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16330000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11000000000,0.2
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10240000000,0.2
update_hint_regression,compile_time_instruction_count,1712000000,0.02
update_hint_regression,compile_time_instruction_count,1611000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1076000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1058000000,0.015
symint_sum,compile_time_instruction_count,3367000000,0.015
symint_sum,compile_time_instruction_count,3168000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2060000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2015000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5872000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5785000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9298000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8664000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3946000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3777000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10450000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10150000000,0.015

1 add_loop_eager compile_time_instruction_count 3121000000 3002000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5807000000 5689000000 0.025
3 add_loop_inductor compile_time_instruction_count 30000000000 28650000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44150000000 42620000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 26040000000 25090000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 969600000 964300000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 19360000000 18060000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17330000000 16330000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 11000000000 10240000000 0.2
10 update_hint_regression compile_time_instruction_count 1712000000 1611000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1076000000 1058000000 0.015
12 symint_sum compile_time_instruction_count 3367000000 3168000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2060000000 2015000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5872000000 5785000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9298000000 8664000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3946000000 3777000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10450000000 10150000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -2513,6 +2513,8 @@ def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ..
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
# Defined in torch/csrc/fx/node.cpp
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
class _NodeBase:
_erased: _bool
_prev: FxNode

View File

@ -0,0 +1,40 @@
from typing import Any, Callable
from torch._C import _fx_map_aggregate, _fx_map_arg
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.node import Node
from ..decorators import substitute_in_graph
@substitute_in_graph(_fx_map_arg, can_constant_fold_through=True)
def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any:
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
@substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True)
def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any:
result: Any
if isinstance(a, tuple):
it = (map_aggregate(elem, fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
elif isinstance(a, list):
result = immutable_list([map_aggregate(elem, fn) for elem in a])
elif isinstance(a, dict):
result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
elif isinstance(a, slice):
result = slice(
map_aggregate(a.start, fn),
map_aggregate(a.stop, fn),
map_aggregate(a.step, fn),
)
else:
result = fn(a)
return result
__all__ = [
"map_arg",
"map_aggregate",
]

View File

@ -20,6 +20,7 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
"os",
"pytree",
"sys",
"fx",
)
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)

View File

@ -1,8 +1,146 @@
#include <torch/csrc/fx/node.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
namespace {
// Thrown to exit out of a C++ function and return an error to Python.
class PythonError : public std::exception {};
inline static PyObject* import_from(const char* module_name, const char* name) {
THPObjectPtr module(PyImport_ImportModule(module_name));
if (!module) {
throw PythonError();
}
PyObject* result = PyObject_GetAttrString(module, name);
if (!result) {
throw PythonError();
}
return result;
}
inline static PyObject* immutable_list_cls() {
static PyObject* immutable_list_cls = nullptr;
if (!immutable_list_cls) {
immutable_list_cls =
import_from("torch.fx.immutable_collections", "immutable_list");
}
return immutable_list_cls;
}
inline static PyObject* immutable_dict_cls() {
static PyObject* immutable_dict_cls = nullptr;
if (!immutable_dict_cls) {
immutable_dict_cls =
import_from("torch.fx.immutable_collections", "immutable_dict");
}
return immutable_dict_cls;
}
inline static bool is_node(PyObject* obj) {
static PyObject* node_cls = nullptr;
if (!node_cls) {
node_cls = import_from("torch.fx.node", "Node");
}
return PyObject_TypeCheck(obj, reinterpret_cast<PyTypeObject*>(node_cls));
}
inline static bool exact_type(PyObject* obj, PyObject* typ) {
return Py_TYPE(obj) == reinterpret_cast<PyTypeObject*>(typ);
}
template <typename F>
inline static PyObject* map_aggregate(PyObject* a, F fn) {
// Invariant: this function will throw an exception and never return nullptr.
// Case 1: a is a tuple.
if (PyTuple_Check(a)) {
Py_ssize_t n = PyTuple_GET_SIZE(a);
if (n == 0 && PyTuple_CheckExact(a)) {
return Py_NewRef(a);
}
THPObjectPtr new_tuple(PyTuple_New(n));
if (!new_tuple) {
throw PythonError();
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* elem = PyTuple_GET_ITEM(a, i); // Borrowed reference.
// PyTuple_SET_ITEM steals reference to result of map_aggregate
PyTuple_SET_ITEM(new_tuple.get(), i, map_aggregate(elem, fn));
}
// If the tuple has a "_fields" attribute, assume it is a NamedTuple.
if (!PyTuple_CheckExact(a) && PyObject_HasAttrString(a, "_fields")) {
// Call type_obj with new_tuple as arguments (i.e. type(a)(*new_tuple))
return PyObject_CallObject(
reinterpret_cast<PyObject*>(Py_TYPE(a)), new_tuple);
} else {
return new_tuple.release();
}
}
// Case 2: a is a list.
else if (PyList_Check(a)) {
Py_ssize_t n = PyList_GET_SIZE(a);
if (n == 0 && exact_type(a, immutable_list_cls())) {
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_list_cls()));
if (!result) {
throw PythonError();
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref
THPObjectPtr mapped(map_aggregate(elem, fn));
if (PyList_Append(result.get(), mapped.get()) < 0) {
throw PythonError();
}
}
return result.release();
}
// Case 3: a is a dict.
else if (PyDict_Check(a)) {
if (PyDict_GET_SIZE(a) == 0 && exact_type(a, immutable_dict_cls())) {
return Py_NewRef(a);
}
THPObjectPtr result(PyObject_CallNoArgs(immutable_dict_cls()));
if (!result) {
throw PythonError();
}
PyObject *key = nullptr, *value = nullptr; // borrowed
Py_ssize_t pos = 0;
while (PyDict_Next(a, &pos, &key, &value)) {
THPObjectPtr mapped(map_aggregate(value, fn));
if (PyDict_SetItem(result.get(), key, mapped.get()) < 0) {
throw PythonError();
}
}
return result.release();
}
// Case 4: a is a slice.
else if (PySlice_Check(a)) {
// Get start, stop, and step attributes.
THPObjectPtr start(PyObject_GetAttrString(a, "start"));
THPObjectPtr stop(PyObject_GetAttrString(a, "stop"));
THPObjectPtr step(PyObject_GetAttrString(a, "step"));
if (!start || !stop || !step) {
throw PythonError();
}
THPObjectPtr mapped_start(map_aggregate(start, fn));
THPObjectPtr mapped_stop(map_aggregate(stop, fn));
THPObjectPtr mapped_step(map_aggregate(step, fn));
return PySlice_New(
mapped_start.get(), mapped_stop.get(), mapped_step.get());
}
// Default case: call fn(a).
else {
PyObject* result = fn(a);
if (!result) {
throw PythonError();
}
return result;
}
}
////////////////////////////////
// NodeBase
///////////////////////////////
@ -59,7 +197,7 @@ static void NodeBase_dealloc(PyObject* self) {
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeBaseType = {
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
sizeof(NodeBase), /* tp_basicsize */
@ -101,12 +239,7 @@ static PyTypeObject NodeBaseType = {
NodeBase_new, /* tp_new */
};
bool NodeBase_init(PyObject* module) {
if (PyModule_AddType(module, &NodeBaseType) < 0) {
return false;
}
return true;
}
} // namespace
////////////////////////////////
// NodeIter
@ -259,3 +392,71 @@ bool NodeIter_init(PyObject* module) {
}
return true;
}
////////////////////////////////
// Global methods
////////////////////////////////
static PyObject* py_map_aggregate(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError, "map_aggregate() takes exactly two arguments");
return nullptr;
}
try {
PyObject* fn = args[1];
// args[0]: aggregate, args[1]: callable fn
return map_aggregate(
args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); });
} catch (const PythonError& e) {
return nullptr; // error should already be set
}
}
static PyObject* py_map_arg(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(PyExc_TypeError, "map_arg() takes exactly two arguments");
return nullptr;
}
try {
PyObject* fn = args[1];
// args[0]: aggregate, args[1]: callable fn
return map_aggregate(args[0], [fn](PyObject* a) {
if (is_node(a)) {
return PyObject_CallOneArg(fn, a);
}
return Py_NewRef(a);
});
} catch (const PythonError& e) {
return nullptr; // error should already be set
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef extra_methods[] = {
{"_fx_map_aggregate",
(PyCFunction)(void*)(py_map_aggregate),
METH_FASTCALL,
"Recursively apply a function to every element in an aggregate object."},
{"_fx_map_arg",
(PyCFunction)(void*)(py_map_arg),
METH_FASTCALL,
"Recursively apply a function to every Node in an aggregate object."},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
bool NodeBase_init(PyObject* module) {
if (PyModule_AddType(module, &NodeBaseType) < 0) {
return false;
}
if (PyModule_AddFunctions(module, extra_methods) < 0) {
return false;
}
return true;
}

View File

@ -19,11 +19,11 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
import torch
import torch.utils._pytree as pytree
from torch._C import _NodeIter
from torch._C import _fx_map_arg as map_arg, _NodeIter
from . import _pytree as fx_pytree
from ._compatibility import compatibility
from .node import _get_qualified_name, _type_repr, Argument, map_arg, Node, Target
from .node import _get_qualified_name, _type_repr, Argument, Node, Target
__all__ = ["PythonCode", "CodeGen", "Graph"]

View File

@ -5,10 +5,10 @@ import logging
import operator
import types
from collections.abc import Mapping, Sequence
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
import torch
from torch._C import _NodeBase
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
from torch.fx.operator_schemas import (
ArgsKwargsPair,
normalize_function,
@ -17,7 +17,6 @@ from torch.fx.operator_schemas import (
from .._ops import ops as _ops
from ._compatibility import compatibility
from .immutable_collections import immutable_dict, immutable_list
if TYPE_CHECKING:
@ -535,7 +534,7 @@ class Node(_NodeBase):
self._args = args_left + (arg,) + args_right
_new_input_nodes: dict[Node, None] = {}
map_arg(arg, _new_input_nodes.setdefault)
_fx_map_arg(arg, _new_input_nodes.setdefault)
for new_use in _new_input_nodes.keys():
if new_use not in self._input_nodes:
@ -596,10 +595,10 @@ class Node(_NodeBase):
# - Populate self._input_nodes
# - Populate arg.users[self] for each arg
object.__setattr__(
self, "_args", map_aggregate(new_args, update_users_and_input_nodes)
self, "_args", _fx_map_aggregate(new_args, update_users_and_input_nodes)
)
object.__setattr__(
self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)
self, "_kwargs", _fx_map_aggregate(new_kwargs, update_users_and_input_nodes)
)
def __repr__(self) -> str:
@ -748,8 +747,8 @@ class Node(_NodeBase):
for replace_hook in m._replace_hooks:
replace_hook(old=self, new=replace_with.name, user=use_node)
new_args = map_arg(use_node.args, maybe_replace_node)
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node.__update_args_kwargs(new_args, new_kwargs)
@ -860,8 +859,8 @@ class Node(_NodeBase):
for replace_hook in m._replace_hooks:
replace_hook(old=old_input, new=new_input.name, user=self)
new_args = map_arg(self.args, maybe_replace_node)
new_kwargs = map_arg(self.kwargs, maybe_replace_node)
new_args = _fx_map_arg(self.args, maybe_replace_node)
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
self.__update_args_kwargs(new_args, new_kwargs)
@ -903,7 +902,7 @@ def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT:
have the same type and structure.
"""
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
return _fx_map_arg(a, fn)
@compatibility(is_backward_compatible=True)
@ -914,23 +913,4 @@ def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT
arg may be a list, tuple, slice, or dict with string keys: the return value will
have the same type and structure.
"""
result: Argument
if isinstance(a, tuple):
it = (map_aggregate(elem, fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
elif isinstance(a, list):
result = immutable_list([map_aggregate(elem, fn) for elem in a])
elif isinstance(a, dict):
result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
elif isinstance(a, slice):
result = slice(
map_aggregate(a.start, fn),
map_aggregate(a.stop, fn),
map_aggregate(a.step, fn),
)
else:
result = fn(a)
return cast(ArgumentT, result)
return _fx_map_aggregate(a, fn)

View File

@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate
from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility
from .graph import Graph, magic_methods, reflectable_magic_methods
from .node import Argument, base_types, map_aggregate, Node, Target
from .node import Argument, base_types, Node, Target
from .operator_schemas import check_for_mutable_operation
@ -584,8 +585,8 @@ class Proxy:
if isinstance(a, cls):
tracers[a.tracer] = None
torch.fx.node.map_aggregate(args, find_tracer)
torch.fx.node.map_aggregate(kwargs, find_tracer)
map_aggregate(args, find_tracer)
map_aggregate(kwargs, find_tracer)
if len(tracers) > 1:
raise RuntimeError(