mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 30603618 function calls (29403419 primitive calls) in 13.744 seconds ``` after: ``` 25203549 function calls (24403352 primitive calls) in 12.090 seconds ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8b1b364c9
commit
bec7bdad47
@ -1,65 +1,65 @@
|
||||
add_loop_eager,compile_time_instruction_count,3121000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,3002000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5807000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5689000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,30000000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,28650000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44150000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42620000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26040000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25090000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969600000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,964300000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19360000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18060000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17330000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16330000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11000000000,0.2
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10240000000,0.2
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1712000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1611000000,0.02
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,1076000000,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1058000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3367000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3168000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2060000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2015000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5872000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5785000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9298000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8664000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3946000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3777000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10450000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10150000000,0.015
|
||||
|
|
@ -2513,6 +2513,8 @@ def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ..
|
||||
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/fx/node.cpp
|
||||
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
||||
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
||||
class _NodeBase:
|
||||
_erased: _bool
|
||||
_prev: FxNode
|
||||
|
40
torch/_dynamo/polyfills/fx.py
Normal file
40
torch/_dynamo/polyfills/fx.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from torch._C import _fx_map_aggregate, _fx_map_arg
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.fx.node import Node
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
@substitute_in_graph(_fx_map_arg, can_constant_fold_through=True)
|
||||
def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any:
|
||||
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
|
||||
|
||||
|
||||
@substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True)
|
||||
def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any:
|
||||
result: Any
|
||||
if isinstance(a, tuple):
|
||||
it = (map_aggregate(elem, fn) for elem in a)
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
|
||||
elif isinstance(a, list):
|
||||
result = immutable_list([map_aggregate(elem, fn) for elem in a])
|
||||
elif isinstance(a, dict):
|
||||
result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
|
||||
elif isinstance(a, slice):
|
||||
result = slice(
|
||||
map_aggregate(a.start, fn),
|
||||
map_aggregate(a.stop, fn),
|
||||
map_aggregate(a.step, fn),
|
||||
)
|
||||
else:
|
||||
result = fn(a)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"map_arg",
|
||||
"map_aggregate",
|
||||
]
|
@ -20,6 +20,7 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
|
||||
"os",
|
||||
"pytree",
|
||||
"sys",
|
||||
"fx",
|
||||
)
|
||||
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
|
||||
importlib.import_module(f".{submodule}", package=polyfills.__name__)
|
||||
|
@ -1,8 +1,146 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
namespace {
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
class PythonError : public std::exception {};
|
||||
|
||||
inline static PyObject* import_from(const char* module_name, const char* name) {
|
||||
THPObjectPtr module(PyImport_ImportModule(module_name));
|
||||
if (!module) {
|
||||
throw PythonError();
|
||||
}
|
||||
PyObject* result = PyObject_GetAttrString(module, name);
|
||||
if (!result) {
|
||||
throw PythonError();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline static PyObject* immutable_list_cls() {
|
||||
static PyObject* immutable_list_cls = nullptr;
|
||||
if (!immutable_list_cls) {
|
||||
immutable_list_cls =
|
||||
import_from("torch.fx.immutable_collections", "immutable_list");
|
||||
}
|
||||
return immutable_list_cls;
|
||||
}
|
||||
|
||||
inline static PyObject* immutable_dict_cls() {
|
||||
static PyObject* immutable_dict_cls = nullptr;
|
||||
if (!immutable_dict_cls) {
|
||||
immutable_dict_cls =
|
||||
import_from("torch.fx.immutable_collections", "immutable_dict");
|
||||
}
|
||||
return immutable_dict_cls;
|
||||
}
|
||||
|
||||
inline static bool is_node(PyObject* obj) {
|
||||
static PyObject* node_cls = nullptr;
|
||||
if (!node_cls) {
|
||||
node_cls = import_from("torch.fx.node", "Node");
|
||||
}
|
||||
return PyObject_TypeCheck(obj, reinterpret_cast<PyTypeObject*>(node_cls));
|
||||
}
|
||||
|
||||
inline static bool exact_type(PyObject* obj, PyObject* typ) {
|
||||
return Py_TYPE(obj) == reinterpret_cast<PyTypeObject*>(typ);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
inline static PyObject* map_aggregate(PyObject* a, F fn) {
|
||||
// Invariant: this function will throw an exception and never return nullptr.
|
||||
// Case 1: a is a tuple.
|
||||
if (PyTuple_Check(a)) {
|
||||
Py_ssize_t n = PyTuple_GET_SIZE(a);
|
||||
if (n == 0 && PyTuple_CheckExact(a)) {
|
||||
return Py_NewRef(a);
|
||||
}
|
||||
THPObjectPtr new_tuple(PyTuple_New(n));
|
||||
if (!new_tuple) {
|
||||
throw PythonError();
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* elem = PyTuple_GET_ITEM(a, i); // Borrowed reference.
|
||||
// PyTuple_SET_ITEM steals reference to result of map_aggregate
|
||||
PyTuple_SET_ITEM(new_tuple.get(), i, map_aggregate(elem, fn));
|
||||
}
|
||||
// If the tuple has a "_fields" attribute, assume it is a NamedTuple.
|
||||
if (!PyTuple_CheckExact(a) && PyObject_HasAttrString(a, "_fields")) {
|
||||
// Call type_obj with new_tuple as arguments (i.e. type(a)(*new_tuple))
|
||||
return PyObject_CallObject(
|
||||
reinterpret_cast<PyObject*>(Py_TYPE(a)), new_tuple);
|
||||
} else {
|
||||
return new_tuple.release();
|
||||
}
|
||||
}
|
||||
// Case 2: a is a list.
|
||||
else if (PyList_Check(a)) {
|
||||
Py_ssize_t n = PyList_GET_SIZE(a);
|
||||
if (n == 0 && exact_type(a, immutable_list_cls())) {
|
||||
return Py_NewRef(a);
|
||||
}
|
||||
THPObjectPtr result(PyObject_CallNoArgs(immutable_list_cls()));
|
||||
if (!result) {
|
||||
throw PythonError();
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref
|
||||
THPObjectPtr mapped(map_aggregate(elem, fn));
|
||||
if (PyList_Append(result.get(), mapped.get()) < 0) {
|
||||
throw PythonError();
|
||||
}
|
||||
}
|
||||
return result.release();
|
||||
}
|
||||
// Case 3: a is a dict.
|
||||
else if (PyDict_Check(a)) {
|
||||
if (PyDict_GET_SIZE(a) == 0 && exact_type(a, immutable_dict_cls())) {
|
||||
return Py_NewRef(a);
|
||||
}
|
||||
THPObjectPtr result(PyObject_CallNoArgs(immutable_dict_cls()));
|
||||
if (!result) {
|
||||
throw PythonError();
|
||||
}
|
||||
PyObject *key = nullptr, *value = nullptr; // borrowed
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(a, &pos, &key, &value)) {
|
||||
THPObjectPtr mapped(map_aggregate(value, fn));
|
||||
if (PyDict_SetItem(result.get(), key, mapped.get()) < 0) {
|
||||
throw PythonError();
|
||||
}
|
||||
}
|
||||
return result.release();
|
||||
}
|
||||
// Case 4: a is a slice.
|
||||
else if (PySlice_Check(a)) {
|
||||
// Get start, stop, and step attributes.
|
||||
THPObjectPtr start(PyObject_GetAttrString(a, "start"));
|
||||
THPObjectPtr stop(PyObject_GetAttrString(a, "stop"));
|
||||
THPObjectPtr step(PyObject_GetAttrString(a, "step"));
|
||||
if (!start || !stop || !step) {
|
||||
throw PythonError();
|
||||
}
|
||||
THPObjectPtr mapped_start(map_aggregate(start, fn));
|
||||
THPObjectPtr mapped_stop(map_aggregate(stop, fn));
|
||||
THPObjectPtr mapped_step(map_aggregate(step, fn));
|
||||
return PySlice_New(
|
||||
mapped_start.get(), mapped_stop.get(), mapped_step.get());
|
||||
}
|
||||
// Default case: call fn(a).
|
||||
else {
|
||||
PyObject* result = fn(a);
|
||||
if (!result) {
|
||||
throw PythonError();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////
|
||||
// NodeBase
|
||||
///////////////////////////////
|
||||
@ -59,7 +197,7 @@ static void NodeBase_dealloc(PyObject* self) {
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyTypeObject NodeBaseType = {
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
sizeof(NodeBase), /* tp_basicsize */
|
||||
@ -101,12 +239,7 @@ static PyTypeObject NodeBaseType = {
|
||||
NodeBase_new, /* tp_new */
|
||||
};
|
||||
|
||||
bool NodeBase_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeBaseType) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
////////////////////////////////
|
||||
// NodeIter
|
||||
@ -259,3 +392,71 @@ bool NodeIter_init(PyObject* module) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////
|
||||
// Global methods
|
||||
////////////////////////////////
|
||||
|
||||
static PyObject* py_map_aggregate(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
if (nargs != 2) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "map_aggregate() takes exactly two arguments");
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
PyObject* fn = args[1];
|
||||
// args[0]: aggregate, args[1]: callable fn
|
||||
return map_aggregate(
|
||||
args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); });
|
||||
} catch (const PythonError& e) {
|
||||
return nullptr; // error should already be set
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* py_map_arg(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
if (nargs != 2) {
|
||||
PyErr_SetString(PyExc_TypeError, "map_arg() takes exactly two arguments");
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
PyObject* fn = args[1];
|
||||
// args[0]: aggregate, args[1]: callable fn
|
||||
return map_aggregate(args[0], [fn](PyObject* a) {
|
||||
if (is_node(a)) {
|
||||
return PyObject_CallOneArg(fn, a);
|
||||
}
|
||||
return Py_NewRef(a);
|
||||
});
|
||||
} catch (const PythonError& e) {
|
||||
return nullptr; // error should already be set
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef extra_methods[] = {
|
||||
{"_fx_map_aggregate",
|
||||
(PyCFunction)(void*)(py_map_aggregate),
|
||||
METH_FASTCALL,
|
||||
"Recursively apply a function to every element in an aggregate object."},
|
||||
{"_fx_map_arg",
|
||||
(PyCFunction)(void*)(py_map_arg),
|
||||
METH_FASTCALL,
|
||||
"Recursively apply a function to every Node in an aggregate object."},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
bool NodeBase_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeBaseType) < 0) {
|
||||
return false;
|
||||
}
|
||||
if (PyModule_AddFunctions(module, extra_methods) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -19,11 +19,11 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _NodeIter
|
||||
from torch._C import _fx_map_arg as map_arg, _NodeIter
|
||||
|
||||
from . import _pytree as fx_pytree
|
||||
from ._compatibility import compatibility
|
||||
from .node import _get_qualified_name, _type_repr, Argument, map_arg, Node, Target
|
||||
from .node import _get_qualified_name, _type_repr, Argument, Node, Target
|
||||
|
||||
|
||||
__all__ = ["PythonCode", "CodeGen", "Graph"]
|
||||
|
@ -5,10 +5,10 @@ import logging
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _NodeBase
|
||||
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
|
||||
from torch.fx.operator_schemas import (
|
||||
ArgsKwargsPair,
|
||||
normalize_function,
|
||||
@ -17,7 +17,6 @@ from torch.fx.operator_schemas import (
|
||||
|
||||
from .._ops import ops as _ops
|
||||
from ._compatibility import compatibility
|
||||
from .immutable_collections import immutable_dict, immutable_list
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -535,7 +534,7 @@ class Node(_NodeBase):
|
||||
self._args = args_left + (arg,) + args_right
|
||||
|
||||
_new_input_nodes: dict[Node, None] = {}
|
||||
map_arg(arg, _new_input_nodes.setdefault)
|
||||
_fx_map_arg(arg, _new_input_nodes.setdefault)
|
||||
|
||||
for new_use in _new_input_nodes.keys():
|
||||
if new_use not in self._input_nodes:
|
||||
@ -596,10 +595,10 @@ class Node(_NodeBase):
|
||||
# - Populate self._input_nodes
|
||||
# - Populate arg.users[self] for each arg
|
||||
object.__setattr__(
|
||||
self, "_args", map_aggregate(new_args, update_users_and_input_nodes)
|
||||
self, "_args", _fx_map_aggregate(new_args, update_users_and_input_nodes)
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)
|
||||
self, "_kwargs", _fx_map_aggregate(new_kwargs, update_users_and_input_nodes)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -748,8 +747,8 @@ class Node(_NodeBase):
|
||||
for replace_hook in m._replace_hooks:
|
||||
replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||
|
||||
new_args = map_arg(use_node.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
|
||||
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
|
||||
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
use_node.__update_args_kwargs(new_args, new_kwargs)
|
||||
@ -860,8 +859,8 @@ class Node(_NodeBase):
|
||||
for replace_hook in m._replace_hooks:
|
||||
replace_hook(old=old_input, new=new_input.name, user=self)
|
||||
|
||||
new_args = map_arg(self.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(self.kwargs, maybe_replace_node)
|
||||
new_args = _fx_map_arg(self.args, maybe_replace_node)
|
||||
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
self.__update_args_kwargs(new_args, new_kwargs)
|
||||
@ -903,7 +902,7 @@ def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT:
|
||||
have the same type and structure.
|
||||
"""
|
||||
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
|
||||
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
|
||||
return _fx_map_arg(a, fn)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@ -914,23 +913,4 @@ def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT
|
||||
arg may be a list, tuple, slice, or dict with string keys: the return value will
|
||||
have the same type and structure.
|
||||
"""
|
||||
result: Argument
|
||||
|
||||
if isinstance(a, tuple):
|
||||
it = (map_aggregate(elem, fn) for elem in a)
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
|
||||
elif isinstance(a, list):
|
||||
result = immutable_list([map_aggregate(elem, fn) for elem in a])
|
||||
elif isinstance(a, dict):
|
||||
result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
|
||||
elif isinstance(a, slice):
|
||||
result = slice(
|
||||
map_aggregate(a.start, fn),
|
||||
map_aggregate(a.stop, fn),
|
||||
map_aggregate(a.step, fn),
|
||||
)
|
||||
else:
|
||||
result = fn(a)
|
||||
|
||||
return cast(ArgumentT, result)
|
||||
return _fx_map_aggregate(a, fn)
|
||||
|
@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._C import _fx_map_aggregate as map_aggregate
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from .node import Argument, base_types, map_aggregate, Node, Target
|
||||
from .node import Argument, base_types, Node, Target
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
|
||||
|
||||
@ -584,8 +585,8 @@ class Proxy:
|
||||
if isinstance(a, cls):
|
||||
tracers[a.tracer] = None
|
||||
|
||||
torch.fx.node.map_aggregate(args, find_tracer)
|
||||
torch.fx.node.map_aggregate(kwargs, find_tracer)
|
||||
map_aggregate(args, find_tracer)
|
||||
map_aggregate(kwargs, find_tracer)
|
||||
|
||||
if len(tracers) > 1:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user