mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "[ROCm] remove caffe2 from hipify (#137157)"
This reverts commit 40d826074546558f6665a4c118335a7725503cac. Reverted https://github.com/pytorch/pytorch/pull/137157 on behalf of https://github.com/xw285cornell due to this is breaking internal where we still use caffe2 ([comment](https://github.com/pytorch/pytorch/pull/137157#issuecomment-2400466131))
This commit is contained in:
@ -31,6 +31,7 @@ import shutil
|
||||
import sys
|
||||
import os
|
||||
|
||||
from . import constants
|
||||
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
|
||||
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
|
||||
|
||||
@ -64,7 +65,7 @@ __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_exte
|
||||
'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
|
||||
'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
|
||||
'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
|
||||
'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
|
||||
'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
|
||||
'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify']
|
||||
|
||||
|
||||
@ -185,6 +186,11 @@ def matched_files_iter(
|
||||
and (not _fnmatch(filepath, ignores))
|
||||
and (match_extensions(filepath, extensions) or filepath in exact_matches)
|
||||
):
|
||||
if not is_pytorch_extension: # for pytorch extensions, consider all files
|
||||
if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
|
||||
continue
|
||||
if out_of_place_only and not is_out_of_place(rel_filepath):
|
||||
continue
|
||||
yield filepath
|
||||
|
||||
|
||||
@ -553,8 +559,8 @@ def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
|
||||
# it gets a different name from the original filename, so
|
||||
# that we don't overwrite the original file
|
||||
#
|
||||
# There's a lot of different naming conventions across PyTorch,
|
||||
# but the general recipe is to convert occurrences
|
||||
# There's a lot of different naming conventions across PyTorch
|
||||
# and Caffe2, but the general recipe is to convert occurrences
|
||||
# of cuda/gpu to hip, and add hip if there are no occurrences
|
||||
# of cuda/gpu anywhere.
|
||||
#
|
||||
@ -617,7 +623,7 @@ def is_out_of_place(rel_filepath):
|
||||
return True
|
||||
|
||||
|
||||
# deprecated
|
||||
# Keep this synchronized with includes/ignores in build_amd.py
|
||||
def is_pytorch_file(rel_filepath):
|
||||
assert not os.path.isabs(rel_filepath)
|
||||
if rel_filepath.startswith("aten/"):
|
||||
@ -633,14 +639,12 @@ def is_pytorch_file(rel_filepath):
|
||||
return False
|
||||
|
||||
|
||||
# deprecated
|
||||
def is_cusparse_file(rel_filepath):
|
||||
if is_pytorch_file(rel_filepath):
|
||||
return "sparse" in rel_filepath.lower()
|
||||
return False
|
||||
|
||||
|
||||
# deprecated
|
||||
def is_special_file(rel_filepath):
|
||||
if is_pytorch_file(rel_filepath):
|
||||
if "sparse" in rel_filepath.lower():
|
||||
@ -651,8 +655,6 @@ def is_special_file(rel_filepath):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# deprecated
|
||||
def is_caffe2_gpu_file(rel_filepath):
|
||||
assert not os.path.isabs(rel_filepath)
|
||||
if rel_filepath.startswith("c10/cuda"):
|
||||
@ -661,7 +663,6 @@ def is_caffe2_gpu_file(rel_filepath):
|
||||
_, ext = os.path.splitext(filename)
|
||||
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
||||
|
||||
|
||||
class TrieNode:
|
||||
"""A Trie node whose children are represented as a directory of char: TrieNode.
|
||||
A special char '' represents end of word
|
||||
@ -670,7 +671,6 @@ class TrieNode:
|
||||
def __init__(self):
|
||||
self.children = {}
|
||||
|
||||
|
||||
class Trie:
|
||||
"""Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
|
||||
The corresponding Regex should match much faster than a simple Regex union."""
|
||||
@ -756,15 +756,38 @@ class Trie:
|
||||
"""Export the Trie to a regex pattern."""
|
||||
return self._pattern(self.root)
|
||||
|
||||
CAFFE2_TRIE = Trie()
|
||||
CAFFE2_MAP = {}
|
||||
PYTORCH_TRIE = Trie()
|
||||
PYTORCH_MAP: Dict[str, object] = {}
|
||||
|
||||
# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
|
||||
# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
|
||||
# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
|
||||
# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
|
||||
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
|
||||
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
|
||||
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
|
||||
# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
|
||||
PYTORCH_SPECIAL_MAP = {}
|
||||
|
||||
for mapping in CUDA_TO_HIP_MAPPINGS:
|
||||
assert isinstance(mapping, Mapping)
|
||||
for src, dst in mapping.items():
|
||||
PYTORCH_TRIE.add(src)
|
||||
PYTORCH_MAP[src] = dst
|
||||
|
||||
for src, value in mapping.items():
|
||||
dst = value[0]
|
||||
meta_data = value[1:]
|
||||
if constants.API_CAFFE2 not in meta_data:
|
||||
PYTORCH_TRIE.add(src)
|
||||
# if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
|
||||
# do not overwrite PYTORCH_MAP, store dst separately
|
||||
if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
|
||||
PYTORCH_SPECIAL_MAP[src] = dst
|
||||
else:
|
||||
PYTORCH_MAP[src] = dst
|
||||
if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
|
||||
CAFFE2_TRIE.add(src)
|
||||
CAFFE2_MAP[src] = dst
|
||||
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
|
||||
RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)')
|
||||
|
||||
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
|
||||
@ -824,7 +847,22 @@ def preprocessor(
|
||||
def pt_repl(m):
|
||||
return PYTORCH_MAP[m.group(0)]
|
||||
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
def pt_special_repl(m):
|
||||
# checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
|
||||
return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
|
||||
|
||||
|
||||
if is_pytorch_extension:
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
if is_special_file(rel_filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
|
||||
elif is_pytorch_file(rel_filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
def c2_repl(m):
|
||||
return CAFFE2_MAP[m.group(0)]
|
||||
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
|
||||
|
||||
# Header rewrites
|
||||
def mk_repl(templ, include_current_dir=True):
|
||||
|
Reference in New Issue
Block a user