Files
pytorch/torch/_functorch/compile_utils.py

213 lines
7.5 KiB
Python

# mypy: ignore-errors
import operator
from typing import Callable
import sympy
import torch
import torch.fx as fx
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_flatten
aten = torch.ops.aten
def get_aten_target(node: fx.Node) -> Callable:
if hasattr(node.target, "overloadpacket"):
return node.target.overloadpacket
return node.target
rand_ops = [
aten.dropout,
aten._fused_dropout,
aten._standard_gamma,
aten.bernoulli,
aten.multinomial,
aten.native_dropout,
aten.normal,
aten.poisson,
aten.binomial,
aten.rrelu,
aten.rand_like,
aten.rand,
aten.randint,
aten.randn,
aten.randperm,
]
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
new_graph = fx.Graph()
env = {} # map from node in the old graph to node in the new graph
hash_env = {} # map from hash to a node in the new graph
token_map = {} # map from hash to token
from torch._inductor.pattern_matcher import (
compute_mutation_region_ids,
same_mutation_regions,
)
compute_mutation_region_ids(fx_g) # type: ignore[arg-type]
# Make a set of separate storages returned from the output, which will be preserved
# when pruning. This prevents us from deduplicating returned tensors which have
# experienced identical operations, but are separate data structures in eager mode.
output_node: fx.Node = list(fx_g.nodes)[-1]
assert output_node.op == "output"
def checkable_node(node: fx.Node) -> bool:
"""We can evaluate only nodes that represent tensors with defined storage."""
if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor):
return False
try:
node.meta["val"].untyped_storage()
except NotImplementedError:
return False
return True
output_storages = {
StorageWeakRef(n.meta["val"].untyped_storage())
for n in output_node.all_input_nodes
if checkable_node(n)
}
nodes_that_alias_outputs = {
n
for n in fx_g.nodes
if checkable_node(n)
and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages
}
for n in fx_g.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change
# do not CSE away random operations
if (
n.op == "placeholder"
or n.op == "output"
or n.op == "get_attr"
or get_aten_target(n) in rand_ops
# aten.empty is non-deterministic, so don't CSE it.
# Also, aten.empty is almost always fusible into its consumer,
# so it's not worth CSEing.
or get_aten_target(n) is aten.empty
or n in nodes_that_alias_outputs
# This CSE pass currently doesn't handle re-propogation of unbacked
# meta where it'll sometimes eliminate a _local_scalar_dense but not
# replace the meta of downstream users. eg. one bug we've seen is:
#
# _local_scalar_dense_11: "Sym(u14)" = torch.ops.aten._local_scalar_dense.default(select_10);
# sym_sum_2: "Sym(u19 + u20 + u21)" = torch.sym_sum((_local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13)) # noqa: B950
#
# Notice how _local_scalar_dense_11 is u14 but sym_sum_2's meta is incorrectly the old
# pre-cse value of u19.
or (
"val" in n.meta
and isinstance(n.meta["val"], sympy.Symbol)
and free_unbacked_symbols(n.meta["val"])
)
):
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs members to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, torch.fx.node.Node) and v in env:
arg_list[i] = env[v]
if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
arg_list[i] = v.node
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {
"target": n.target,
"args": args,
"args_spec": args_spec,
"kwargs": kwargs,
"kwargs_spec": kwargs_spec,
}
# hash substituted args to a number, do not hash specs because specs are not hashable
# We need to add type into hash to avoid situations like:
# hash((primals_2, 1.0)) == hash((primals_2, 1))
hash_arg = hash(
(tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
)
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
overwrite_due_to_mutation = False
if hash_val_in_hash_env and token_map[hash_val] == token:
duplicate_n_prev = hash_env[hash_val]
if same_mutation_regions(n, duplicate_n_prev):
env[n] = duplicate_n_prev
continue
else:
# any futures duplicates should replace with n, not duplicate_n_prev
overwrite_due_to_mutation = True
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if overwrite_due_to_mutation or not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
return new_graph
def raise_getitems(gm: fx.GraphModule) -> fx.GraphModule:
# Pre-create a list of nodes to iterate over, as modifying the node order
# during the loop can lead to infinite loops if not handled properly.
getitem_nodes = list(
gm.graph.find_nodes(op="call_function", target=operator.getitem)
)
# loop through getitem nodes in the graph and raise them to the parent node
# in reverse order to preserve their original relative order
for node in reversed(getitem_nodes):
assert len(node.all_input_nodes) == 1
parent = node.all_input_nodes[0]
parent.append(node)
gm.recompile()
return gm
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def get_placeholders(graph):
return graph.find_nodes(op="placeholder")
def get_outputs(graph):
for node in graph.find_nodes(op="output"):
return pytree.tree_leaves(node.args[0])
raise AssertionError("No output node found")