mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)
Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260, #148261, #148288
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f858e226b
commit
a60b4ed623
@ -1,65 +1,65 @@
|
||||
add_loop_eager,compile_time_instruction_count,2853000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,2806000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5525000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5460000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,27830000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,27520000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40840000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40410000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,24230000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,23970000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,949600000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,953800000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17290000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17070000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15510000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15320000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9800000000,0.2
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9714000000,0.2
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1544000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1523000000,0.02
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,1032000000,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1026000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3065000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3013000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1980000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1964000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5708000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5672000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7911000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7752000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3590000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3537000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9776000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9662000000,0.015
|
||||
|
|
@ -479,38 +479,24 @@ class CodeGen:
|
||||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
codes = {
|
||||
"yellow": "\033[33m",
|
||||
"cyan": "\033[36m",
|
||||
"green": "\033[32m",
|
||||
"blue": "\033[34m",
|
||||
"red": "\033[31m",
|
||||
"dim": "\033[2m",
|
||||
"dim_blue": "\033[2m\033[34m",
|
||||
"dim_green": "\033[2m\033[32m",
|
||||
"reset": "\033[0m",
|
||||
}
|
||||
|
||||
def make_wrapper_func(name):
|
||||
def f(s):
|
||||
if colored:
|
||||
return f"{codes[name]}{s}{codes['reset']}"
|
||||
return s
|
||||
|
||||
return f
|
||||
|
||||
yellow = make_wrapper_func("yellow") # noqa: F841
|
||||
cyan = make_wrapper_func("cyan") # noqa: F841
|
||||
red = make_wrapper_func("red")
|
||||
green = make_wrapper_func("green") # noqa: F841
|
||||
dim_green = make_wrapper_func("dim_green")
|
||||
dim = make_wrapper_func("dim")
|
||||
dim_blue = make_wrapper_func("dim_blue")
|
||||
blue = make_wrapper_func("blue")
|
||||
if colored:
|
||||
red = _color_fns["red"]
|
||||
dim_green = _color_fns["dim_green"]
|
||||
dim = _color_fns["dim"]
|
||||
dim_blue = _color_fns["dim_blue"]
|
||||
blue = _color_fns["blue"]
|
||||
else:
|
||||
red = _identity
|
||||
dim_green = _identity
|
||||
dim = _identity
|
||||
dim_blue = _identity
|
||||
blue = _identity
|
||||
|
||||
def _get_repr(arg: Any) -> str:
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
if isinstance(arg, Node): # first because common
|
||||
return repr(arg)
|
||||
elif isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
@ -524,8 +510,6 @@ class CodeGen:
|
||||
cls = arg.__class__
|
||||
clsname = add_global(cls.__name__, cls)
|
||||
return f"{clsname}.{arg.name}"
|
||||
elif isinstance(arg, Node):
|
||||
return repr(arg)
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
size = list(arg.size())
|
||||
dtype = str(arg.dtype).split(".")[-1]
|
||||
@ -545,11 +529,9 @@ class CodeGen:
|
||||
def _format_args(
|
||||
args: tuple[Argument, ...], kwargs: dict[str, Argument]
|
||||
) -> str:
|
||||
args_s = ", ".join(_get_repr(a) for a in args)
|
||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f"{args_s}, {kwargs_s}"
|
||||
return args_s or kwargs_s
|
||||
res = [_get_repr(a) for a in args]
|
||||
res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()])
|
||||
return ", ".join(res)
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
@ -564,8 +546,8 @@ class CodeGen:
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
for input_node in node._input_nodes:
|
||||
register_last_uses(input_node, node)
|
||||
|
||||
def delete_unused_values(user: Node):
|
||||
"""
|
||||
@ -604,22 +586,22 @@ class CodeGen:
|
||||
nonlocal prev_stacktrace
|
||||
|
||||
if node.op not in {"placeholder", "output"}:
|
||||
if node.stack_trace:
|
||||
if node.stack_trace != prev_stacktrace:
|
||||
prev_stacktrace = node.stack_trace
|
||||
summary_str = ""
|
||||
|
||||
if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
|
||||
stack_trace = node.stack_trace
|
||||
if stack_trace:
|
||||
if stack_trace != prev_stacktrace:
|
||||
prev_stacktrace = stack_trace
|
||||
if parsed_stack_trace := _parse_stack_trace(stack_trace):
|
||||
summary_str = parsed_stack_trace.get_summary_str()
|
||||
|
||||
body.append(f'\n {dim("# " + summary_str)}\n')
|
||||
else:
|
||||
summary_str = ""
|
||||
body.append(f'\n {dim(f"# {summary_str}")}\n')
|
||||
elif prev_stacktrace != "":
|
||||
prev_stacktrace = ""
|
||||
no_stacktrace_msg = "# No stacktrace found for following nodes"
|
||||
body.append(f"\n{dim(no_stacktrace_msg)}\n")
|
||||
|
||||
def stringify_shape(shape: Iterable) -> str:
|
||||
return f"[{', '.join(str(x) for x in shape)}]"
|
||||
return f"[{', '.join([str(x) for x in shape])}]"
|
||||
|
||||
def emit_node(node: Node):
|
||||
maybe_type_annotation = (
|
||||
@ -777,8 +759,8 @@ class CodeGen:
|
||||
new_lines: list[str] = []
|
||||
cur_idx = None
|
||||
for line in "".join(body).split("\n"):
|
||||
counter = re.search(r"# COUNTER: (\d+)", line)
|
||||
if counter and counter.group(1) is not None:
|
||||
counter = _counter_regexp.search(line)
|
||||
if counter is not None:
|
||||
cur_idx = int(counter.group(1))
|
||||
else:
|
||||
lineno_map[len(new_lines) + prologue_len] = cur_idx
|
||||
@ -1207,12 +1189,10 @@ class Graph:
|
||||
|
||||
# Null out this Node's argument nodes so that the Nodes referred to
|
||||
# can update their ``users`` accordingly
|
||||
new_args = map_arg(to_erase.args, lambda n: None)
|
||||
assert isinstance(new_args, tuple)
|
||||
to_erase.args = new_args
|
||||
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
to_erase.kwargs = new_kwargs
|
||||
to_erase._update_args_kwargs(
|
||||
map_arg(to_erase._args, lambda n: None),
|
||||
map_arg(to_erase._kwargs, lambda n: None),
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def inserting_before(self, n: Optional[Node] = None):
|
||||
@ -1726,21 +1706,14 @@ class Graph:
|
||||
seen_names: set[str] = set()
|
||||
seen_values: set[Node] = set()
|
||||
for node in self.nodes:
|
||||
if node.op not in [
|
||||
"placeholder",
|
||||
"call_method",
|
||||
"call_module",
|
||||
"call_function",
|
||||
"get_attr",
|
||||
"output",
|
||||
]:
|
||||
if node.op not in _legal_ops:
|
||||
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
|
||||
if node.graph is not self:
|
||||
raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
|
||||
if node not in self._find_nodes_lookup_table:
|
||||
raise RuntimeError(f"Node '{node}' is not added to the side table")
|
||||
map_arg(node.args, lambda arg: check_arg(arg, node))
|
||||
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
|
||||
for arg in node._input_nodes:
|
||||
check_arg(arg, node)
|
||||
seen_values.add(node)
|
||||
|
||||
if node.name in seen_names:
|
||||
@ -1959,6 +1932,32 @@ class Graph:
|
||||
return on_generate_code_context_manager()
|
||||
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def _make_color_fn(code):
|
||||
def f(s):
|
||||
reset = "\033[0m"
|
||||
return f"{code}{s}{reset}"
|
||||
|
||||
return f
|
||||
|
||||
|
||||
_color_codes = {
|
||||
"yellow": "\033[33m",
|
||||
"cyan": "\033[36m",
|
||||
"green": "\033[32m",
|
||||
"blue": "\033[34m",
|
||||
"red": "\033[31m",
|
||||
"dim": "\033[2m",
|
||||
"dim_blue": "\033[2m\033[34m",
|
||||
"dim_green": "\033[2m\033[32m",
|
||||
}
|
||||
_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()}
|
||||
_counter_regexp = re.compile(r"# COUNTER: (\d+)")
|
||||
|
||||
|
||||
reflectable_magic_methods = {
|
||||
"add": "{} + {}",
|
||||
"sub": "{} - {}",
|
||||
|
@ -115,8 +115,8 @@ class Interpreter:
|
||||
self.user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(self.graph.nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
for n in node._input_nodes:
|
||||
register_last_uses(n, node)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run(
|
||||
|
@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._C import _fx_map_aggregate as map_aggregate
|
||||
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from .immutable_collections import immutable_dict, immutable_list
|
||||
from .node import Argument, base_types, Node, Target
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
|
||||
@ -302,6 +303,13 @@ class TracerBase:
|
||||
# into the graph. In particular, Tensor operations should go into the graph,
|
||||
# but non-Tensor operations shouldn't. What that means is that constructors
|
||||
# for new types *SHOULD NOT* become nodes in the FX graph.
|
||||
handler = _create_arg_bypass.get(type(a))
|
||||
if handler is not None:
|
||||
# this is just a performance optimization and can be removed if needed
|
||||
# for common types, we have a fast path to avoid isinstance() overhead
|
||||
# this doesn't remove the checks below since we need to handle subclasses
|
||||
return handler(self, a)
|
||||
|
||||
if isinstance(a, Proxy):
|
||||
return a.node # most common arg type goes first
|
||||
elif hasattr(a, "__fx_create_arg__"):
|
||||
@ -318,24 +326,7 @@ class TracerBase:
|
||||
elif isinstance(a, list):
|
||||
return [self.create_arg(elem) for elem in a]
|
||||
elif isinstance(a, dict):
|
||||
|
||||
def no_node(arg):
|
||||
if isinstance(arg, Node):
|
||||
raise RuntimeError(
|
||||
"Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {k}"
|
||||
)
|
||||
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||
# anywhere within the key. Since keys can be collection types,
|
||||
# we iterate through the key with map_aggregate
|
||||
k = self.create_arg(k)
|
||||
map_aggregate(k, no_node)
|
||||
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
return _create_arg_dict(self, a)
|
||||
elif isinstance(a, slice):
|
||||
return slice(
|
||||
self.create_arg(a.start),
|
||||
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
|
||||
|
||||
for orig_method_name in reflectable_magic_methods:
|
||||
_define_reflectable(orig_method_name)
|
||||
|
||||
|
||||
def _no_nodes_error(arg):
|
||||
raise RuntimeError(
|
||||
"Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {arg}"
|
||||
)
|
||||
|
||||
|
||||
def _create_arg_dict(self, a):
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
if not isinstance(k, str):
|
||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||
# anywhere within the key. Since keys can be collection types,
|
||||
# we iterate through the key with map_arg
|
||||
k = self.create_arg(k)
|
||||
map_arg(k, _no_nodes_error)
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
|
||||
|
||||
_create_arg_bypass = {
|
||||
t: lambda self, a: a
|
||||
for t in [
|
||||
*base_types,
|
||||
type(None),
|
||||
type(...),
|
||||
torch._ops.OpOverload,
|
||||
torch._ops.HigherOrderOperator,
|
||||
]
|
||||
}
|
||||
_create_arg_bypass[Proxy] = lambda self, a: a.node
|
||||
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
|
||||
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
|
||||
_create_arg_bypass[dict] = _create_arg_dict
|
||||
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
|
||||
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]
|
||||
|
Reference in New Issue
Block a user