mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
### Implementation of #151705 This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates. To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705: 1. **Basic support** (this PR) 2. **Lazy broadcasting** for optimal performance (future PR) ### Summary of This PR This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead. ### Notable Changes 1. Adds a new config flag: `config.triton.enable_native_matmul` 2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled 3. Enforces tililng suitable for matmul when the native matmul flag is enabled 4. Implements code generation for `ops.dot` 5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this. @eellison @jansel @PaulZhang12 @shunting314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743 Approved by: https://github.com/jansel
155 lines
4.5 KiB
Python
155 lines
4.5 KiB
Python
import functools
|
|
from collections.abc import Sequence
|
|
from typing import Callable, Optional, Protocol, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
|
|
from .virtualized import OpsValue, V
|
|
|
|
|
|
BlockShapeType = Optional[Sequence[Union[int, str]]]
|
|
|
|
|
|
class ShapeVar(Protocol):
|
|
@property
|
|
def shape(self) -> BlockShapeType: ...
|
|
|
|
|
|
ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype]
|
|
|
|
# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective
|
|
# So first decompose CSEVars -> tuple before calling this
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_broadcasted_shape(a: BlockShapeType, b: BlockShapeType) -> BlockShapeType:
|
|
assert isinstance(a, Sequence)
|
|
assert isinstance(b, Sequence)
|
|
if len(a) > len(b):
|
|
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
|
|
elif len(a) < len(b):
|
|
b, a = a, b
|
|
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
|
|
else:
|
|
|
|
def _get_broadcasted_dim(
|
|
d1: Union[int, str], d2: Union[int, str]
|
|
) -> Union[int, str]:
|
|
if str(d1) == "1":
|
|
return d2
|
|
elif str(d2) == "1":
|
|
return d1
|
|
assert str(d1) == str(d2)
|
|
return d1
|
|
|
|
return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b))
|
|
|
|
|
|
def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> BlockShapeType:
|
|
result_shape: BlockShapeType = None
|
|
|
|
for arg in args:
|
|
if hasattr(arg, "shape"):
|
|
shape = arg.shape
|
|
if shape is None:
|
|
return None
|
|
elif result_shape is None:
|
|
result_shape = tuple(shape)
|
|
else:
|
|
result_shape = get_broadcasted_shape(result_shape, tuple(shape))
|
|
elif isinstance(arg, (int, float)):
|
|
if result_shape is None:
|
|
result_shape = ()
|
|
elif isinstance(arg, torch.dtype):
|
|
continue
|
|
else:
|
|
from torch._inductor.loop_body import LoopBody, LoopBodyBlock
|
|
|
|
if isinstance(arg, (LoopBodyBlock, LoopBody, OpsValue)):
|
|
# TODO: fix me
|
|
return None
|
|
raise TypeError(f"Unknown type: {type(arg)}")
|
|
|
|
return result_shape
|
|
|
|
|
|
class ShapePropagationOpsHandler:
|
|
"""
|
|
Propagate shape from args to output
|
|
"""
|
|
|
|
@staticmethod
|
|
def constant(value: torch.types.Number, dtype: torch.dtype) -> BlockShapeType:
|
|
# See implementation of constant for triton for the reason
|
|
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
|
|
|
triton_type = triton_compute_type(dtype)
|
|
|
|
if isinstance(V.kernel, TritonKernel) and triton_type != "tl.float32":
|
|
ndim = V.kernel.triton_tensor_ndim()
|
|
return tuple([1] * ndim)
|
|
else:
|
|
return ()
|
|
|
|
@staticmethod
|
|
def store_reduction(name: str, index: int, value: ShapeArg) -> None:
|
|
return None
|
|
|
|
@staticmethod
|
|
def reduction(
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: str,
|
|
value: Union[ShapeArg, tuple[ShapeArg, ...]],
|
|
) -> Union[BlockShapeType, tuple[BlockShapeType, ...]]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def store(
|
|
name: str, index: int, value: ShapeArg, mode: Optional[str] = None
|
|
) -> None:
|
|
return None
|
|
|
|
@staticmethod
|
|
def to_dtype(
|
|
value: ShapeVar,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types: bool = True,
|
|
) -> BlockShapeType:
|
|
return value.shape
|
|
|
|
@staticmethod
|
|
def dot(a: sympy.Expr, b: sympy.Expr) -> BlockShapeType:
|
|
from torch._inductor.codegen.triton import TritonKernel
|
|
|
|
assert isinstance(V.kernel, TritonKernel), "dot supports Triton only"
|
|
return ("YBLOCK", "XBLOCK")
|
|
|
|
@staticmethod
|
|
def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> BlockShapeType:
|
|
# shape is implicitly embedded in expr.
|
|
return None
|
|
|
|
@staticmethod
|
|
def load_seed(name: str, offset: int) -> BlockShapeType:
|
|
return ()
|
|
|
|
@staticmethod
|
|
def indirect_indexing(
|
|
var: ShapeArg,
|
|
size: Union[sympy.Expr, int],
|
|
check: bool = True,
|
|
wrap_neg: bool = True,
|
|
) -> None:
|
|
return None
|
|
|
|
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
|
|
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
|
|
|
|
@staticmethod
|
|
def device_assert_async(cond: ShapeArg, msg: str) -> None:
|
|
return None
|