mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][Triton] Update TMA Compatibility Requirements (#157881)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157881 Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
e71bb021b9
commit
ea74fdd24a
@ -1543,35 +1543,87 @@ def use_triton_template(
|
||||
)
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
"""
|
||||
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
|
||||
that Triton relies on today.
|
||||
* https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
|
||||
|
||||
A tensor is accepted when:
|
||||
* 2 ≤ rank ≤ 5
|
||||
* dtype ∈ {FP16, BF16, FP8-E4M3FN}
|
||||
* Every logical size ≥ 2
|
||||
* Base pointer 16-byte aligned
|
||||
* All "outer" dims have 16-byte aligned strides
|
||||
* The “inner” dim has stride 1 (contiguous)
|
||||
* For FP8 tensors, inner dim ≥ 32
|
||||
"""
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
|
||||
return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
|
||||
|
||||
def _is_tma_compatible(x: IRNode) -> bool:
|
||||
if len(x.get_size()) != 2:
|
||||
sizes = x.get_size()
|
||||
strides = x.get_stride()
|
||||
rank = len(sizes)
|
||||
dtype = x.get_dtype()
|
||||
itemsize = dtype.itemsize
|
||||
|
||||
# 2 ≤ rank ≤ 5
|
||||
if rank < 2 or rank > 5:
|
||||
return False
|
||||
|
||||
dtype = x.get_dtype()
|
||||
# dtype ∈ {FP16, BF16, FP8-E4M3FN}
|
||||
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
|
||||
return False
|
||||
|
||||
layout = x.get_layout()
|
||||
transposed = layout.is_transposed()
|
||||
if not (layout.is_contiguous() or transposed):
|
||||
# Base pointer 16-byte aligned
|
||||
if x.get_name() in V.graph.unaligned_buffers:
|
||||
return False
|
||||
|
||||
inner_dim = layout.size[1]
|
||||
if transposed:
|
||||
inner_dim = layout.size[0]
|
||||
if add_guards:
|
||||
sizes_i = V.graph.sizevars.guard_int_seq(sizes)
|
||||
strides_i = V.graph.sizevars.guard_int_seq(strides)
|
||||
else:
|
||||
sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
|
||||
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
|
||||
|
||||
if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt(
|
||||
# Every logical size ≥ 2
|
||||
if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
|
||||
return False
|
||||
|
||||
# Find the single contiguous (“inner”) dim
|
||||
inner = [
|
||||
i
|
||||
for i, st in enumerate(strides_i)
|
||||
if V.graph.sizevars.statically_known_equals(st, 1)
|
||||
]
|
||||
if len(inner) != 1:
|
||||
return False
|
||||
inner_idx = inner[0]
|
||||
|
||||
# All "outer" dims must have 16-byte aligned strides
|
||||
for i, st in enumerate(strides_i):
|
||||
if i == inner_idx:
|
||||
continue
|
||||
if not _aligned(st * itemsize):
|
||||
return False
|
||||
|
||||
# Inner dim byte width must still be a multiple of 16 B
|
||||
inner_dim = sizes_i[inner_idx]
|
||||
if not _aligned(inner_dim * itemsize):
|
||||
return False
|
||||
|
||||
# FP8 special case: inner ≥ 32
|
||||
if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq(
|
||||
inner_dim, 32
|
||||
):
|
||||
return False
|
||||
|
||||
inner_bytes = inner_dim * dtype.itemsize
|
||||
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
|
||||
return True
|
||||
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
|
Reference in New Issue
Block a user