mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype - Update dtype propagation to use `type_to_dtype` instead of instantiating tensor - Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ). Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991 Approved by: https://github.com/drisspg ghstack dependencies: #139945, #140057, #141495, #141882
This commit is contained in:
committed by
PyTorch MergeBot
parent
920e4364b7
commit
fd35be2fd3
@ -2185,6 +2185,34 @@ def normalize_name(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
# correct cases where Triton types names don't match PyTorch
|
||||
_triton_type_mapping = {
|
||||
"tl.bool": "tl.int1",
|
||||
"tl.float8_e4m3fn": "tl.float8e4nv",
|
||||
"tl.float8_e5m2": "tl.float8e5",
|
||||
"tl.float8_e4m3fnuz": "tl.float8e4b8",
|
||||
"tl.float8_e5m2fnuz": "tl.float8e5b16",
|
||||
}
|
||||
_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
|
||||
|
||||
|
||||
_triton_type_re = re.compile(r"^.*[.]")
|
||||
|
||||
|
||||
def triton_type(dtype: torch.dtype) -> str:
|
||||
"""Convert torch.dtype to triton type"""
|
||||
triton_type_name = _triton_type_re.sub("tl.", str(dtype))
|
||||
return _triton_type_mapping.get(triton_type_name, triton_type_name)
|
||||
|
||||
|
||||
def triton_type_to_torch(dtype: str) -> torch.dtype:
|
||||
adjusted_type = _torch_triton_mapping.get(dtype, dtype)
|
||||
type_name = adjusted_type.replace("tl.", "")
|
||||
out_dtype = getattr(torch, type_name)
|
||||
assert isinstance(out_dtype, torch.dtype)
|
||||
return out_dtype
|
||||
|
||||
|
||||
def is_same_tensor(data: torch.Tensor, value: torch.Tensor):
|
||||
return (
|
||||
not data.is_mkldnn
|
||||
|
Reference in New Issue
Block a user