Enable hipSOLVER in ROCm builds (#97370)

Enables the hipSolver backend for ROCm builds
--------------------------------------------------------------------------

- Minimum ROCm version requirement - 5.3
- Introduces new macro USE_LINALG_SOLVER the controls enablement of both cuSOLVER and hipSOLVER
- Adds hipSOLVER API to hipification process
- combines hipSOLVER and hipSPARSE mappings into single SPECIAL map that takes priority among normal mappings
- Torch api to be moved to hipsolver backend (as opposed to magma) include: torch.svd(), torch.geqrf(), torch.orgqr(), torch.ormqr()
- Will enable 100+ linalg unit tests for ROCm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97370
Approved by: https://github.com/malfet
This commit is contained in:
Andres Lugo-Reyes
2023-05-31 16:53:23 +00:00
committed by PyTorch MergeBot
parent 46a925795e
commit eaffd98880
19 changed files with 568 additions and 251 deletions

View File

@ -49,7 +49,7 @@ PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file',
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']
@ -625,6 +625,11 @@ def is_cusparse_file(rel_filepath):
return False
def is_special_file(rel_filepath):
if is_pytorch_file(rel_filepath):
return ("sparse" in rel_filepath.lower()) or ("linalg" in rel_filepath.lower())
return False
def is_caffe2_gpu_file(rel_filepath):
assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("c10/cuda"):
@ -708,7 +713,8 @@ PYTORCH_MAP: Dict[str, object] = {}
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
PYTORCH_SPARSE_MAP = {}
# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
PYTORCH_SPECIAL_MAP = {}
for mapping in CUDA_TO_HIP_MAPPINGS:
assert isinstance(mapping, Mapping)
@ -717,10 +723,10 @@ for mapping in CUDA_TO_HIP_MAPPINGS:
meta_data = value[1:]
if constants.API_CAFFE2 not in meta_data:
PYTORCH_TRIE.add(src)
# if src is already in PYTORCH_MAP and dst belongs to API_SPARSE
# if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
# do not overwrite PYTORCH_MAP, store dst separately
if constants.API_SPARSE in meta_data and PYTORCH_MAP.get(src, ""):
PYTORCH_SPARSE_MAP[src] = dst
if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
PYTORCH_SPECIAL_MAP[src] = dst
else:
PYTORCH_MAP[src] = dst
if constants.API_PYTORCH not in meta_data:
@ -777,15 +783,16 @@ def preprocessor(
def pt_repl(m):
return PYTORCH_MAP[m.group(0)]
def pt_sparse_repl(m):
# checks SPARSE map first, and if a miss occurs, falls back to pytorch mappings
return PYTORCH_SPARSE_MAP.get(m.group(0), pt_repl(m))
def pt_special_repl(m):
# checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
if is_pytorch_extension:
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
if is_cusparse_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
if is_special_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
elif is_pytorch_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else: