TritonTemplate dtype fixes (#141991)

- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
This commit is contained in:
eellison
2024-12-03 13:52:56 -08:00
committed by PyTorch MergeBot
parent 920e4364b7
commit fd35be2fd3
5 changed files with 48 additions and 22 deletions

View File

@ -2185,6 +2185,34 @@ def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
# correct cases where Triton types names don't match PyTorch
_triton_type_mapping = {
"tl.bool": "tl.int1",
"tl.float8_e4m3fn": "tl.float8e4nv",
"tl.float8_e5m2": "tl.float8e5",
"tl.float8_e4m3fnuz": "tl.float8e4b8",
"tl.float8_e5m2fnuz": "tl.float8e5b16",
}
_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
_triton_type_re = re.compile(r"^.*[.]")
def triton_type(dtype: torch.dtype) -> str:
"""Convert torch.dtype to triton type"""
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
return _triton_type_mapping.get(triton_type_name, triton_type_name)
def triton_type_to_torch(dtype: str) -> torch.dtype:
adjusted_type = _torch_triton_mapping.get(dtype, dtype)
type_name = adjusted_type.replace("tl.", "")
out_dtype = getattr(torch, type_name)
assert isinstance(out_dtype, torch.dtype)
return out_dtype
def is_same_tensor(data: torch.Tensor, value: torch.Tensor):
return (
not data.is_mkldnn