TritonTemplate dtype fixes (#141991)

- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
This commit is contained in:
eellison
2024-12-03 13:52:56 -08:00
committed by PyTorch MergeBot
parent 920e4364b7
commit fd35be2fd3
5 changed files with 48 additions and 22 deletions

View File

@ -1563,7 +1563,9 @@ class CSE:
bounds: ValueRanges[Any] = ValueRanges.unknown(),
dtype: Optional[torch.dtype] = None,
) -> CSEVariable:
assert name not in self.varname_map, "duplicate name"
torch._check_value(
name not in self.varname_map, lambda: f"duplicate name: {name}"
)
var = V.kernel.create_cse_var(name, bounds, dtype)
self.varname_map[name] = var
return var

View File

@ -8,7 +8,6 @@ import functools
import itertools
import logging
import os
import re
import textwrap
from functools import lru_cache
from typing import (
@ -62,6 +61,7 @@ from ..utils import (
is_welford_reduction,
Placeholder,
sympy_subs,
triton_type,
upcast_compute_type,
)
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
@ -659,22 +659,6 @@ class TritonPrinter(PythonPrinter):
texpr = TritonPrinter().doprint
# correct cases where Triton types names don't match PyTorch
_triton_type_mapping = {
"tl.bool": "tl.int1",
"tl.float8_e4m3fn": "tl.float8e4nv",
"tl.float8_e5m2": "tl.float8e5",
"tl.float8_e4m3fnuz": "tl.float8e4b8",
"tl.float8_e5m2fnuz": "tl.float8e5b16",
}
_triton_type_re = re.compile(r"^.*[.]")
def triton_type(dtype: torch.dtype) -> str:
"""Convert torch.dtype to triton type"""
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
return _triton_type_mapping.get(triton_type_name, triton_type_name)
def triton_compute_type(dtype: torch.dtype) -> str:
"""Convert torch.dtype to triton type and upcast [b]float16 to float32"""

View File

@ -19,8 +19,9 @@ if TYPE_CHECKING:
import torch
from torch._inductor.virtualized import V
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
from . import config
from .utils import upcast_compute_type
from .virtualized import OpsValue
@ -72,7 +73,12 @@ def promote_types(
for arg in args:
if isinstance(arg, str):
# comes from templates.. TODO
# TODO: fix the flex attention instances, enable internally
if not config.is_fbcode():
assert isinstance(
V.get_ops_handler(),
torch._inductor.select_algorithm.ModificationWrapper,
)
continue
if isinstance(arg, OpsValue):
@ -80,7 +86,7 @@ def promote_types(
assert isinstance(arg, torch._prims_common.Number) or hasattr(arg, "dtype")
if isinstance(arg, torch._prims_common.Number):
dtype_prop_candidates.append((torch.tensor(arg).dtype, True))
dtype_prop_candidates.append((type_to_dtype(type(arg)), True))
continue
dtype_prop_candidates.append((arg.dtype, False))

View File

@ -63,6 +63,7 @@ from .utils import (
sympy_dot,
sympy_index_symbol,
sympy_product,
triton_type_to_torch,
unique,
)
from .virtualized import V
@ -605,7 +606,12 @@ class TritonTemplateKernel(TritonKernel):
if output_index == contiguous_index:
output_index = sympy.Symbol("xindex", integer=True)
epilogue_args = [val]
acc_dtype = (
triton_type_to_torch(self.meta["ACC_TYPE"])
if "ACC_TYPE" in self.meta
else torch.float32
)
epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)]
for input_node in itertools.chain(
self.input_nodes[: self.prefix_args],
self.input_nodes[len(self.input_nodes) - self.suffix_args :],

View File

@ -2185,6 +2185,34 @@ def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
# correct cases where Triton types names don't match PyTorch
_triton_type_mapping = {
"tl.bool": "tl.int1",
"tl.float8_e4m3fn": "tl.float8e4nv",
"tl.float8_e5m2": "tl.float8e5",
"tl.float8_e4m3fnuz": "tl.float8e4b8",
"tl.float8_e5m2fnuz": "tl.float8e5b16",
}
_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
_triton_type_re = re.compile(r"^.*[.]")
def triton_type(dtype: torch.dtype) -> str:
"""Convert torch.dtype to triton type"""
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
return _triton_type_mapping.get(triton_type_name, triton_type_name)
def triton_type_to_torch(dtype: str) -> torch.dtype:
adjusted_type = _torch_triton_mapping.get(dtype, dtype)
type_name = adjusted_type.replace("tl.", "")
out_dtype = getattr(torch, type_name)
assert isinstance(out_dtype, torch.dtype)
return out_dtype
def is_same_tensor(data: torch.Tensor, value: torch.Tensor):
return (
not data.is_mkldnn