mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ROCm] Add sparse mappings for CUDA->HIP translation (#67323)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67323 Applied patch proposed by Jeff https://github.com/pytorch/pytorch/pull/63948#issuecomment-952166982. In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip. The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance. Unfortunately, the `roc*` types and `hip*` types differ, i.e., `rocblas_float_complex` versus `hipComplex`. In the case of SPARSE, we must use the hip types for complex instead of the roc types, but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority. Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place. When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices. cc jeffdaily sunway513 jithunnair-amd ROCmSupport KyleCZH Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D31969246 Pulled By: cpuhrsch fbshipit-source-id: 4ce1b35eaf9ef0d146a0955ce70c354ddd8f4669
This commit is contained in:
committed by
Facebook GitHub Bot
parent
708f7b1209
commit
2267a984eb
@ -600,6 +600,11 @@ def is_pytorch_file(filepath):
|
||||
return False
|
||||
|
||||
|
||||
def is_cusparse_file(filepath):
|
||||
if is_pytorch_file(filepath):
|
||||
return "sparse" in filepath.lower()
|
||||
return False
|
||||
|
||||
def is_caffe2_gpu_file(filepath):
|
||||
if filepath.startswith("c10/cuda"):
|
||||
return True
|
||||
@ -673,7 +678,17 @@ class Trie():
|
||||
CAFFE2_TRIE = Trie()
|
||||
CAFFE2_MAP = {}
|
||||
PYTORCH_TRIE = Trie()
|
||||
PYTORCH_MAP = {}
|
||||
PYTORCH_MAP: Dict[str, object] = {}
|
||||
|
||||
# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
|
||||
# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
|
||||
# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
|
||||
# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
|
||||
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
|
||||
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
|
||||
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
|
||||
PYTORCH_SPARSE_MAP = {}
|
||||
|
||||
for mapping in CUDA_TO_HIP_MAPPINGS:
|
||||
assert isinstance(mapping, Mapping)
|
||||
for src, value in mapping.items():
|
||||
@ -681,7 +696,12 @@ for mapping in CUDA_TO_HIP_MAPPINGS:
|
||||
meta_data = value[1:]
|
||||
if constants.API_CAFFE2 not in meta_data:
|
||||
PYTORCH_TRIE.add(src)
|
||||
PYTORCH_MAP[src] = dst
|
||||
# if src is already in PYTORCH_MAP and dst belongs to API_SPARSE
|
||||
# do not overwrite PYTORCH_MAP, store dst separately
|
||||
if constants.API_SPARSE in meta_data and PYTORCH_MAP.get(src, ""):
|
||||
PYTORCH_SPARSE_MAP[src] = dst
|
||||
else:
|
||||
PYTORCH_MAP[src] = dst
|
||||
if constants.API_PYTORCH not in meta_data:
|
||||
CAFFE2_TRIE.add(src)
|
||||
CAFFE2_MAP[src] = dst
|
||||
@ -729,10 +749,16 @@ def preprocessor(
|
||||
def pt_repl(m):
|
||||
return PYTORCH_MAP[m.group(0)]
|
||||
|
||||
def pt_sparse_repl(m):
|
||||
# checks SPARSE map first, and if a miss occurs, falls back to pytorch mappings
|
||||
return PYTORCH_SPARSE_MAP.get(m.group(0), pt_repl(m))
|
||||
|
||||
if is_pytorch_extension:
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
if is_pytorch_file(filepath):
|
||||
if is_cusparse_file(filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
|
||||
elif is_pytorch_file(filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
def c2_repl(m):
|
||||
|
Reference in New Issue
Block a user