|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
# mypy: allow-untyped-defs
|
|
|
|
|
import functools
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
@ -9,13 +8,17 @@ import sys
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from types import ModuleType
|
|
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
|
Callable,
|
|
|
|
|
DefaultDict,
|
|
|
|
|
Dict,
|
|
|
|
|
Iterable,
|
|
|
|
|
List,
|
|
|
|
|
NoReturn,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
|
Set,
|
|
|
|
|
Tuple,
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
@ -23,17 +26,21 @@ from typing import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
import sympy
|
|
|
|
|
from sympy import Expr
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch._logging
|
|
|
|
|
import torch.fx
|
|
|
|
|
from torch import device, Tensor
|
|
|
|
|
from torch._decomp import get_decompositions
|
|
|
|
|
from torch._dynamo.utils import defake, dynamo_timed
|
|
|
|
|
from torch._logging import LazyString, trace_structured
|
|
|
|
|
from torch._prims_common import make_channels_last_strides_for
|
|
|
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
|
|
|
from torch.fx import GraphModule
|
|
|
|
|
from torch.fx.experimental._backward_state import BackwardState
|
|
|
|
|
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
|
|
free_unbacked_symbols,
|
|
|
|
|
has_free_symbols,
|
|
|
|
@ -42,6 +49,8 @@ from torch.fx.experimental.symbolic_shapes import (
|
|
|
|
|
ShapeEnv,
|
|
|
|
|
SymTypes,
|
|
|
|
|
)
|
|
|
|
|
from torch.fx.graph import Graph
|
|
|
|
|
from torch.fx.node import Node
|
|
|
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
|
|
from torch.utils._sympy.numbers import int_oo
|
|
|
|
|
|
|
|
|
@ -82,6 +91,8 @@ from .lowering import (
|
|
|
|
|
needs_realized_inputs,
|
|
|
|
|
unsupported_output_tensor,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from .scheduler import BaseSchedulerNode
|
|
|
|
|
from .sizevars import SizeVarAllocator
|
|
|
|
|
from .utils import (
|
|
|
|
|
convert_shape_to_inductor,
|
|
|
|
@ -112,11 +123,11 @@ if config.is_fbcode():
|
|
|
|
|
from torch._inductor.fb.utils import log_module_code
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
def log_module_code(*args, **kwargs):
|
|
|
|
|
def log_module_code(*args: Any, **kwargs: Any) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
|
|
|
|
def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool:
|
|
|
|
|
supported_dtype = {
|
|
|
|
|
torch.float32,
|
|
|
|
|
torch.float64,
|
|
|
|
@ -141,7 +152,7 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
|
|
|
|
return dtype in supported_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def may_get_constant_buffer_dtype(constant_buffer):
|
|
|
|
|
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
|
|
|
|
|
assert isinstance(
|
|
|
|
|
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
|
|
|
|
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
|
|
|
@ -159,12 +170,14 @@ def may_get_constant_buffer_dtype(constant_buffer):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_magic_method(op):
|
|
|
|
|
def is_magic_method(op: Any) -> bool:
|
|
|
|
|
magic_ops = {method_to_operator(m) for m in magic_methods}
|
|
|
|
|
return op in magic_ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def getattr_recursive(obj, target):
|
|
|
|
|
def getattr_recursive(
|
|
|
|
|
obj: GraphModule, target: str
|
|
|
|
|
) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
|
|
|
|
|
target_atoms = target.split(".")
|
|
|
|
|
attr_itr = obj
|
|
|
|
|
for i, atom in enumerate(target_atoms):
|
|
|
|
@ -176,7 +189,7 @@ def getattr_recursive(obj, target):
|
|
|
|
|
return attr_itr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mark_nodes_dislike_padding(g):
|
|
|
|
|
def mark_nodes_dislike_padding(g: Graph) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Nodes like convolution/convolution_backward want its input to be dense.
|
|
|
|
|
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
|
|
|
|
@ -207,7 +220,9 @@ def mark_nodes_dislike_padding(g):
|
|
|
|
|
aten.scatter_reduce,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _get_overload_packet(node):
|
|
|
|
|
def _get_overload_packet(
|
|
|
|
|
node: torch.fx.Node,
|
|
|
|
|
) -> Optional[torch._ops.OpOverloadPacket]:
|
|
|
|
|
return (
|
|
|
|
|
node.target._overloadpacket
|
|
|
|
|
if node.op == "call_function"
|
|
|
|
@ -237,7 +252,9 @@ def mark_nodes_dislike_padding(g):
|
|
|
|
|
class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
graph_outputs: List[ir.IRNode]
|
|
|
|
|
|
|
|
|
|
def symbolic_sizes_strides(self, ex: torch.Tensor):
|
|
|
|
|
def symbolic_sizes_strides(
|
|
|
|
|
self, ex: torch.Tensor
|
|
|
|
|
) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]:
|
|
|
|
|
"""
|
|
|
|
|
Support dynamic shapes and dynamic strides by assigning variables
|
|
|
|
|
to each dimension. We duck-shape tensors, so if two tensors
|
|
|
|
@ -272,7 +289,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
|
|
|
|
|
return size, stride
|
|
|
|
|
|
|
|
|
|
def static_sizes_strides(self, ex: torch.Tensor):
|
|
|
|
|
def static_sizes_strides(
|
|
|
|
|
self, ex: torch.Tensor
|
|
|
|
|
) -> Tuple[List[sympy.Expr], List[sympy.Expr]]:
|
|
|
|
|
"""
|
|
|
|
|
Primarily used to weights
|
|
|
|
|
"""
|
|
|
|
@ -284,19 +303,21 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self,
|
|
|
|
|
gm: torch.fx.GraphModule,
|
|
|
|
|
example_inputs: Optional[List[torch.Tensor]] = None,
|
|
|
|
|
shape_env=None,
|
|
|
|
|
graph_id=None,
|
|
|
|
|
cpp_wrapper=False,
|
|
|
|
|
aot_mode=False,
|
|
|
|
|
user_visible_outputs=None,
|
|
|
|
|
layout_opt=None,
|
|
|
|
|
extern_node_serializer=None,
|
|
|
|
|
is_inference=False,
|
|
|
|
|
is_const_graph=False,
|
|
|
|
|
const_output_index=None,
|
|
|
|
|
const_code=None,
|
|
|
|
|
const_module=None,
|
|
|
|
|
name=None,
|
|
|
|
|
shape_env: Optional[ShapeEnv] = None,
|
|
|
|
|
graph_id: Optional[int] = None,
|
|
|
|
|
cpp_wrapper: bool = False,
|
|
|
|
|
aot_mode: bool = False,
|
|
|
|
|
user_visible_outputs: Optional[Dict[str, None]] = None,
|
|
|
|
|
layout_opt: Optional[bool] = None,
|
|
|
|
|
extern_node_serializer: Optional[
|
|
|
|
|
Callable[[List[ir.ExternKernelNode]], Any]
|
|
|
|
|
] = None,
|
|
|
|
|
is_inference: bool = False,
|
|
|
|
|
is_const_graph: bool = False,
|
|
|
|
|
const_output_index: Optional[Dict[str, int]] = None,
|
|
|
|
|
const_code: Optional[str] = None,
|
|
|
|
|
const_module: Optional["GraphLowering"] = None,
|
|
|
|
|
name: Optional[str] = None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(gm)
|
|
|
|
|
self.example_inputs = example_inputs
|
|
|
|
@ -370,7 +391,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
|
|
|
|
self.name_to_op: Dict[str, ir.Operation] = {}
|
|
|
|
|
self.creation_time = time.time()
|
|
|
|
|
self.name = name
|
|
|
|
|
self.name = name # type: ignore[assignment]
|
|
|
|
|
self.cpp_wrapper = cpp_wrapper
|
|
|
|
|
|
|
|
|
|
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
|
|
|
@ -407,7 +428,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
|
|
|
|
|
"dynamo_flat_name_to_original_fqn", {}
|
|
|
|
|
)
|
|
|
|
|
self.allocated_constant_name = (
|
|
|
|
|
self.allocated_constant_name: Dict[str, str] = (
|
|
|
|
|
const_module.allocated_constant_name if const_module is not None else {}
|
|
|
|
|
)
|
|
|
|
|
init_backend_registration()
|
|
|
|
@ -417,12 +438,14 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.aligned_inputs: Set[str] = set()
|
|
|
|
|
self.no_fuse_buffer_names: Set[str] = set()
|
|
|
|
|
|
|
|
|
|
def has_feature(self, device, feature):
|
|
|
|
|
def has_feature(
|
|
|
|
|
self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
|
|
|
|
|
) -> bool:
|
|
|
|
|
assert isinstance(feature, BackendFeature), feature
|
|
|
|
|
return feature in self.get_backend_features(get_device_type(device))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def decide_layout_opt(gm, *, is_inference) -> bool:
|
|
|
|
|
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Decide if we should enable layout optimization for this graph based on
|
|
|
|
|
heuristics.
|
|
|
|
@ -470,19 +493,21 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def is_grouped(n):
|
|
|
|
|
return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
|
|
|
|
|
def is_grouped(n: Any) -> bool:
|
|
|
|
|
meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
|
|
|
|
|
assert isinstance(meta_val, torch.Tensor)
|
|
|
|
|
return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
|
|
|
|
|
|
|
|
|
|
def is_in_out_channel(n):
|
|
|
|
|
def is_in_out_channel(n: torch.fx.Node) -> bool:
|
|
|
|
|
return (
|
|
|
|
|
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
|
|
|
|
|
and n.args[1].meta["val"].size(2) > 1
|
|
|
|
|
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
|
|
|
|
|
and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def is_small_channel(n):
|
|
|
|
|
def is_small_channel(n: torch.fx.Node) -> bool:
|
|
|
|
|
return (
|
|
|
|
|
n.args[1].meta["val"].size(0) <= 64
|
|
|
|
|
and n.args[1].meta["val"].size(1) <= 64
|
|
|
|
|
n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
|
|
|
|
|
and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# only grouped convolutions benchmarked as slower in conv samples for inference only
|
|
|
|
@ -614,7 +639,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
name=self.qualify_name(subgraph_name),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def find_nodes_prefer_channels_last(self):
|
|
|
|
|
def find_nodes_prefer_channels_last(self) -> Set[Node]:
|
|
|
|
|
"""
|
|
|
|
|
The rule to decide if an node prefer channels last is simple.
|
|
|
|
|
1. if it's input/output of a convolution
|
|
|
|
@ -664,12 +689,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
return output_set
|
|
|
|
|
|
|
|
|
|
def warn_fallback(self, name):
|
|
|
|
|
def warn_fallback(self, name: str) -> None:
|
|
|
|
|
if name not in self._warned_fallback:
|
|
|
|
|
self._warned_fallback.add(name)
|
|
|
|
|
perf_hint_log.info("Using FallbackKernel: %s", name)
|
|
|
|
|
|
|
|
|
|
def add_device_info(self, device: torch.device):
|
|
|
|
|
def add_device_info(self, device: torch.device) -> None:
|
|
|
|
|
self.device_types.add(device.type)
|
|
|
|
|
if device.index is not None:
|
|
|
|
|
self.device_idxs.add(device.index)
|
|
|
|
@ -677,10 +702,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.device_node_mapping[device] = V.graph.current_node
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def fake_mode(self):
|
|
|
|
|
def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
|
|
|
|
|
return V.fake_mode
|
|
|
|
|
|
|
|
|
|
def get_buffer(self, buffer_name: str):
|
|
|
|
|
def try_get_buffer(
|
|
|
|
|
self, buffer_name: str
|
|
|
|
|
) -> Optional[Union[ir.TensorBox, ir.Buffer]]:
|
|
|
|
|
if buffer_name in self.name_to_buffer:
|
|
|
|
|
return self.name_to_buffer[buffer_name]
|
|
|
|
|
if buffer_name in self.graph_inputs:
|
|
|
|
@ -693,9 +720,16 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_dtype(self, buffer_name: str):
|
|
|
|
|
def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]:
|
|
|
|
|
buf = self.try_get_buffer(buffer_name)
|
|
|
|
|
if buf is not None:
|
|
|
|
|
return buf
|
|
|
|
|
raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
|
|
|
|
|
|
|
|
|
|
def get_dtype(self, buffer_name: str) -> torch.dtype:
|
|
|
|
|
if buffer_name in self.constants:
|
|
|
|
|
return self.constants[buffer_name].dtype
|
|
|
|
|
if buffer_name in self.name_to_buffer:
|
|
|
|
@ -707,7 +741,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
return self.get_dtype(m.group(1))
|
|
|
|
|
raise KeyError(f"could not find {buffer_name}")
|
|
|
|
|
|
|
|
|
|
def get_numel(self, buffer_name: str):
|
|
|
|
|
def get_numel(self, buffer_name: str) -> Union[int, Expr]:
|
|
|
|
|
from .ir import MultiOutputLayout
|
|
|
|
|
|
|
|
|
|
if buffer_name in self.constants:
|
|
|
|
@ -722,10 +756,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
raise KeyError(f"could not find {buffer_name}")
|
|
|
|
|
|
|
|
|
|
@dynamo_timed
|
|
|
|
|
def run(self, *args):
|
|
|
|
|
def run(self, *args: Any) -> Any:
|
|
|
|
|
return super().run(*args)
|
|
|
|
|
|
|
|
|
|
def register_operation(self, op: ir.Operation):
|
|
|
|
|
def register_operation(self, op: ir.Operation) -> str:
|
|
|
|
|
assert op.operation_name is None, f"Operation registered twice: {op}"
|
|
|
|
|
assert isinstance(op, ir.Operation)
|
|
|
|
|
name = self.qualify_name(f"op{len(self.operations)}")
|
|
|
|
@ -734,7 +768,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
op.operation_name = name
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False):
|
|
|
|
|
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
|
|
|
|
|
name = self.qualify_name(f"buf{len(self.buffers)}")
|
|
|
|
|
self.buffers.append(buffer)
|
|
|
|
|
self.name_to_buffer[name] = buffer
|
|
|
|
@ -754,8 +788,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.lists[name] = operation_names
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
def register_users_of(self, node_output):
|
|
|
|
|
def register(value):
|
|
|
|
|
def register_users_of(
|
|
|
|
|
self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
|
|
|
|
|
) -> None:
|
|
|
|
|
def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
|
|
|
|
|
if isinstance(value, (list, tuple)):
|
|
|
|
|
for x in value:
|
|
|
|
|
register(x)
|
|
|
|
@ -765,7 +801,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
register(node_output)
|
|
|
|
|
|
|
|
|
|
def mark_buffer_mutated(self, name: str):
|
|
|
|
|
def mark_buffer_mutated(self, name: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
When a buffer is mutated we need to make sure all the reads to
|
|
|
|
|
the old version are realized before the mutation happens.
|
|
|
|
@ -779,7 +815,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
for user in self.name_to_users[name]:
|
|
|
|
|
user.realize()
|
|
|
|
|
|
|
|
|
|
def get_original_value_of_constant(self, name: str):
|
|
|
|
|
def get_original_value_of_constant(self, name: str) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
In AOTI, module buffers may have been mutated during the tracing and compilation.
|
|
|
|
|
Thus we need to read from previously stored original buffers, to make sure the
|
|
|
|
@ -795,7 +831,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
else self.constants[name]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def allocate_non_dup_const_name(self, name, data):
|
|
|
|
|
def allocate_non_dup_const_name(
|
|
|
|
|
self, name: Optional[str], data: Union[Tensor]
|
|
|
|
|
) -> str:
|
|
|
|
|
orig_name = name
|
|
|
|
|
if not config.aot_inductor.use_runtime_constant_folding:
|
|
|
|
|
for constant_name, value in self.constants.items():
|
|
|
|
@ -813,6 +851,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
if name is None:
|
|
|
|
|
name = f"constant{len(self.constants)}"
|
|
|
|
|
assert name is not None
|
|
|
|
|
if name[0].isdigit():
|
|
|
|
|
name = f"constant_{name}"
|
|
|
|
|
name = self.qualify_name(name)
|
|
|
|
@ -830,10 +869,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
f"{tuple(data.size())!r} {tuple(data.stride())!r} "
|
|
|
|
|
f"{hash(data):x}"
|
|
|
|
|
)
|
|
|
|
|
self.allocated_constant_name[name] = orig_name
|
|
|
|
|
self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
def add_tensor_constant(self, data, name=None):
|
|
|
|
|
def add_tensor_constant(
|
|
|
|
|
self, data: Tensor, name: Optional[str] = None
|
|
|
|
|
) -> TensorBox:
|
|
|
|
|
new_name = self.allocate_non_dup_const_name(name, data)
|
|
|
|
|
return TensorBox.create(
|
|
|
|
|
ir.ConstantBuffer(
|
|
|
|
@ -842,7 +883,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def constant_name(self, name: str, device_override: Optional[torch.device]):
|
|
|
|
|
def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
We AOT copy constants to the devices they are needed on.
|
|
|
|
|
If device_override doesn't match the constant's device, then
|
|
|
|
@ -858,7 +899,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.constants[name].to(device_override),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def placeholder(self, target: str, args, kwargs):
|
|
|
|
|
def placeholder(
|
|
|
|
|
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
|
|
|
|
) -> Union[Expr, TensorBox, None]:
|
|
|
|
|
example = super().placeholder(target, args, kwargs)
|
|
|
|
|
self.graph_input_names.append(target)
|
|
|
|
|
if isinstance(example, SymTypes):
|
|
|
|
@ -882,7 +925,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
# the first N inputs are weights
|
|
|
|
|
sizes, strides = self.static_sizes_strides(example)
|
|
|
|
|
else:
|
|
|
|
|
sizes, strides = self.symbolic_sizes_strides(example)
|
|
|
|
|
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
|
|
|
|
|
# TODO(jansel): handle input aliasing
|
|
|
|
|
target = self.qualify_name(target)
|
|
|
|
|
tensor = TensorBox.create(
|
|
|
|
@ -912,7 +955,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.aligned_inputs.add(target)
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
def call_function(self, target, args, kwargs):
|
|
|
|
|
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
|
|
|
|
|
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
|
|
|
|
return super().call_function(target, args, kwargs)
|
|
|
|
|
|
|
|
|
@ -923,7 +966,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
# passthrough lowerings from .pattern_matcher
|
|
|
|
|
return target(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def get_custom_op_layout_constraints(target, args, kwargs):
|
|
|
|
|
def get_custom_op_layout_constraints(
|
|
|
|
|
target: torch._ops.OpOverload, args: Any, kwargs: Dict[str, Any]
|
|
|
|
|
) -> Tuple[Optional[Callable], Tuple[Any], Dict[str, Any]]: # type: ignore[type-arg]
|
|
|
|
|
# Custom operations that require preserving stride order
|
|
|
|
|
# which run through implicit fallback must constrain their
|
|
|
|
|
# arguments' fx strides
|
|
|
|
@ -971,8 +1016,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
raise MissingOperatorWithoutDecomp(target, args, kwargs)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
log.debug(" via %s", lowerings[target])
|
|
|
|
|
out = lowerings[target](*args, **kwargs)
|
|
|
|
|
log.debug(" via %s", lowerings[target]) # type: ignore[index]
|
|
|
|
|
out = lowerings[target](*args, **kwargs) # type: ignore[index]
|
|
|
|
|
return out
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise LoweringException(e, target, args, kwargs).with_traceback(
|
|
|
|
@ -986,9 +1031,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
"""
|
|
|
|
|
return len(t.shape) == 1 and t.shape[0] <= 8
|
|
|
|
|
|
|
|
|
|
def get_attr(self, target, args, kwargs):
|
|
|
|
|
def get_attr(
|
|
|
|
|
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
|
|
|
|
|
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
|
|
|
|
# this is a constant
|
|
|
|
|
value = getattr_recursive(self.module, target)
|
|
|
|
|
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
|
|
|
|
|
|
|
|
|
if isinstance(value, torch.fx.GraphModule):
|
|
|
|
|
return ir.Subgraph(name=target, graph_module=value)
|
|
|
|
@ -998,6 +1045,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.constant_reprs[target] = ""
|
|
|
|
|
return TorchBindObject(target, value)
|
|
|
|
|
|
|
|
|
|
assert isinstance(value, torch.Tensor)
|
|
|
|
|
if (
|
|
|
|
|
config.aot_inductor.use_runtime_constant_folding
|
|
|
|
|
or config.always_keep_tensor_constants
|
|
|
|
@ -1017,13 +1065,15 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
return self.add_tensor_constant(value, target)
|
|
|
|
|
|
|
|
|
|
def call_module(self, target, args, kwargs):
|
|
|
|
|
def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
|
|
|
|
|
raise AssertionError
|
|
|
|
|
|
|
|
|
|
def call_method(self, target, args, kwargs):
|
|
|
|
|
def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
|
|
|
|
|
raise AssertionError
|
|
|
|
|
|
|
|
|
|
def output(self, target, args, kwargs):
|
|
|
|
|
def output(
|
|
|
|
|
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
|
|
|
|
) -> None:
|
|
|
|
|
result = super().output(target, args, kwargs)
|
|
|
|
|
if not isinstance(result, (tuple, list)):
|
|
|
|
|
# nested subgraphs can have singleton outputs
|
|
|
|
@ -1099,12 +1149,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.graph_id if self.graph_id is not None else -1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def finalize(self):
|
|
|
|
|
def finalize(self) -> None:
|
|
|
|
|
for buf in self.buffers:
|
|
|
|
|
buf.decide_layout()
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def set_current_node(self, node: torch.fx.Node):
|
|
|
|
|
def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
|
|
|
|
|
old = self.current_node
|
|
|
|
|
try:
|
|
|
|
|
self.current_node = node
|
|
|
|
@ -1114,9 +1164,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
def try_match_insignificant_strides(
|
|
|
|
|
self,
|
|
|
|
|
tensor,
|
|
|
|
|
tensor: Union[ir.TensorBox, ir.BaseView],
|
|
|
|
|
meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
|
|
|
|
|
) -> ir.TensorBox:
|
|
|
|
|
) -> Union[ir.TensorBox, ir.BaseView]:
|
|
|
|
|
"""
|
|
|
|
|
Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
|
|
|
|
|
dimensions - size 0 or 1 - will be updated.
|
|
|
|
@ -1135,9 +1185,13 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.sizevars.statically_known_equals(s1, s2)
|
|
|
|
|
for s1, s2 in zip(meta_strides, tensor.get_stride())
|
|
|
|
|
):
|
|
|
|
|
return tensor
|
|
|
|
|
return tensor # type: ignore[arg-type]
|
|
|
|
|
|
|
|
|
|
def significant_strides_equal(shape, meta_strides, tensor_strides):
|
|
|
|
|
def significant_strides_equal(
|
|
|
|
|
shape: Sequence[Union[Expr, int]],
|
|
|
|
|
meta_strides: Sequence[Union[Expr, int]],
|
|
|
|
|
tensor_strides: Sequence[Union[Expr, int]],
|
|
|
|
|
) -> bool:
|
|
|
|
|
for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
|
|
|
|
|
if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type]
|
|
|
|
|
continue
|
|
|
|
@ -1167,8 +1221,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
)
|
|
|
|
|
return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
|
|
|
|
|
|
|
|
|
|
def run_node(self, n: torch.fx.Node):
|
|
|
|
|
def debug(msg):
|
|
|
|
|
def run_node(self, n: torch.fx.Node) -> object:
|
|
|
|
|
def debug(msg: str) -> None:
|
|
|
|
|
log.debug("lowering %s %s", LazyString(n.format_node), msg)
|
|
|
|
|
|
|
|
|
|
buffer_watermark = len(self.buffers)
|
|
|
|
@ -1193,7 +1247,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
elif n.op == "call_function" and n.target in layout_constraints:
|
|
|
|
|
debug("layout_constraints")
|
|
|
|
|
args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
|
|
|
|
|
result = self.call_function(n.target, args, kwargs)
|
|
|
|
|
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
|
|
|
|
|
elif is_magic_method(n.target):
|
|
|
|
|
# TODO: this is sus, it probably should be handled in the
|
|
|
|
|
# lowerings themselves similarly to sym_size/sym-stride
|
|
|
|
@ -1392,7 +1446,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
for op in self.operations[operation_watermark:]:
|
|
|
|
|
new_unbacked_defs |= op.get_unbacked_symbol_defs()
|
|
|
|
|
|
|
|
|
|
def format_new_defs():
|
|
|
|
|
def format_new_defs() -> str:
|
|
|
|
|
r = []
|
|
|
|
|
for buf in self.buffers[buffer_watermark:]:
|
|
|
|
|
r.append(
|
|
|
|
@ -1427,7 +1481,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
# This is all doable, it just hasn't been done yet.
|
|
|
|
|
shape_env = V.graph.sizevars.shape_env
|
|
|
|
|
|
|
|
|
|
def make_assert(expr, msg):
|
|
|
|
|
def make_assert(expr: Expr, msg: str) -> None:
|
|
|
|
|
assert_op = ir.AssertScalar(expr, msg)
|
|
|
|
|
self.register_buffer(assert_op, set_name=True)
|
|
|
|
|
self.register_operation(assert_op)
|
|
|
|
@ -1438,7 +1492,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
vr = shape_env.var_to_range[i0]
|
|
|
|
|
if not shape_env._default_unspecified_value_range().issubset(vr):
|
|
|
|
|
|
|
|
|
|
def is_convertible(s):
|
|
|
|
|
def is_convertible(s: Expr) -> bool:
|
|
|
|
|
if s in (int_oo, -int_oo):
|
|
|
|
|
return False
|
|
|
|
|
try:
|
|
|
|
@ -1492,7 +1546,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def validate_can_generate_cpp_wrapper(self):
|
|
|
|
|
def validate_can_generate_cpp_wrapper(self) -> None:
|
|
|
|
|
if config.disable_cpp_codegen:
|
|
|
|
|
raise CppWrapperCodeGenError("C++ codegen is disabled")
|
|
|
|
|
|
|
|
|
@ -1511,7 +1565,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
|
|
|
|
|
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
|
|
|
|
|
|
|
|
|
|
def init_wrapper_code(self):
|
|
|
|
|
def init_wrapper_code(self) -> None:
|
|
|
|
|
self.cuda = "cuda" in self.device_types
|
|
|
|
|
if self.cpp_wrapper:
|
|
|
|
|
self.validate_can_generate_cpp_wrapper()
|
|
|
|
@ -1541,7 +1595,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.const_module.wrapper_code.src_to_kernel
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def codegen_with_cpp_wrapper(self):
|
|
|
|
|
def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]:
|
|
|
|
|
"""
|
|
|
|
|
For CPU, the cpp wrapper codegen is done in one pass.
|
|
|
|
|
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
|
|
|
|
@ -1559,7 +1613,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
|
|
|
|
|
if not config.triton.autotune_at_compile_time:
|
|
|
|
|
|
|
|
|
|
def materialize(x):
|
|
|
|
|
def materialize(
|
|
|
|
|
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
|
|
|
|
|
) -> Union[int, float, torch.Tensor]:
|
|
|
|
|
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
|
|
|
|
# Need concrete value to run dynamic shapes and tune the result
|
|
|
|
|
return x.node.hint
|
|
|
|
@ -1617,7 +1673,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
# f, the inputs x will be mutated twice in the process:
|
|
|
|
|
# once here, and again when running the compiled model;
|
|
|
|
|
# this will also lead to a numerically incorrect output
|
|
|
|
|
real_inputs[idx] = clone_preserve_strides(real_inputs[idx])
|
|
|
|
|
mutated_inp = real_inputs[idx]
|
|
|
|
|
assert isinstance(mutated_inp, torch.Tensor)
|
|
|
|
|
real_inputs[idx] = clone_preserve_strides(mutated_inp)
|
|
|
|
|
del mutated_inp
|
|
|
|
|
|
|
|
|
|
with torch.utils._python_dispatch._disable_current_modes():
|
|
|
|
|
compiled(real_inputs)
|
|
|
|
@ -1636,7 +1695,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
# cpu
|
|
|
|
|
return self.codegen()
|
|
|
|
|
|
|
|
|
|
def codegen(self):
|
|
|
|
|
def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
|
|
|
|
|
from .scheduler import Scheduler
|
|
|
|
|
|
|
|
|
|
self.init_wrapper_code()
|
|
|
|
@ -1650,7 +1709,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.wrapper_code.pop_codegened_graph()
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def codegen_subgraph(self, parent_graph):
|
|
|
|
|
def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
|
|
|
|
|
"""
|
|
|
|
|
This is a more compact version of the `codegen()` above
|
|
|
|
|
where we codegen this graph as a subgraph of some parent
|
|
|
|
@ -1669,7 +1728,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
self.scheduler = Scheduler(self.operations)
|
|
|
|
|
self.scheduler.codegen()
|
|
|
|
|
|
|
|
|
|
def count_bytes(self):
|
|
|
|
|
def count_bytes(
|
|
|
|
|
self,
|
|
|
|
|
) -> Tuple[
|
|
|
|
|
int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]]
|
|
|
|
|
]:
|
|
|
|
|
total_bytes = 0
|
|
|
|
|
node_counts = []
|
|
|
|
|
node_runtimes = []
|
|
|
|
@ -1678,15 +1741,16 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
total_bytes += num_bytes
|
|
|
|
|
node_counts.append((node, num_bytes // 4))
|
|
|
|
|
node_runtimes.append((node, node.get_estimated_runtime()))
|
|
|
|
|
|
|
|
|
|
return total_bytes, node_counts, node_runtimes
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def save_output_code(code: str):
|
|
|
|
|
def save_output_code(code: str) -> None:
|
|
|
|
|
# No-op to be patched for unit tests
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@dynamo_timed(phase_name="code_gen", fwd_only=False)
|
|
|
|
|
def compile_to_module(self):
|
|
|
|
|
def compile_to_module(self) -> ModuleType:
|
|
|
|
|
from .codecache import PyCodeCache
|
|
|
|
|
|
|
|
|
|
code, linemap = (
|
|
|
|
@ -1696,7 +1760,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
GraphLowering.save_output_code(code)
|
|
|
|
|
output_code_log.debug("Output code: \n%s", code)
|
|
|
|
|
try:
|
|
|
|
|
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
|
|
|
|
|
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]
|
|
|
|
|
key, path = PyCodeCache.write(code)
|
|
|
|
|
except Exception:
|
|
|
|
|
trace_structured(
|
|
|
|
@ -1715,12 +1779,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
mod = PyCodeCache.load_by_key_path(
|
|
|
|
|
key,
|
|
|
|
|
path,
|
|
|
|
|
linemap=linemap,
|
|
|
|
|
linemap=linemap, # type: ignore[arg-type]
|
|
|
|
|
attrs={**self.constants, **self.torchbind_constants},
|
|
|
|
|
)
|
|
|
|
|
self.cache_key = key
|
|
|
|
|
self.cache_path = path
|
|
|
|
|
self.cache_linemap = linemap
|
|
|
|
|
self.cache_linemap = linemap # type: ignore[assignment]
|
|
|
|
|
|
|
|
|
|
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
|
|
|
|
# TODO. Revisit this once the logging API is more mature
|
|
|
|
@ -1735,7 +1799,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
|
|
|
|
return mod
|
|
|
|
|
|
|
|
|
|
def compile_to_fn(self):
|
|
|
|
|
def compile_to_fn(self) -> Any:
|
|
|
|
|
if self.aot_mode:
|
|
|
|
|
from .codecache import AotCodeCompiler
|
|
|
|
|
|
|
|
|
@ -1764,7 +1828,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
else:
|
|
|
|
|
return self.compile_to_module().call
|
|
|
|
|
|
|
|
|
|
def get_output_names(self):
|
|
|
|
|
def get_output_names(self) -> List[str]:
|
|
|
|
|
return [
|
|
|
|
|
node.get_name()
|
|
|
|
|
for node in self.graph_outputs
|
|
|
|
@ -1772,7 +1836,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|
|
|
|
and not isinstance(node, ir.ShapeAsConstantBuffer)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def is_unspec_arg(self, name: str):
|
|
|
|
|
def is_unspec_arg(self, name: str) -> bool:
|
|
|
|
|
# dynamo wraps unspec variable as 0d CPU tensor,
|
|
|
|
|
# need to convert to scalar during codegen (triton only)
|
|
|
|
|
return (
|
|
|
|
|