[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)

Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds
After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260, #148261, #148288
This commit is contained in:
Jason Ansel
2025-03-09 21:21:58 -07:00
committed by PyTorch MergeBot
parent 8f858e226b
commit a60b4ed623
4 changed files with 132 additions and 104 deletions

View File

@ -1,65 +1,65 @@
add_loop_eager,compile_time_instruction_count,2853000000,0.015
add_loop_eager,compile_time_instruction_count,2806000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5525000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5460000000,0.025
add_loop_inductor,compile_time_instruction_count,27830000000,0.015
add_loop_inductor,compile_time_instruction_count,27520000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40840000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40410000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,24230000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,23970000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,949600000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,953800000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17290000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17070000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15510000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15320000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9800000000,0.2
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9714000000,0.2
update_hint_regression,compile_time_instruction_count,1544000000,0.02
update_hint_regression,compile_time_instruction_count,1523000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1032000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1026000000,0.015
symint_sum,compile_time_instruction_count,3065000000,0.015
symint_sum,compile_time_instruction_count,3013000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1980000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1964000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5708000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5672000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7911000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7752000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3590000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3537000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9776000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9662000000,0.015

1 add_loop_eager compile_time_instruction_count 2853000000 2806000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5525000000 5460000000 0.025
3 add_loop_inductor compile_time_instruction_count 27830000000 27520000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 40840000000 40410000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 24230000000 23970000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 949600000 953800000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 17290000000 17070000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 15510000000 15320000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 9800000000 9714000000 0.2
10 update_hint_regression compile_time_instruction_count 1544000000 1523000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1032000000 1026000000 0.015
12 symint_sum compile_time_instruction_count 3065000000 3013000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 1980000000 1964000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5708000000 5672000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 7911000000 7752000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3590000000 3537000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9776000000 9662000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -479,38 +479,24 @@ class CodeGen:
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
codes = {
"yellow": "\033[33m",
"cyan": "\033[36m",
"green": "\033[32m",
"blue": "\033[34m",
"red": "\033[31m",
"dim": "\033[2m",
"dim_blue": "\033[2m\033[34m",
"dim_green": "\033[2m\033[32m",
"reset": "\033[0m",
}
def make_wrapper_func(name):
def f(s):
if colored:
return f"{codes[name]}{s}{codes['reset']}"
return s
return f
yellow = make_wrapper_func("yellow") # noqa: F841
cyan = make_wrapper_func("cyan") # noqa: F841
red = make_wrapper_func("red")
green = make_wrapper_func("green") # noqa: F841
dim_green = make_wrapper_func("dim_green")
dim = make_wrapper_func("dim")
dim_blue = make_wrapper_func("dim_blue")
blue = make_wrapper_func("blue")
if colored:
red = _color_fns["red"]
dim_green = _color_fns["dim_green"]
dim = _color_fns["dim"]
dim_blue = _color_fns["dim_blue"]
blue = _color_fns["blue"]
else:
red = _identity
dim_green = _identity
dim = _identity
dim_blue = _identity
blue = _identity
def _get_repr(arg: Any) -> str:
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
if isinstance(arg, Node): # first because common
return repr(arg)
elif isinstance(arg, tuple) and hasattr(arg, "_fields"):
# Handle NamedTuples (if it has `_fields`) via add_global.
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
@ -524,8 +510,6 @@ class CodeGen:
cls = arg.__class__
clsname = add_global(cls.__name__, cls)
return f"{clsname}.{arg.name}"
elif isinstance(arg, Node):
return repr(arg)
elif isinstance(arg, torch.Tensor):
size = list(arg.size())
dtype = str(arg.dtype).split(".")[-1]
@ -545,11 +529,9 @@ class CodeGen:
def _format_args(
args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> str:
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
res = [_get_repr(a) for a in args]
res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()])
return ", ".join(res)
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
@ -564,8 +546,8 @@ class CodeGen:
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
for input_node in node._input_nodes:
register_last_uses(input_node, node)
def delete_unused_values(user: Node):
"""
@ -604,22 +586,22 @@ class CodeGen:
nonlocal prev_stacktrace
if node.op not in {"placeholder", "output"}:
if node.stack_trace:
if node.stack_trace != prev_stacktrace:
prev_stacktrace = node.stack_trace
summary_str = ""
if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
stack_trace = node.stack_trace
if stack_trace:
if stack_trace != prev_stacktrace:
prev_stacktrace = stack_trace
if parsed_stack_trace := _parse_stack_trace(stack_trace):
summary_str = parsed_stack_trace.get_summary_str()
body.append(f'\n {dim("# " + summary_str)}\n')
else:
summary_str = ""
body.append(f'\n {dim(f"# {summary_str}")}\n')
elif prev_stacktrace != "":
prev_stacktrace = ""
no_stacktrace_msg = "# No stacktrace found for following nodes"
body.append(f"\n{dim(no_stacktrace_msg)}\n")
def stringify_shape(shape: Iterable) -> str:
return f"[{', '.join(str(x) for x in shape)}]"
return f"[{', '.join([str(x) for x in shape])}]"
def emit_node(node: Node):
maybe_type_annotation = (
@ -777,8 +759,8 @@ class CodeGen:
new_lines: list[str] = []
cur_idx = None
for line in "".join(body).split("\n"):
counter = re.search(r"# COUNTER: (\d+)", line)
if counter and counter.group(1) is not None:
counter = _counter_regexp.search(line)
if counter is not None:
cur_idx = int(counter.group(1))
else:
lineno_map[len(new_lines) + prologue_len] = cur_idx
@ -1207,12 +1189,10 @@ class Graph:
# Null out this Node's argument nodes so that the Nodes referred to
# can update their ``users`` accordingly
new_args = map_arg(to_erase.args, lambda n: None)
assert isinstance(new_args, tuple)
to_erase.args = new_args
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
assert isinstance(new_kwargs, dict)
to_erase.kwargs = new_kwargs
to_erase._update_args_kwargs(
map_arg(to_erase._args, lambda n: None),
map_arg(to_erase._kwargs, lambda n: None),
)
@compatibility(is_backward_compatible=True)
def inserting_before(self, n: Optional[Node] = None):
@ -1726,21 +1706,14 @@ class Graph:
seen_names: set[str] = set()
seen_values: set[Node] = set()
for node in self.nodes:
if node.op not in [
"placeholder",
"call_method",
"call_module",
"call_function",
"get_attr",
"output",
]:
if node.op not in _legal_ops:
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
if node.graph is not self:
raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
if node not in self._find_nodes_lookup_table:
raise RuntimeError(f"Node '{node}' is not added to the side table")
map_arg(node.args, lambda arg: check_arg(arg, node))
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
for arg in node._input_nodes:
check_arg(arg, node)
seen_values.add(node)
if node.name in seen_names:
@ -1959,6 +1932,32 @@ class Graph:
return on_generate_code_context_manager()
def _identity(x):
return x
def _make_color_fn(code):
def f(s):
reset = "\033[0m"
return f"{code}{s}{reset}"
return f
_color_codes = {
"yellow": "\033[33m",
"cyan": "\033[36m",
"green": "\033[32m",
"blue": "\033[34m",
"red": "\033[31m",
"dim": "\033[2m",
"dim_blue": "\033[2m\033[34m",
"dim_green": "\033[2m\033[32m",
}
_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()}
_counter_regexp = re.compile(r"# COUNTER: (\d+)")
reflectable_magic_methods = {
"add": "{} + {}",
"sub": "{} - {}",

View File

@ -115,8 +115,8 @@ class Interpreter:
self.user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(self.graph.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
for n in node._input_nodes:
register_last_uses(n, node)
@compatibility(is_backward_compatible=True)
def run(

View File

@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility
from .graph import Graph, magic_methods, reflectable_magic_methods
from .immutable_collections import immutable_dict, immutable_list
from .node import Argument, base_types, Node, Target
from .operator_schemas import check_for_mutable_operation
@ -302,6 +303,13 @@ class TracerBase:
# into the graph. In particular, Tensor operations should go into the graph,
# but non-Tensor operations shouldn't. What that means is that constructors
# for new types *SHOULD NOT* become nodes in the FX graph.
handler = _create_arg_bypass.get(type(a))
if handler is not None:
# this is just a performance optimization and can be removed if needed
# for common types, we have a fast path to avoid isinstance() overhead
# this doesn't remove the checks below since we need to handle subclasses
return handler(self, a)
if isinstance(a, Proxy):
return a.node # most common arg type goes first
elif hasattr(a, "__fx_create_arg__"):
@ -318,24 +326,7 @@ class TracerBase:
elif isinstance(a, list):
return [self.create_arg(elem) for elem in a]
elif isinstance(a, dict):
def no_node(arg):
if isinstance(arg, Node):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}"
)
r = {}
for k, v in a.items():
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_aggregate
k = self.create_arg(k)
map_aggregate(k, no_node)
r[k] = self.create_arg(v)
return r
return _create_arg_dict(self, a)
elif isinstance(a, slice):
return slice(
self.create_arg(a.start),
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
def _no_nodes_error(arg):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {arg}"
)
def _create_arg_dict(self, a):
r = {}
for k, v in a.items():
if not isinstance(k, str):
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_arg
k = self.create_arg(k)
map_arg(k, _no_nodes_error)
r[k] = self.create_arg(v)
return r
_create_arg_bypass = {
t: lambda self, a: a
for t in [
*base_types,
type(None),
type(...),
torch._ops.OpOverload,
torch._ops.HigherOrderOperator,
]
}
_create_arg_bypass[Proxy] = lambda self, a: a.node
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
_create_arg_bypass[dict] = _create_arg_dict
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]