mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Feature Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`: - Submodules as stored as attributes, and accessed via `getattr`. - The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs. # Implementation overview The FX backend generates code for subgraphs using the following steps: 1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`. a. We also codegen the true/false subgraphs at this time, storing their subgms for later. 2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export. 3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`. # Implementation details This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs. Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`. Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning. In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.) Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR. This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs. # Test plan This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}}) return buf1 ``` It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234 Approved by: https://github.com/angelayi, https://github.com/jansel
1090 lines
38 KiB
Python
1090 lines
38 KiB
Python
import dataclasses
|
|
import functools
|
|
import logging
|
|
import operator
|
|
import textwrap
|
|
from collections import Counter
|
|
from collections.abc import Sequence
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._export.passes._node_metadata_hook import (
|
|
_node_metadata_hook,
|
|
_set_node_metadata_hook,
|
|
)
|
|
from torch._export.utils import _detect_fake_mode_from_gm
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
TraceableTritonKernelWrapper,
|
|
tracing_triton_hopifier_singleton,
|
|
triton_kernel_wrapper_mutation,
|
|
)
|
|
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
|
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
|
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
|
|
from torch._inductor.utils import convert_shape_to_symint, sympy_product
|
|
from torch._inductor.virtualized import V
|
|
from torch._library.triton import wrap_triton
|
|
from torch.fx import GraphModule
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
|
|
from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
|
|
from torch.utils._sympy.solve import try_solve
|
|
|
|
from .. import config, ir
|
|
from ..runtime.triton_compat import Config
|
|
from ..utils import cache_property_on_self, LineContext, ValueWithLineMap
|
|
from .common import (
|
|
CodegenSymbol,
|
|
FileBackedGraphModule,
|
|
WorkspaceArg,
|
|
WorkspaceZeroMode,
|
|
)
|
|
from .wrapper import (
|
|
AllocateLine,
|
|
BufferLike,
|
|
CommBufferAllocateLine,
|
|
CommBufferFreeLine,
|
|
CommentLine,
|
|
ConditionalLine,
|
|
EnterDeviceContextManagerLine,
|
|
EnterSubgraphLine,
|
|
ExitDeviceContextManagerLine,
|
|
ExitSubgraphLine,
|
|
ExternKernelAllocLine,
|
|
ExternKernelOutLine,
|
|
FreeIfNotReusedLine,
|
|
FreeLine,
|
|
IndexPutFallbackLine,
|
|
KernelCallLine,
|
|
KernelDefinitionLine,
|
|
Line,
|
|
MultiOutputLine,
|
|
NullLine,
|
|
PythonWrapperCodegen,
|
|
ReinterpretLine,
|
|
ReuseLine,
|
|
ScatterFallbackLine,
|
|
SubgraphPythonWrapperCodegen,
|
|
SymbolicCallArg,
|
|
SymbolicCallArgLine,
|
|
WrapperLine,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SymbolBuffer(CodegenSymbol):
|
|
"""
|
|
Represents a sympy.Symbol graph input.
|
|
"""
|
|
|
|
symbol: sympy.Symbol
|
|
|
|
def get_name(self) -> str:
|
|
return str(self.symbol)
|
|
|
|
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
|
return self.symbol
|
|
|
|
|
|
CodegenBuffer = Union[BufferLike, SymbolBuffer]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TritonKernel:
|
|
"""
|
|
Stores metadata about Triton kernels for use in FX.
|
|
"""
|
|
|
|
tuner: CachingAutotuner
|
|
wrapped: TraceableTritonKernelWrapper
|
|
|
|
|
|
def replace_floor_div(expr: sympy.Expr) -> sympy.Expr:
|
|
"""
|
|
Replace sympy.floor with FloorDiv.
|
|
"""
|
|
|
|
def replace(expr: sympy.Expr) -> sympy.Expr:
|
|
expr = sympy.together(expr)
|
|
|
|
# Find division operations in the sympy.floor expression
|
|
# Div is either represented as Mul with:
|
|
# Rational denominator or Pow with negative exponent
|
|
if not isinstance(expr, sympy.core.mul.Mul):
|
|
return sympy.floor(expr)
|
|
|
|
if isinstance(expr.args[0], sympy.Rational):
|
|
frac = expr.args[0]
|
|
numerator = sympy_product(expr.args[1:]) * frac.numerator
|
|
denominator = frac.denominator
|
|
|
|
return FloorDiv(numerator, denominator)
|
|
elif isinstance(expr.args[0], sympy.Pow):
|
|
base = expr.args[0].base
|
|
exp = expr.args[0].exp
|
|
numerator = sympy_product(expr.args[1:])
|
|
if exp < 0:
|
|
denominator = base ** (-exp)
|
|
else:
|
|
numerator = numerator * (base**exp)
|
|
denominator = 1
|
|
return FloorDiv(numerator, denominator)
|
|
else:
|
|
return sympy.floor(expr)
|
|
|
|
return expr.replace(sympy.floor, replace)
|
|
|
|
|
|
class WrapperFxCodegen(PythonWrapperCodegen):
|
|
"""
|
|
Backend to generate wrapper code as an FX IR graph.
|
|
"""
|
|
|
|
supports_caching = False
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
super().__init__(*args, **kwargs)
|
|
self.subgms: dict[str, torch.fx.GraphModule] = {}
|
|
|
|
def codegen_inputs(self) -> None:
|
|
"""
|
|
This would generate code for symbolic input shapes, strides, etc.
|
|
Since the FX converter handles this, do nothing here.
|
|
"""
|
|
|
|
def codegen_conditional(self, conditional: ir.Conditional) -> None:
|
|
"""
|
|
Conditional codegen normally emits a number of different wrapper lines.
|
|
Instead, FX conversion uses a dedicated line for the whole conditional.
|
|
"""
|
|
self.writeline(ConditionalLine(self, conditional))
|
|
for subgraph in (conditional.true_subgraph, conditional.false_subgraph):
|
|
self.codegen_subgraph_common(subgraph)
|
|
|
|
def define_subgraph_launcher_fn(
|
|
self, name: str, subgraph_code: Union[ValueWithLineMap, FileBackedGraphModule]
|
|
) -> None:
|
|
"""
|
|
Record subgms as they're generated.
|
|
"""
|
|
assert isinstance(subgraph_code, FileBackedGraphModule)
|
|
self.subgms[name] = subgraph_code.gm
|
|
|
|
@property
|
|
@cache_property_on_self
|
|
def is_subgraph(self) -> bool:
|
|
return isinstance(self, SubgraphPythonWrapperCodegen)
|
|
|
|
def get_fx_graph_inputs(
|
|
self,
|
|
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
|
|
"""
|
|
Get the input nodes corresponding to FX graph placeholders.
|
|
"""
|
|
if V.aot_compilation and not self.is_subgraph:
|
|
# AOT graphs must match the signature of the input module.
|
|
return {
|
|
node.name: V.graph.graph_inputs.get(node.name)
|
|
for node in V.graph.module.graph.find_nodes(op="placeholder") # type: ignore[operator, union-attr]
|
|
}
|
|
|
|
return self.get_graph_inputs()
|
|
|
|
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
|
|
self.run_wrapper_ir_passes(is_inference)
|
|
|
|
prologue = "\n".join(
|
|
[
|
|
self.imports.getvalue(),
|
|
self.header.getvalue(),
|
|
]
|
|
)
|
|
gm = FxConverter(
|
|
lines=self.lines,
|
|
prologue=prologue,
|
|
graph_inputs=self.get_fx_graph_inputs(),
|
|
graph_outputs=self.get_graph_outputs(),
|
|
subgms=self.subgms,
|
|
is_subgraph=self.is_subgraph,
|
|
).generate()
|
|
|
|
compiled_fn = self.compile_graph(gm)
|
|
|
|
return FileBackedGraphModule(gm, compiled_fn), None
|
|
|
|
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
|
|
"""
|
|
Converts the graph module into a runnable function. The default implementation
|
|
is simply an interpreter calling kernels in eager mode. Derived backends can
|
|
override this to do further compilation.
|
|
"""
|
|
return gm.forward
|
|
|
|
def write_header(self) -> None:
|
|
"""
|
|
Python subgraphs normally lack headers.
|
|
Override this behavior to generate prologues for FX subgraphs.
|
|
"""
|
|
PythonWrapperCodegen.write_header(self)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls: type["WrapperFxCodegen"],
|
|
is_subgraph: bool,
|
|
subgraph_name: Optional[str],
|
|
parent_wrapper: Optional[PythonWrapperCodegen],
|
|
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
|
) -> "WrapperFxCodegen":
|
|
if is_subgraph:
|
|
assert subgraph_name is not None
|
|
assert parent_wrapper is not None
|
|
|
|
# Subgraphs override some methods of PythonWrapperCodegen.
|
|
# Apply these overrides to the user-provided class, with priority given to
|
|
# user-provided methods.
|
|
class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen): # type: ignore[misc,valid-type]
|
|
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
|
|
"""
|
|
Skip graph compilation for subgraphs.
|
|
"""
|
|
|
|
def crash_if_run(*args: Any) -> None:
|
|
raise NotImplementedError("Cannot run a subgraph in isolation!")
|
|
|
|
return crash_if_run
|
|
|
|
return SubgraphFxWrapperCodegen(
|
|
subgraph_name, parent_wrapper, partition_signatures
|
|
)
|
|
|
|
return cls()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FxConverter:
|
|
"""
|
|
Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
|
|
input and output code are stored as attributes.
|
|
"""
|
|
|
|
lines: list[Line]
|
|
prologue: str
|
|
graph_inputs: dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]
|
|
graph_outputs: list[ir.IRNode]
|
|
subgms: dict[str, torch.fx.GraphModule]
|
|
is_subgraph: bool
|
|
|
|
def __post_init__(self) -> None:
|
|
graph = torch.fx.Graph()
|
|
self.gm = GraphModule({}, graph) # Wrapper FX IR.
|
|
self.buffer_to_node: dict[
|
|
Optional[str], torch.fx.Node
|
|
] = {} # Symbol table for codegen.
|
|
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
|
|
self._unique_symbol_ids: Counter[str] = Counter()
|
|
self.tracer = torch.fx.proxy.GraphAppendingTracer(graph)
|
|
self.expr_to_proxy: dict[sympy.Expr, torch.fx.Proxy] = {}
|
|
|
|
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
|
|
"""
|
|
Imports a kernel from source, possibly autotuning block parameters.
|
|
"""
|
|
module_code = "\n".join([self.prologue, code])
|
|
mod = PyCodeCache.load(module_code)
|
|
kernel = getattr(mod, kernel_name)
|
|
|
|
if isinstance(kernel, LambdaFuture):
|
|
kernel = kernel.result()
|
|
|
|
if not isinstance(kernel, CachingAutotuner):
|
|
raise NotImplementedError(
|
|
textwrap.dedent(f"""
|
|
Unsupported type for kernel {kernel_name}: {type(kernel)}.
|
|
FX conversion only supports Triton kernels.
|
|
""")
|
|
)
|
|
|
|
return kernel
|
|
|
|
def _fake_tensor(
|
|
self,
|
|
size: tuple[Any, ...],
|
|
stride: tuple[Any, ...],
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
with V.fake_mode:
|
|
return torch.empty_strided(
|
|
convert_shape_to_symint(size),
|
|
convert_shape_to_symint(stride),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def _create_as_strided(
|
|
self,
|
|
input_node: torch.fx.Node,
|
|
size: tuple[Any, ...],
|
|
stride: tuple[Any, ...],
|
|
offset: Union[int, sympy.Expr],
|
|
) -> torch.fx.Node:
|
|
return self.gm.graph.call_function(
|
|
torch.as_strided,
|
|
args=(
|
|
input_node,
|
|
self._generate_sym_nodes(size),
|
|
self._generate_sym_nodes(stride),
|
|
self._generate_sym_node(offset),
|
|
),
|
|
)
|
|
|
|
def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
|
|
"""
|
|
Updates the symbol table to record that an Inductor buffer maps to the result of
|
|
an FX node.
|
|
"""
|
|
assert node not in self.buffer_to_node
|
|
self.buffer_to_node[buffer.get_name()] = node
|
|
|
|
def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
|
|
"""
|
|
Removes the buffer from the symbol table.
|
|
"""
|
|
name = buffer.get_name()
|
|
del self.buffer_to_node[name]
|
|
|
|
def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
|
"""
|
|
Maps call args back to FX nodes.
|
|
"""
|
|
return tuple(
|
|
self.buffer_to_node[arg]
|
|
if isinstance(arg, str)
|
|
else arg.inner_expr
|
|
if isinstance(arg, SymbolicCallArg)
|
|
else arg
|
|
for arg in args
|
|
)
|
|
|
|
def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
|
|
"""
|
|
Extract buffer data from an IR node.
|
|
"""
|
|
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
|
return node
|
|
elif isinstance(node, (ir.BaseView, ir.MutableBox)):
|
|
return self._get_buffer(node.data)
|
|
elif isinstance(node, sympy.Symbol):
|
|
return SymbolBuffer(node)
|
|
else:
|
|
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
|
|
|
|
def _generate_graph_inputs(self) -> None:
|
|
"""
|
|
Converts graph inputs to FX placeholders.
|
|
"""
|
|
|
|
for name, ir_node in self.graph_inputs.items():
|
|
if ir_node is None:
|
|
# Create dummy input nodes to match the input signature
|
|
self.gm.graph.placeholder(name)
|
|
continue
|
|
|
|
# Introduce a new symbol for constant inputs.
|
|
buffer = (
|
|
SymbolBuffer(sympy.Symbol(name, is_integer=True))
|
|
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
|
|
else self._get_buffer(ir_node)
|
|
)
|
|
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
|
|
placeholder_node.meta["val"] = buffer.get_example()
|
|
self._record_allocation(buffer, placeholder_node)
|
|
|
|
def _generate_graph_input_shapes(self) -> None:
|
|
"""
|
|
Generate nodes creating symints that are part of graph input
|
|
shape/strides.
|
|
"""
|
|
|
|
def _codegen_symbol(
|
|
sym_or_exp: Union[sympy.Symbol, sympy.Expr],
|
|
base_node: torch.fx.Node,
|
|
target: torch._ops.OpOverload,
|
|
dim: int,
|
|
) -> None:
|
|
def codegen_proxy() -> torch.fx.Proxy:
|
|
size_node = self.gm.graph.call_function(target, (base_node, dim))
|
|
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
|
|
self.expr_to_proxy[sym_or_exp] = size_proxy
|
|
return size_proxy
|
|
|
|
if isinstance(sym_or_exp, sympy.Symbol):
|
|
if sym_or_exp in self.expr_to_proxy:
|
|
return
|
|
codegen_proxy()
|
|
|
|
elif isinstance(sym_or_exp, sympy.Integer):
|
|
return
|
|
|
|
elif isinstance(sym_or_exp, sympy.Expr):
|
|
# Check if we need to solve for an undefined symbol.
|
|
undefined_symbols = [
|
|
sym
|
|
for sym in sym_or_exp.free_symbols
|
|
if sym not in self.expr_to_proxy
|
|
]
|
|
if len(undefined_symbols) == 0:
|
|
self._sympy_interp(sym_or_exp)
|
|
return
|
|
elif len(undefined_symbols) > 1:
|
|
raise ValueError(f"Underdetermined input expression: {sym_or_exp}")
|
|
|
|
# Define a new symbol for the input size.
|
|
size_proxy = codegen_proxy()
|
|
size_symbol = sympy.Symbol(
|
|
size_proxy.node.name, integer=True, nonnegative=True
|
|
)
|
|
self.expr_to_proxy[size_symbol] = size_proxy
|
|
|
|
# Solve for the undefined symbol.
|
|
undefined_symbol = undefined_symbols[0]
|
|
solution = try_solve(
|
|
sympy.Eq(sym_or_exp, size_symbol), undefined_symbol
|
|
)
|
|
if solution is None:
|
|
raise ValueError(f"Cannot solve input expression: {sym_or_exp}")
|
|
|
|
# Since the symbol is a size, it must be an integer.
|
|
# Therefore, we can convert division to FloorDiv.
|
|
undefined_symbol_expr = solution[1]
|
|
if undefined_symbol.is_integer:
|
|
undefined_symbol_expr = replace_floor_div(
|
|
sympy.floor(undefined_symbol_expr)
|
|
)
|
|
|
|
# Generate FX for the symbol.
|
|
self._sympy_interp(undefined_symbol_expr)
|
|
self.expr_to_proxy[undefined_symbol] = self.expr_to_proxy[
|
|
undefined_symbol_expr
|
|
]
|
|
|
|
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
|
name = node.name
|
|
ir_node = self.graph_inputs.get(name)
|
|
if isinstance(ir_node, ir.TensorBox):
|
|
buffer = self._get_buffer(ir_node)
|
|
placeholder_node = self.buffer_to_node[buffer.get_name()]
|
|
|
|
for dim, size in enumerate(ir_node.get_size()):
|
|
_codegen_symbol(
|
|
size, placeholder_node, torch.ops.aten.sym_size.int, dim
|
|
)
|
|
for dim, stride in enumerate(ir_node.get_stride()):
|
|
_codegen_symbol(
|
|
stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
|
|
)
|
|
|
|
def _generate_graph_constants(self) -> None:
|
|
for name, value in V.graph.constants.items():
|
|
node = self.gm.graph.get_attr(name)
|
|
node.meta["val"] = value
|
|
setattr(self.gm, name, value)
|
|
self.buffer_to_node[name] = node
|
|
|
|
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
|
|
"""
|
|
Generates FX IR for transformations on a buffer, such as ReinterpretView.
|
|
Does nothing if no such transformations are present.
|
|
"""
|
|
|
|
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
|
|
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
|
return node
|
|
elif isinstance(node, ir.NoneAsConstantBuffer):
|
|
return None
|
|
elif isinstance(node, ir.MutableBox):
|
|
return generate_to_buffer(node.data)
|
|
elif isinstance(node, ir.ReinterpretView):
|
|
# We need to introduce a new symbol if the output is a ReinterpretView.
|
|
# Use a WorkspaceArg for this.
|
|
buffer = self._get_buffer(node.data)
|
|
assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
|
|
unique_name = self.gm.graph._graph_namespace.create_name(
|
|
f"{buffer.get_name()}_view", None
|
|
)
|
|
device = buffer.get_device()
|
|
assert device
|
|
reused_as = WorkspaceArg(
|
|
count=buffer.get_size(),
|
|
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
|
|
device=device,
|
|
outer_name=unique_name,
|
|
dtype=buffer.get_dtype(),
|
|
)
|
|
|
|
# Generate FX IR for the view.
|
|
self._generate_reinterpret_helper(buffer, reused_as, node.layout)
|
|
|
|
return reused_as
|
|
else:
|
|
raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
|
|
|
|
buffer = generate_to_buffer(node)
|
|
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
|
|
|
|
def _generate_output(self) -> None:
|
|
"""
|
|
Generate FX IR for graph outputs.
|
|
"""
|
|
output_nodes = [
|
|
self._generate_buffer(node) for idx, node in enumerate(self.graph_outputs)
|
|
]
|
|
|
|
# Parent graphs with single return elements don't use a tuple.
|
|
output_value = (
|
|
output_nodes[0]
|
|
if len(output_nodes) == 1 and not self.is_subgraph
|
|
else output_nodes
|
|
)
|
|
|
|
self.gm.graph.output(output_value)
|
|
|
|
def _generate_subgm_getattrs(self) -> None:
|
|
"""
|
|
Generate getattr nodes for subgms.
|
|
"""
|
|
|
|
def generate_getattr(name: str, subgm: torch.fx.GraphModule) -> torch.fx.Node:
|
|
self.gm.add_submodule(name, subgm)
|
|
node = self.gm.graph.get_attr(name)
|
|
node.meta["val"] = subgm
|
|
return node
|
|
|
|
self.subgm_getattrs = {
|
|
name: generate_getattr(name, subgm) for name, subgm in self.subgms.items()
|
|
}
|
|
|
|
def _get_subgm_attr(self, subgraph: ir.Subgraph) -> torch.fx.Node:
|
|
"""
|
|
Look up the getattr node for a subgraph.
|
|
"""
|
|
graph = subgraph.graph
|
|
assert graph is not None
|
|
return self.subgm_getattrs[graph.name]
|
|
|
|
def generate(self) -> torch.fx.GraphModule:
|
|
"""
|
|
Main entrypoint for FX codegen.
|
|
"""
|
|
self._generate_graph_inputs()
|
|
self._generate_graph_constants()
|
|
self._generate_subgm_getattrs()
|
|
|
|
fake_mode = _detect_fake_mode_from_gm(self.gm)
|
|
|
|
with _set_node_metadata_hook(
|
|
self.gm,
|
|
functools.partial(_node_metadata_hook, fake_mode=fake_mode),
|
|
):
|
|
self._generate_graph_input_shapes()
|
|
|
|
# Generate FX IR from Wrapper IR lines.
|
|
for line in self.lines:
|
|
if isinstance(line, WrapperLine):
|
|
line.codegen_fx(self)(line)
|
|
elif isinstance(line, LineContext):
|
|
# Ignore line context in FX IR.
|
|
pass
|
|
else:
|
|
raise NotImplementedError(
|
|
textwrap.dedent(
|
|
f"""
|
|
Found line of unrecognized type '{type(line)}':
|
|
'{line}'
|
|
|
|
FX conversion only supports Wrapper IR lines.
|
|
"""
|
|
)
|
|
)
|
|
|
|
self._generate_output()
|
|
self.gm.recompile()
|
|
return self.gm
|
|
|
|
def _sympy_interp(self, expr: sympy.Expr) -> torch.fx.Proxy:
|
|
# hash cons
|
|
if expr in self.expr_to_proxy:
|
|
return self.expr_to_proxy[expr]
|
|
# base cases, don't cache
|
|
if isinstance(
|
|
expr,
|
|
(
|
|
sympy.Integer,
|
|
sympy.Number,
|
|
sympy.Symbol,
|
|
sympy.logic.boolalg.BooleanAtom,
|
|
),
|
|
):
|
|
return sympy_interp(
|
|
OptimizedPythonReferenceAnalysis, self.expr_to_proxy, expr
|
|
)
|
|
|
|
# hash cons on arguments, run expr handler
|
|
self.expr_to_proxy[expr] = _run_sympy_handler(
|
|
OptimizedPythonReferenceAnalysis,
|
|
[self._sympy_interp(arg) for arg in expr.args],
|
|
expr,
|
|
)
|
|
return self.expr_to_proxy[expr]
|
|
|
|
def _generate_sym_node(
|
|
self, s: Union[int, sympy.Expr]
|
|
) -> Union[int, torch.fx.Node]:
|
|
if isinstance(s, (int, sympy.Integer)):
|
|
return int(s)
|
|
elif isinstance(s, sympy.Symbol):
|
|
assert s in self.expr_to_proxy, (
|
|
f"Could not find a node corresponding to the symbol {s}"
|
|
)
|
|
return self.expr_to_proxy[s].node
|
|
elif isinstance(s, sympy.Expr):
|
|
return self._sympy_interp(s).node
|
|
|
|
elif isinstance(s, torch.fx.Node):
|
|
return s
|
|
|
|
else:
|
|
raise ValueError(f"{s} of type {type(s)} is not a valid input")
|
|
|
|
def _generate_sym_nodes(
|
|
self, shape: Sequence[sympy.Expr]
|
|
) -> list[Union[int, torch.fx.Node]]:
|
|
return [self._generate_sym_node(s) for s in shape]
|
|
|
|
def _generate_allocate(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, AllocateLine)
|
|
buffer = line.node
|
|
name = buffer.get_name()
|
|
assert name not in V.graph.removed_buffers
|
|
|
|
device = buffer.get_device()
|
|
dtype = buffer.get_dtype()
|
|
shape = self._generate_sym_nodes(buffer.get_size())
|
|
stride = self._generate_sym_nodes(buffer.get_stride())
|
|
|
|
node = self.gm.graph.call_function(
|
|
torch.empty_strided,
|
|
args=(shape, stride),
|
|
kwargs={"dtype": dtype, "device": device},
|
|
)
|
|
assert name
|
|
node.name = name
|
|
self._record_allocation(buffer, node)
|
|
|
|
def _generate_conditional(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ConditionalLine)
|
|
|
|
def get_subgm_attr(subgraph: Optional[ir.Subgraph]) -> torch.fx.Node:
|
|
assert subgraph is not None
|
|
return self._get_subgm_attr(subgraph)
|
|
|
|
# Access the subgraphs as getattrs.
|
|
ir_node = line.node
|
|
(true_subgm, false_subgm) = [
|
|
get_subgm_attr(subgraph)
|
|
for subgraph in (ir_node.true_subgraph, ir_node.false_subgraph)
|
|
]
|
|
|
|
def generate_buffer(node: Optional[ir.IRNode]) -> Optional[torch.fx.Node]:
|
|
assert node is not None
|
|
return self._generate_buffer(node)
|
|
|
|
predicate = generate_buffer(ir_node.predicate)
|
|
assert ir_node.operands is not None
|
|
operands = tuple(generate_buffer(arg) for arg in ir_node.operands)
|
|
fx_node = self.gm.graph.call_function(
|
|
torch.ops.higher_order.cond,
|
|
args=(predicate, true_subgm, false_subgm, operands),
|
|
)
|
|
self._record_allocation(ir_node, fx_node)
|
|
|
|
def _generate_comment(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, CommentLine)
|
|
# We ignore comments in FX IR.
|
|
|
|
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, EnterDeviceContextManagerLine)
|
|
# We ignore the device context in FX IR.
|
|
|
|
def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ExitDeviceContextManagerLine)
|
|
# We ignore the device context in FX IR.
|
|
|
|
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, EnterSubgraphLine)
|
|
# We ignore memory planning lines in FX IR.
|
|
|
|
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ExitSubgraphLine)
|
|
# We ignore memory planning lines in FX IR.
|
|
|
|
def _generate_free(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, FreeLine)
|
|
|
|
buf = line.node
|
|
|
|
# No need to free placeholders.
|
|
if self.buffer_to_node[buf.get_name()].op == "placeholder":
|
|
return
|
|
|
|
self._free(buf)
|
|
|
|
def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, FreeIfNotReusedLine)
|
|
buf = line.node
|
|
assert buf.get_name() not in V.graph.removed_buffers
|
|
if not line.is_reused:
|
|
self._free(buf)
|
|
|
|
def _generate_line_context(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, LineContext)
|
|
# We ignore line context in FX IR.
|
|
|
|
def _generate_reinterpret(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ReinterpretLine)
|
|
self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
|
|
|
|
def _generate_reinterpret_helper(
|
|
self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
|
|
) -> None:
|
|
input_node = self.buffer_to_node[input_buffer.get_name()]
|
|
|
|
# Look up output metadata.
|
|
name = result_buffer.get_name()
|
|
assert name
|
|
size = tuple(layout.size)
|
|
stride = tuple(layout.stride)
|
|
if isinstance(layout, ir.NonOwningLayout):
|
|
# Look up the view's layout.
|
|
view = layout.view
|
|
assert isinstance(view, ir.ReinterpretView), (
|
|
f"unexpected type: {type(view)}"
|
|
)
|
|
layout = view.layout
|
|
offset = input_buffer.get_offset() + layout.offset
|
|
|
|
# Map ReinterpretView to as_strided.
|
|
result_node = self._create_as_strided(input_node, size, stride, offset)
|
|
result_node.name = name
|
|
self._record_allocation(result_buffer, result_node)
|
|
|
|
def _generate_reuse(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ReuseLine)
|
|
old = line.node
|
|
new = line.reused_as
|
|
assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
|
|
assert old.get_dtype() == new.get_dtype()
|
|
|
|
old_node = self.buffer_to_node[old.get_name()]
|
|
result_node = old_node
|
|
|
|
# Change shape and stride.
|
|
size = tuple(new.get_size())
|
|
stride = tuple(new.get_stride())
|
|
offset = new.get_offset()
|
|
if (
|
|
tuple(old.get_size()) != size
|
|
or tuple(old.get_stride()) != stride
|
|
or old.get_offset() != offset
|
|
):
|
|
result_node = self._create_as_strided(old_node, size, stride, offset)
|
|
|
|
self._record_allocation(new, result_node)
|
|
|
|
# Free the old buffer, if we allocated a new tensor.
|
|
if (
|
|
old.get_name() not in V.graph.get_output_names()
|
|
and line.delete_old
|
|
and result_node is not old_node
|
|
):
|
|
self._free(old)
|
|
|
|
def _generate_multi_output(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, MultiOutputLine)
|
|
|
|
arg_node = self.buffer_to_node[line.arg_name]
|
|
|
|
# For non-tuple / non-list outputs, map the
|
|
# output to the same node as the input.
|
|
if len(line.indices) == 0:
|
|
self.buffer_to_node[line.result_name] = arg_node
|
|
return
|
|
|
|
# Extract the index for tuple access.
|
|
inds = line.indices[0][1:]
|
|
assert len(inds) == 1, f"Cannot convert {inds} to an index."
|
|
idx = inds[0]
|
|
|
|
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
|
|
node.name = line.result_name
|
|
self.buffer_to_node[line.result_name] = node
|
|
|
|
def _generate_fallback_call(
|
|
self,
|
|
ir_node: ir.ExternKernel,
|
|
args: Optional[tuple[Any, ...]] = None,
|
|
kwargs: Optional[dict[str, Any]] = None,
|
|
) -> None:
|
|
fx_node = self.gm.graph.call_function(
|
|
ir_node.op_overload, # type: ignore[arg-type]
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
result_buffer = ir_node.codegen_reference()
|
|
self.buffer_to_node[result_buffer] = fx_node
|
|
|
|
def _generate_index_put_fallback(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, IndexPutFallbackLine)
|
|
ir_node = line.node
|
|
|
|
def generate_buffer_or_none(
|
|
x: Union[ir.IRNode, Sequence[ir.IRNode], None],
|
|
) -> Optional[torch.fx.Node]:
|
|
"""
|
|
Handles None before calling _generate_buffer.
|
|
"""
|
|
if x is None:
|
|
return None
|
|
|
|
assert isinstance(x, ir.IRNode)
|
|
return self._generate_buffer(x)
|
|
|
|
(x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]]
|
|
indices = tuple(generate_buffer_or_none(t) for t in line.indices)
|
|
accumulate = ir_node.constant_args[0]
|
|
args = (x, indices, values, accumulate)
|
|
self._generate_fallback_call(ir_node, args)
|
|
|
|
def _generate_scatter_fallback(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ScatterFallbackLine)
|
|
ir_node = line.node
|
|
assert ir.is_node_sequence(ir_node.inputs)
|
|
(x, index, src) = [self._generate_buffer(t) for t in ir_node.inputs] + (
|
|
[] if ir_node.src_is_tensor else [ir_node.constant_args[1]]
|
|
)
|
|
args = (x, ir_node.constant_args[0], index, src)
|
|
kwargs = {}
|
|
if reduce := ir_node.kwargs.get("reduce"):
|
|
kwargs["reduce"] = reduce
|
|
|
|
self._generate_fallback_call(ir_node, args, kwargs)
|
|
|
|
def _generate_null(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, NullLine)
|
|
# Does nothing.
|
|
|
|
def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, CommBufferAllocateLine)
|
|
raise NotImplementedError("Comm buffer allocation is not yet supported")
|
|
|
|
def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, CommBufferFreeLine)
|
|
self._free(line.node)
|
|
|
|
def _generate_triton_call(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, KernelCallLine)
|
|
|
|
# Collect all kwargs, including autotuned block sizes.
|
|
call_args = self._lookup_args(line.call_args)
|
|
kernel = self.kernels[line.kernel_name]
|
|
tuner = kernel.tuner
|
|
# Use python_slow mode instead of python mode to avoid
|
|
# the round to neginf behaviour, which is not the convention
|
|
# in other languages.
|
|
tuner.grid_mode = "python_slow"
|
|
|
|
# Optionally autotune the kernels.
|
|
# The FX backend currently only supports compile-time tuning.
|
|
kernel_name = tuner.fn.__name__
|
|
if config.triton.autotune_at_compile_time:
|
|
from triton.runtime import driver
|
|
|
|
log.info("Autotuning Triton kernel %s at compile time.", kernel_name)
|
|
device = driver.active.get_current_device()
|
|
stream = driver.active.get_current_stream(device)
|
|
|
|
def node_to_tuning_arg(arg: Any) -> Any:
|
|
"""
|
|
Create real tensors for autotuning arguments, substituting size hints
|
|
for dynamic shapes.
|
|
"""
|
|
to_size_hint = functools.partial(
|
|
pytree.tree_map, V.graph.sizevars.size_hint
|
|
)
|
|
if not isinstance(arg, torch.fx.Node):
|
|
return to_size_hint(arg)
|
|
|
|
fake = arg.meta["val"]
|
|
return torch.empty_strided(
|
|
to_size_hint(fake.shape),
|
|
to_size_hint(fake.stride()),
|
|
device=device,
|
|
).zero_()
|
|
|
|
arg_values = [node_to_tuning_arg(arg) for arg in call_args]
|
|
tuner.run(*arg_values, stream=stream)
|
|
else:
|
|
log.info(
|
|
"Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.",
|
|
kernel_name,
|
|
)
|
|
|
|
triton_meta = tuner.triton_meta
|
|
signature = triton_meta["signature"]
|
|
|
|
def add_constants_to_call_args(
|
|
call_args: Sequence[Any], cfg: Config
|
|
) -> tuple[Any, ...]:
|
|
"""
|
|
Add constant kwargs to the arg list.
|
|
"""
|
|
# Add args from the proper Triton signature.
|
|
new_call_args = []
|
|
call_arg_idx = 0
|
|
constants = triton_meta["constants"]
|
|
for arg_name in signature:
|
|
# Config kwargs are tracked separately.
|
|
if arg_name in cfg.kwargs:
|
|
continue
|
|
|
|
try:
|
|
new_arg = constants[arg_name]
|
|
except KeyError:
|
|
new_arg = call_args[call_arg_idx]
|
|
call_arg_idx += 1
|
|
new_call_args.append(new_arg)
|
|
|
|
# Add Inductor's extra call args to the end.
|
|
new_call_args.extend(call_args[call_arg_idx:])
|
|
|
|
return tuple(new_call_args)
|
|
|
|
kernel_config = tuner.compile_results[0].config
|
|
call_args = add_constants_to_call_args(call_args, kernel_config)
|
|
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
|
|
call_kwargs = dict(zip(signature, call_args))
|
|
call_kwargs.update(kernel_config.kwargs)
|
|
|
|
# Replace all sympy.floor with FloorDiv
|
|
# _generate_sym_node does not support sympy.floor
|
|
grid = [replace_floor_div(x) if isinstance(x, sympy.Expr) else x for x in grid]
|
|
wrapper_grid = [tuple(self._generate_sym_nodes(grid))]
|
|
call_kwargs = {
|
|
name: self._generate_sym_node(val) for name, val in call_kwargs.items()
|
|
}
|
|
|
|
# Store non-graphable kwargs in the side table.
|
|
(
|
|
call_kwargs,
|
|
constant_args_idx,
|
|
) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
|
|
|
|
self.gm.graph.call_function(
|
|
triton_kernel_wrapper_mutation,
|
|
kwargs={
|
|
"kernel_idx": kernel.wrapped.kernel_idx,
|
|
"constant_args_idx": constant_args_idx,
|
|
"grid": wrapper_grid,
|
|
"tma_descriptor_metadata": {},
|
|
"kwargs": call_kwargs,
|
|
},
|
|
)
|
|
|
|
def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, ExternKernelAllocLine)
|
|
node = line.node
|
|
self._generate_extern_kernel_common(node, node)
|
|
|
|
def _generate_extern_kernel_out(
|
|
self,
|
|
line: WrapperLine,
|
|
) -> None:
|
|
assert isinstance(line, ExternKernelOutLine)
|
|
node = line.node
|
|
out_node = node.output_view if node.output_view else node
|
|
self._generate_extern_kernel_common(node, out_node)
|
|
|
|
def _generate_extern_kernel_common(
|
|
self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
|
|
) -> None:
|
|
"""
|
|
Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
|
|
"""
|
|
|
|
# Get FX nodes corresponding to the call args.
|
|
assert ir.is_node_sequence(kernel.inputs)
|
|
tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
|
|
args = tensor_nodes + tuple(kernel.constant_args)
|
|
|
|
# Get the result buffer.
|
|
# Some kernels write to a pre-existing output tensor via the "out" kwarg.
|
|
kwargs = kernel.kwargs.copy()
|
|
result_buffer: Optional[str] = None
|
|
if isinstance(kernel, ir.ExternKernelOut):
|
|
kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
|
|
elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
|
|
result_buffer = kernel.get_name()
|
|
elif isinstance(kernel.layout, ir.NoneLayout):
|
|
pass
|
|
else:
|
|
raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
|
|
|
|
fx_node = self.gm.graph.call_function(
|
|
kernel.op_overload, # type: ignore[arg-type]
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
|
|
# Assign the result to the given name.
|
|
if result_buffer:
|
|
assert "out" not in kwargs, (
|
|
f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
|
|
)
|
|
fx_node.name = result_buffer
|
|
self.buffer_to_node[result_buffer] = fx_node
|
|
|
|
def _generate_kernel_call(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, KernelCallLine)
|
|
if not line.triton:
|
|
raise NotImplementedError("FX conversion only supports Triton kernels.")
|
|
|
|
self._generate_triton_call(line)
|
|
|
|
def _generate_kernel_definition(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, KernelDefinitionLine)
|
|
|
|
# Generate code for the kernel.
|
|
kernel_code = PythonWrapperCodegen._format_kernel_definition(
|
|
line.kernel_name, line.kernel_body, metadata=line.metadata
|
|
)
|
|
|
|
# Import the module and store the JIT kernel.
|
|
tuner = self._import_kernel(kernel_code, line.kernel_name)
|
|
wrapped = wrap_triton(tuner.fn)
|
|
self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
|
|
|
|
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
|
|
assert isinstance(line, SymbolicCallArgLine)
|
|
# Store the arg: expr mapping for later use.
|
|
arg = line.arg
|
|
|
|
inner_expr_proxy = self._sympy_interp(arg.inner_expr)
|
|
self.expr_to_proxy[arg.inner] = inner_expr_proxy
|