mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype - Update dtype propagation to use `type_to_dtype` instead of instantiating tensor - Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ). Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991 Approved by: https://github.com/drisspg ghstack dependencies: #139945, #140057, #141495, #141882
This commit is contained in:
committed by
PyTorch MergeBot
parent
920e4364b7
commit
fd35be2fd3
@ -1563,7 +1563,9 @@ class CSE:
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> CSEVariable:
|
||||
assert name not in self.varname_map, "duplicate name"
|
||||
torch._check_value(
|
||||
name not in self.varname_map, lambda: f"duplicate name: {name}"
|
||||
)
|
||||
var = V.kernel.create_cse_var(name, bounds, dtype)
|
||||
self.varname_map[name] = var
|
||||
return var
|
||||
|
@ -8,7 +8,6 @@ import functools
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
@ -62,6 +61,7 @@ from ..utils import (
|
||||
is_welford_reduction,
|
||||
Placeholder,
|
||||
sympy_subs,
|
||||
triton_type,
|
||||
upcast_compute_type,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
||||
@ -659,22 +659,6 @@ class TritonPrinter(PythonPrinter):
|
||||
|
||||
texpr = TritonPrinter().doprint
|
||||
|
||||
# correct cases where Triton types names don't match PyTorch
|
||||
_triton_type_mapping = {
|
||||
"tl.bool": "tl.int1",
|
||||
"tl.float8_e4m3fn": "tl.float8e4nv",
|
||||
"tl.float8_e5m2": "tl.float8e5",
|
||||
"tl.float8_e4m3fnuz": "tl.float8e4b8",
|
||||
"tl.float8_e5m2fnuz": "tl.float8e5b16",
|
||||
}
|
||||
_triton_type_re = re.compile(r"^.*[.]")
|
||||
|
||||
|
||||
def triton_type(dtype: torch.dtype) -> str:
|
||||
"""Convert torch.dtype to triton type"""
|
||||
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
|
||||
return _triton_type_mapping.get(triton_type_name, triton_type_name)
|
||||
|
||||
|
||||
def triton_compute_type(dtype: torch.dtype) -> str:
|
||||
"""Convert torch.dtype to triton type and upcast [b]float16 to float32"""
|
||||
|
@ -19,8 +19,9 @@ if TYPE_CHECKING:
|
||||
|
||||
import torch
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
|
||||
|
||||
from . import config
|
||||
from .utils import upcast_compute_type
|
||||
from .virtualized import OpsValue
|
||||
|
||||
@ -72,7 +73,12 @@ def promote_types(
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
# comes from templates.. TODO
|
||||
# TODO: fix the flex attention instances, enable internally
|
||||
if not config.is_fbcode():
|
||||
assert isinstance(
|
||||
V.get_ops_handler(),
|
||||
torch._inductor.select_algorithm.ModificationWrapper,
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(arg, OpsValue):
|
||||
@ -80,7 +86,7 @@ def promote_types(
|
||||
assert isinstance(arg, torch._prims_common.Number) or hasattr(arg, "dtype")
|
||||
|
||||
if isinstance(arg, torch._prims_common.Number):
|
||||
dtype_prop_candidates.append((torch.tensor(arg).dtype, True))
|
||||
dtype_prop_candidates.append((type_to_dtype(type(arg)), True))
|
||||
continue
|
||||
|
||||
dtype_prop_candidates.append((arg.dtype, False))
|
||||
|
@ -63,6 +63,7 @@ from .utils import (
|
||||
sympy_dot,
|
||||
sympy_index_symbol,
|
||||
sympy_product,
|
||||
triton_type_to_torch,
|
||||
unique,
|
||||
)
|
||||
from .virtualized import V
|
||||
@ -605,7 +606,12 @@ class TritonTemplateKernel(TritonKernel):
|
||||
if output_index == contiguous_index:
|
||||
output_index = sympy.Symbol("xindex", integer=True)
|
||||
|
||||
epilogue_args = [val]
|
||||
acc_dtype = (
|
||||
triton_type_to_torch(self.meta["ACC_TYPE"])
|
||||
if "ACC_TYPE" in self.meta
|
||||
else torch.float32
|
||||
)
|
||||
epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)]
|
||||
for input_node in itertools.chain(
|
||||
self.input_nodes[: self.prefix_args],
|
||||
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
|
||||
|
@ -2185,6 +2185,34 @@ def normalize_name(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
# correct cases where Triton types names don't match PyTorch
|
||||
_triton_type_mapping = {
|
||||
"tl.bool": "tl.int1",
|
||||
"tl.float8_e4m3fn": "tl.float8e4nv",
|
||||
"tl.float8_e5m2": "tl.float8e5",
|
||||
"tl.float8_e4m3fnuz": "tl.float8e4b8",
|
||||
"tl.float8_e5m2fnuz": "tl.float8e5b16",
|
||||
}
|
||||
_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
|
||||
|
||||
|
||||
_triton_type_re = re.compile(r"^.*[.]")
|
||||
|
||||
|
||||
def triton_type(dtype: torch.dtype) -> str:
|
||||
"""Convert torch.dtype to triton type"""
|
||||
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
|
||||
return _triton_type_mapping.get(triton_type_name, triton_type_name)
|
||||
|
||||
|
||||
def triton_type_to_torch(dtype: str) -> torch.dtype:
|
||||
adjusted_type = _torch_triton_mapping.get(dtype, dtype)
|
||||
type_name = adjusted_type.replace("tl.", "")
|
||||
out_dtype = getattr(torch, type_name)
|
||||
assert isinstance(out_dtype, torch.dtype)
|
||||
return out_dtype
|
||||
|
||||
|
||||
def is_same_tensor(data: torch.Tensor, value: torch.Tensor):
|
||||
return (
|
||||
not data.is_mkldnn
|
||||
|
Reference in New Issue
Block a user