[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)

Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds
After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260, #148261, #148288
This commit is contained in:
Jason Ansel
2025-03-09 21:21:58 -07:00
committed by PyTorch MergeBot
parent 8f858e226b
commit a60b4ed623
4 changed files with 132 additions and 104 deletions

View File

@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility
from .graph import Graph, magic_methods, reflectable_magic_methods
from .immutable_collections import immutable_dict, immutable_list
from .node import Argument, base_types, Node, Target
from .operator_schemas import check_for_mutable_operation
@ -302,6 +303,13 @@ class TracerBase:
# into the graph. In particular, Tensor operations should go into the graph,
# but non-Tensor operations shouldn't. What that means is that constructors
# for new types *SHOULD NOT* become nodes in the FX graph.
handler = _create_arg_bypass.get(type(a))
if handler is not None:
# this is just a performance optimization and can be removed if needed
# for common types, we have a fast path to avoid isinstance() overhead
# this doesn't remove the checks below since we need to handle subclasses
return handler(self, a)
if isinstance(a, Proxy):
return a.node # most common arg type goes first
elif hasattr(a, "__fx_create_arg__"):
@ -318,24 +326,7 @@ class TracerBase:
elif isinstance(a, list):
return [self.create_arg(elem) for elem in a]
elif isinstance(a, dict):
def no_node(arg):
if isinstance(arg, Node):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}"
)
r = {}
for k, v in a.items():
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_aggregate
k = self.create_arg(k)
map_aggregate(k, no_node)
r[k] = self.create_arg(v)
return r
return _create_arg_dict(self, a)
elif isinstance(a, slice):
return slice(
self.create_arg(a.start),
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
def _no_nodes_error(arg):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {arg}"
)
def _create_arg_dict(self, a):
r = {}
for k, v in a.items():
if not isinstance(k, str):
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_arg
k = self.create_arg(k)
map_arg(k, _no_nodes_error)
r[k] = self.create_arg(v)
return r
_create_arg_bypass = {
t: lambda self, a: a
for t in [
*base_types,
type(None),
type(...),
torch._ops.OpOverload,
torch._ops.HigherOrderOperator,
]
}
_create_arg_bypass[Proxy] = lambda self, a: a.node
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
_create_arg_bypass[dict] = _create_arg_dict
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]