mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454 and considerable manual editing Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466 Approved by: https://github.com/eellison
253 lines
8.8 KiB
Python
253 lines
8.8 KiB
Python
# mypy: allow-untyped-defs
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
compute_unbacked_bindings,
|
|
rebind_unbacked,
|
|
statically_known_true,
|
|
sym_eq,
|
|
)
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from .virtualized import V
|
|
|
|
|
|
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
|
# Works for length 2 patterns with 1 module and 1 function/method.
|
|
def matches_module_function_pattern(
|
|
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
|
node: torch.fx.node.Node,
|
|
modules: Dict[str, torch.nn.modules.Module],
|
|
) -> bool:
|
|
if len(node.args) == 0:
|
|
return False
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node, torch.fx.Node
|
|
):
|
|
return False
|
|
# the first node is call_module
|
|
if node.args[0].op != "call_module":
|
|
return False
|
|
if not isinstance(node.args[0].target, str):
|
|
return False
|
|
if node.args[0].target not in modules:
|
|
return False
|
|
if type(modules[node.args[0].target]) is not pattern[0]:
|
|
return False
|
|
# the second node is call_function or call_method
|
|
if node.op != "call_function" and node.op != "call_method":
|
|
return False
|
|
if node.target != pattern[1]:
|
|
return False
|
|
# make sure node.args[0] output is only used by current node.
|
|
if len(node.args[0].users) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
class FakeTensorUpdater:
|
|
"""
|
|
The main idea here is that it's difficult to maintain accurate fake
|
|
tensors (our primary form of metadata) for each node in our graph as we
|
|
transform it.
|
|
|
|
The most reliable way to obtain this information is by rerunning
|
|
faketensor propagation. However, in general, faketensor propagation is
|
|
fairly expensive. So, instead we'd like to only rerun faketensor
|
|
propagation on nodes that have changed.
|
|
|
|
In order to detect which nodes have changed, we first hash its node,
|
|
target, and argument lists (which are immutable in FX).
|
|
|
|
Then, whenever we call incremental_update, we check which FX nodes have a
|
|
new hash, and recompute the faketensor metadata for that node. Then, we
|
|
continue to recursively compute the faketensors for all users until the
|
|
fake tensors stop changing.
|
|
"""
|
|
|
|
def __init__(self, graph: torch.fx.Graph) -> None:
|
|
self.processed_hashes = OrderedSet[Any]()
|
|
self.graph = graph
|
|
|
|
for node in self.graph.nodes:
|
|
self.processed_hashes.add(self.hash_node(node))
|
|
|
|
def hash_node(self, node: torch.fx.Node):
|
|
# todo(chilli): Not a great hash function
|
|
return (node, node.target, id(node.args), id(node.kwargs))
|
|
|
|
def incremental_update(self):
|
|
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
|
for node in self.graph.nodes:
|
|
existing_storages[get_node_storage(node)] += 1
|
|
|
|
def is_intlist_same(new, old):
|
|
return statically_known_true(sym_eq(new, old))
|
|
|
|
def is_fake_tensor_same(new, old):
|
|
if type(new) != type(old):
|
|
return False
|
|
if isinstance(new, (list, tuple)):
|
|
if len(new) != len(old):
|
|
return False
|
|
return all(
|
|
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
|
)
|
|
if new is None:
|
|
return old is None
|
|
if not isinstance(new, torch.Tensor):
|
|
assert isinstance(
|
|
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
|
|
), f"Unknown type {type(new)} in {self.graph}"
|
|
return (
|
|
new.node.shape_env._maybe_evaluate_static(
|
|
sympy.Eq(new.node.expr, old.node.expr)
|
|
)
|
|
== sympy.true
|
|
)
|
|
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
|
|
return False
|
|
if new.layout == torch.strided and (
|
|
not is_intlist_same(new.stride(), old.stride())
|
|
or not statically_known_true(
|
|
new.storage_offset() == old.storage_offset()
|
|
)
|
|
):
|
|
return False
|
|
|
|
if new.device != old.device:
|
|
return False
|
|
|
|
if get_storage(new) == get_storage(old):
|
|
return True
|
|
|
|
# This is the case where it returns a completely fresh storage that's used nowhere else.
|
|
if (
|
|
existing_storages[get_storage(old)] == 1
|
|
and get_storage(new) not in existing_storages
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def should_process_node(node):
|
|
# node.target for nodes returning true from this function
|
|
# are called under fake mode and does not work for inductor
|
|
# lowerings. We check if the node.target is an aten operator
|
|
# or operator.getitem which is used when returning multiple
|
|
# tensors from an op.
|
|
return node.op == "call_function" and (
|
|
isinstance(node.target, torch._ops.OpOverload)
|
|
or node.target == operator.getitem
|
|
)
|
|
|
|
to_process = OrderedSet[int]()
|
|
for node in self.graph.nodes:
|
|
if (
|
|
self.hash_node(node) in self.processed_hashes
|
|
and id(node) not in to_process
|
|
):
|
|
continue
|
|
|
|
if not should_process_node(node):
|
|
continue
|
|
|
|
is_valid, args, kwargs = get_fake_args_kwargs(node)
|
|
if not is_valid:
|
|
continue
|
|
with V.fake_mode, enable_python_dispatcher():
|
|
new_fake_tensor = node.target(*args, **kwargs)
|
|
if "val" in node.meta and is_fake_tensor_same(
|
|
new_fake_tensor, node.meta["val"]
|
|
):
|
|
continue
|
|
|
|
rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
|
|
|
|
node.meta["val"] = new_fake_tensor
|
|
if (shape_env := V.fake_mode.shape_env) and (
|
|
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
|
|
):
|
|
# Refresh the bindings to the new symbols
|
|
node.meta["unbacked_bindings"] = symbol_to_path
|
|
|
|
existing_storages[get_node_storage(node)] += 1
|
|
|
|
to_process.update([id(user) for user in node.users])
|
|
|
|
self.processed_hashes.add(self.hash_node(node))
|
|
|
|
|
|
def get_storage(t: torch.Tensor) -> int:
|
|
return t.untyped_storage()._cdata
|
|
|
|
|
|
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
|
|
if "val" not in node.meta:
|
|
return None
|
|
if not isinstance(node.meta["val"], torch.Tensor):
|
|
return None
|
|
if not torch._C._has_storage(node.meta["val"]):
|
|
return None
|
|
return get_storage(node.meta["val"])
|
|
|
|
|
|
def get_fake(x):
|
|
if isinstance(x, torch.fx.Node):
|
|
if "val" not in x.meta:
|
|
return x
|
|
return x.meta["val"]
|
|
return x
|
|
|
|
|
|
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
|
|
"""
|
|
First value returns a boolean if any of the input nodes don't have a faketensor.
|
|
"""
|
|
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
|
|
if any(
|
|
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
|
|
):
|
|
return False, args, kwargs
|
|
return True, args, kwargs
|
|
|
|
|
|
def is_node_realized(node: torch.fx.Node) -> bool:
|
|
"""Returns true if a node is always realized when lowered to inductor IR.
|
|
|
|
NOTE: This may return some false negatives. e.g. it doesn't
|
|
handle buffers realized heuristically during lowering, or
|
|
buffers realized indirectly through view ops.
|
|
"""
|
|
from torch._inductor.lowering import fallbacks, needs_realized_inputs
|
|
|
|
def is_buffer(node: torch.fx.Node) -> bool:
|
|
if node.op == "call_function" and node.target is operator.getitem:
|
|
# For nodes with multiple outputs, we get the fx graph:
|
|
# foo = torch.ops.aten.foo(...)
|
|
# getitem = foo[0]
|
|
# getitem_1 = foo[1]
|
|
# where we need to check if foo is a fallback kernel
|
|
return is_buffer(node.args[0]) # type: ignore[arg-type]
|
|
return node.op in ("placeholder", "output") or node.target in fallbacks
|
|
|
|
if is_buffer(node):
|
|
return True
|
|
|
|
def realizes_inputs(node: torch.fx.Node) -> bool:
|
|
return node.op == "output" or node.target in needs_realized_inputs
|
|
|
|
if any(realizes_inputs(user) for user in node.users):
|
|
return True
|
|
|
|
# Otherwise, assume node isn't realized
|
|
return False
|