mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)
Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260, #148261, #148288
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f858e226b
commit
a60b4ed623
@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._C import _fx_map_aggregate as map_aggregate
|
||||
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from .immutable_collections import immutable_dict, immutable_list
|
||||
from .node import Argument, base_types, Node, Target
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
|
||||
@ -302,6 +303,13 @@ class TracerBase:
|
||||
# into the graph. In particular, Tensor operations should go into the graph,
|
||||
# but non-Tensor operations shouldn't. What that means is that constructors
|
||||
# for new types *SHOULD NOT* become nodes in the FX graph.
|
||||
handler = _create_arg_bypass.get(type(a))
|
||||
if handler is not None:
|
||||
# this is just a performance optimization and can be removed if needed
|
||||
# for common types, we have a fast path to avoid isinstance() overhead
|
||||
# this doesn't remove the checks below since we need to handle subclasses
|
||||
return handler(self, a)
|
||||
|
||||
if isinstance(a, Proxy):
|
||||
return a.node # most common arg type goes first
|
||||
elif hasattr(a, "__fx_create_arg__"):
|
||||
@ -318,24 +326,7 @@ class TracerBase:
|
||||
elif isinstance(a, list):
|
||||
return [self.create_arg(elem) for elem in a]
|
||||
elif isinstance(a, dict):
|
||||
|
||||
def no_node(arg):
|
||||
if isinstance(arg, Node):
|
||||
raise RuntimeError(
|
||||
"Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {k}"
|
||||
)
|
||||
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||
# anywhere within the key. Since keys can be collection types,
|
||||
# we iterate through the key with map_aggregate
|
||||
k = self.create_arg(k)
|
||||
map_aggregate(k, no_node)
|
||||
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
return _create_arg_dict(self, a)
|
||||
elif isinstance(a, slice):
|
||||
return slice(
|
||||
self.create_arg(a.start),
|
||||
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
|
||||
|
||||
for orig_method_name in reflectable_magic_methods:
|
||||
_define_reflectable(orig_method_name)
|
||||
|
||||
|
||||
def _no_nodes_error(arg):
|
||||
raise RuntimeError(
|
||||
"Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {arg}"
|
||||
)
|
||||
|
||||
|
||||
def _create_arg_dict(self, a):
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
if not isinstance(k, str):
|
||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||
# anywhere within the key. Since keys can be collection types,
|
||||
# we iterate through the key with map_arg
|
||||
k = self.create_arg(k)
|
||||
map_arg(k, _no_nodes_error)
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
|
||||
|
||||
_create_arg_bypass = {
|
||||
t: lambda self, a: a
|
||||
for t in [
|
||||
*base_types,
|
||||
type(None),
|
||||
type(...),
|
||||
torch._ops.OpOverload,
|
||||
torch._ops.HigherOrderOperator,
|
||||
]
|
||||
}
|
||||
_create_arg_bypass[Proxy] = lambda self, a: a.node
|
||||
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
|
||||
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
|
||||
_create_arg_bypass[dict] = _create_arg_dict
|
||||
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
|
||||
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]
|
||||
|
||||
Reference in New Issue
Block a user