diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f40e506388c9..6bf3104d1509 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1563,7 +1563,9 @@ class CSE: bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, ) -> CSEVariable: - assert name not in self.varname_map, "duplicate name" + torch._check_value( + name not in self.varname_map, lambda: f"duplicate name: {name}" + ) var = V.kernel.create_cse_var(name, bounds, dtype) self.varname_map[name] = var return var diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 96abe257a070..c4d1ac56ee55 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -8,7 +8,6 @@ import functools import itertools import logging import os -import re import textwrap from functools import lru_cache from typing import ( @@ -62,6 +61,7 @@ from ..utils import ( is_welford_reduction, Placeholder, sympy_subs, + triton_type, upcast_compute_type, ) from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V @@ -659,22 +659,6 @@ class TritonPrinter(PythonPrinter): texpr = TritonPrinter().doprint -# correct cases where Triton types names don't match PyTorch -_triton_type_mapping = { - "tl.bool": "tl.int1", - "tl.float8_e4m3fn": "tl.float8e4nv", - "tl.float8_e5m2": "tl.float8e5", - "tl.float8_e4m3fnuz": "tl.float8e4b8", - "tl.float8_e5m2fnuz": "tl.float8e5b16", -} -_triton_type_re = re.compile(r"^.*[.]") - - -def triton_type(dtype: torch.dtype) -> str: - """Convert torch.dtype to triton type""" - triton_type_name = _triton_type_re.sub("tl.", str(dtype)) - return _triton_type_mapping.get(triton_type_name, triton_type_name) - def triton_compute_type(dtype: torch.dtype) -> str: """Convert torch.dtype to triton type and upcast [b]float16 to float32""" diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index a7d419836036..21db66588639 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -19,8 +19,9 @@ if TYPE_CHECKING: import torch from torch._inductor.virtualized import V -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype +from . import config from .utils import upcast_compute_type from .virtualized import OpsValue @@ -72,7 +73,12 @@ def promote_types( for arg in args: if isinstance(arg, str): - # comes from templates.. TODO + # TODO: fix the flex attention instances, enable internally + if not config.is_fbcode(): + assert isinstance( + V.get_ops_handler(), + torch._inductor.select_algorithm.ModificationWrapper, + ) continue if isinstance(arg, OpsValue): @@ -80,7 +86,7 @@ def promote_types( assert isinstance(arg, torch._prims_common.Number) or hasattr(arg, "dtype") if isinstance(arg, torch._prims_common.Number): - dtype_prop_candidates.append((torch.tensor(arg).dtype, True)) + dtype_prop_candidates.append((type_to_dtype(type(arg)), True)) continue dtype_prop_candidates.append((arg.dtype, False)) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dad228e2123b..edf7f8eecd3a 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -63,6 +63,7 @@ from .utils import ( sympy_dot, sympy_index_symbol, sympy_product, + triton_type_to_torch, unique, ) from .virtualized import V @@ -605,7 +606,12 @@ class TritonTemplateKernel(TritonKernel): if output_index == contiguous_index: output_index = sympy.Symbol("xindex", integer=True) - epilogue_args = [val] + acc_dtype = ( + triton_type_to_torch(self.meta["ACC_TYPE"]) + if "ACC_TYPE" in self.meta + else torch.float32 + ) + epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)] for input_node in itertools.chain( self.input_nodes[: self.prefix_args], self.input_nodes[len(self.input_nodes) - self.suffix_args :], diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index fd3fb1271e24..3f91ee87b3a5 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2185,6 +2185,34 @@ def normalize_name(name: str) -> str: return re.sub(r"[^a-zA-Z0-9_]", "_", name) +# correct cases where Triton types names don't match PyTorch +_triton_type_mapping = { + "tl.bool": "tl.int1", + "tl.float8_e4m3fn": "tl.float8e4nv", + "tl.float8_e5m2": "tl.float8e5", + "tl.float8_e4m3fnuz": "tl.float8e4b8", + "tl.float8_e5m2fnuz": "tl.float8e5b16", +} +_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()} + + +_triton_type_re = re.compile(r"^.*[.]") + + +def triton_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type""" + triton_type_name = _triton_type_re.sub("tl.", str(dtype)) + return _triton_type_mapping.get(triton_type_name, triton_type_name) + + +def triton_type_to_torch(dtype: str) -> torch.dtype: + adjusted_type = _torch_triton_mapping.get(dtype, dtype) + type_name = adjusted_type.replace("tl.", "") + out_dtype = getattr(torch, type_name) + assert isinstance(out_dtype, torch.dtype) + return out_dtype + + def is_same_tensor(data: torch.Tensor, value: torch.Tensor): return ( not data.is_mkldnn