Files
pytorch/torch/_inductor/shape_propagation.py
nullplay ac529df244 Native matmul (#157743)
### Implementation of #151705

This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.

To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:

1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)

### Summary of This PR

This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.

### Notable Changes

1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.

@eellison @jansel @PaulZhang12 @shunting314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
2025-10-14 04:22:30 +00:00

155 lines
4.5 KiB
Python

import functools
from collections.abc import Sequence
from typing import Callable, Optional, Protocol, Union
import sympy
import torch
from .virtualized import OpsValue, V
BlockShapeType = Optional[Sequence[Union[int, str]]]
class ShapeVar(Protocol):
@property
def shape(self) -> BlockShapeType: ...
ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype]
# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective
# So first decompose CSEVars -> tuple before calling this
@functools.lru_cache(None)
def get_broadcasted_shape(a: BlockShapeType, b: BlockShapeType) -> BlockShapeType:
assert isinstance(a, Sequence)
assert isinstance(b, Sequence)
if len(a) > len(b):
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
elif len(a) < len(b):
b, a = a, b
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
else:
def _get_broadcasted_dim(
d1: Union[int, str], d2: Union[int, str]
) -> Union[int, str]:
if str(d1) == "1":
return d2
elif str(d2) == "1":
return d1
assert str(d1) == str(d2)
return d1
return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b))
def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> BlockShapeType:
result_shape: BlockShapeType = None
for arg in args:
if hasattr(arg, "shape"):
shape = arg.shape
if shape is None:
return None
elif result_shape is None:
result_shape = tuple(shape)
else:
result_shape = get_broadcasted_shape(result_shape, tuple(shape))
elif isinstance(arg, (int, float)):
if result_shape is None:
result_shape = ()
elif isinstance(arg, torch.dtype):
continue
else:
from torch._inductor.loop_body import LoopBody, LoopBodyBlock
if isinstance(arg, (LoopBodyBlock, LoopBody, OpsValue)):
# TODO: fix me
return None
raise TypeError(f"Unknown type: {type(arg)}")
return result_shape
class ShapePropagationOpsHandler:
"""
Propagate shape from args to output
"""
@staticmethod
def constant(value: torch.types.Number, dtype: torch.dtype) -> BlockShapeType:
# See implementation of constant for triton for the reason
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
triton_type = triton_compute_type(dtype)
if isinstance(V.kernel, TritonKernel) and triton_type != "tl.float32":
ndim = V.kernel.triton_tensor_ndim()
return tuple([1] * ndim)
else:
return ()
@staticmethod
def store_reduction(name: str, index: int, value: ShapeArg) -> None:
return None
@staticmethod
def reduction(
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: str,
value: Union[ShapeArg, tuple[ShapeArg, ...]],
) -> Union[BlockShapeType, tuple[BlockShapeType, ...]]:
raise NotImplementedError
@staticmethod
def store(
name: str, index: int, value: ShapeArg, mode: Optional[str] = None
) -> None:
return None
@staticmethod
def to_dtype(
value: ShapeVar,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> BlockShapeType:
return value.shape
@staticmethod
def dot(a: sympy.Expr, b: sympy.Expr) -> BlockShapeType:
from torch._inductor.codegen.triton import TritonKernel
assert isinstance(V.kernel, TritonKernel), "dot supports Triton only"
return ("YBLOCK", "XBLOCK")
@staticmethod
def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> BlockShapeType:
# shape is implicitly embedded in expr.
return None
@staticmethod
def load_seed(name: str, offset: int) -> BlockShapeType:
return ()
@staticmethod
def indirect_indexing(
var: ShapeArg,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg: bool = True,
) -> None:
return None
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
@staticmethod
def device_assert_async(cond: ShapeArg, msg: str) -> None:
return None