mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D24924736: [pytorch][PR] Hipify revamp
Test Plan: revert-hammer
Differential Revision:
D24924736 (10b490a3e0
)
Original commit changeset: 4af42b8ff4f2
fbshipit-source-id: 7f8f90d55d8a69a2890ec73622fcea559189e381
This commit is contained in:
committed by
Facebook GitHub Bot
parent
68a3a3f3b5
commit
8af9f2cc23
@ -1206,7 +1206,7 @@ if(USE_ROCM)
|
||||
endforeach()
|
||||
|
||||
set(Caffe2_HIP_INCLUDE
|
||||
$<INSTALL_INTERFACE:include> ${Caffe2_HIP_INCLUDE})
|
||||
${thrust_INCLUDE_DIRS} ${hipcub_INCLUDE_DIRS} ${rocprim_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${roctracer_INCLUDE_DIRS} ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} $<INSTALL_INTERFACE:include> ${Caffe2_HIP_INCLUDE})
|
||||
# This is needed for library added by hip_add_library (same for hip_add_executable)
|
||||
hip_include_directories(${Caffe2_HIP_INCLUDE})
|
||||
|
||||
|
@ -205,4 +205,9 @@ if(HIP_FOUND)
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
|
||||
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
|
||||
|
||||
# Necessary includes for building PyTorch since we include HIP headers that depend on hcc/hsa headers.
|
||||
set(hcc_INCLUDE_DIRS ${HCC_PATH}/include)
|
||||
set(hsa_INCLUDE_DIRS ${HSA_PATH}/include)
|
||||
|
||||
endif()
|
||||
|
@ -29,7 +29,7 @@ ext_modules = [
|
||||
extra_compile_args=CXX_FLAGS),
|
||||
]
|
||||
|
||||
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.cuda', [
|
||||
'cuda_extension.cpp',
|
||||
@ -39,6 +39,22 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
elif torch.cuda.is_available() and ROCM_HOME is not None:
|
||||
from torch.utils.hipify import hipify_python
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
hipify_python.hipify(
|
||||
project_directory=this_dir,
|
||||
output_directory=this_dir,
|
||||
includes="./*",
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,)
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.cuda', [
|
||||
'cuda_extension.cpp',
|
||||
'hip/hip_extension_kernel.hip',
|
||||
'hip/hip_extension_kernel2.hip',
|
||||
])
|
||||
ext_modules.append(extension)
|
||||
|
||||
if not IS_WINDOWS: # MSVC has bug compiling this example
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
|
@ -830,27 +830,6 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
kwargs['libraries'] = libraries
|
||||
|
||||
include_dirs = kwargs.get('include_dirs', [])
|
||||
|
||||
if IS_HIP_EXTENSION:
|
||||
build_dir = os.getcwd()
|
||||
if not include_dirs:
|
||||
include_dirs = ['*']
|
||||
hipify_result = hipify_python.hipify(
|
||||
project_directory=build_dir,
|
||||
output_directory=build_dir,
|
||||
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs],
|
||||
extra_files=[os.path.abspath(s) for s in sources],
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
)
|
||||
|
||||
hipified_sources = set()
|
||||
for source in sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
|
||||
|
||||
sources = list(hipified_sources)
|
||||
|
||||
include_dirs += include_paths(cuda=True)
|
||||
kwargs['include_dirs'] = include_dirs
|
||||
|
||||
@ -883,11 +862,9 @@ def include_paths(cuda: bool = False) -> List[str]:
|
||||
]
|
||||
if cuda and IS_HIP_EXTENSION:
|
||||
paths.append(os.path.join(lib_include, 'THH'))
|
||||
rocm_include_path = _join_rocm_home('include')
|
||||
paths.append(rocm_include_path)
|
||||
paths.append(_join_rocm_home('include'))
|
||||
if MIOPEN_HOME is not None:
|
||||
paths.append(os.path.join(MIOPEN_HOME, 'include'))
|
||||
paths.extend([f.path for f in os.scandir(rocm_include_path) if f.is_dir()])
|
||||
elif cuda:
|
||||
cuda_home_include = _join_cuda_home('include')
|
||||
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
|
||||
@ -1664,7 +1641,7 @@ def _write_ninja_file_to_build_library(path,
|
||||
cuda_flags += _get_rocm_arch_flags(cuda_flags)
|
||||
sources = [s if not _is_cuda_file(s) else
|
||||
os.path.abspath(os.path.join(
|
||||
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
|
||||
path, get_hip_file_path(os.path.relpath(s, path))))
|
||||
for s in sources]
|
||||
elif with_cuda:
|
||||
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
|
||||
|
@ -1 +0,0 @@
|
||||
from .version import __version__
|
||||
|
@ -552,26 +552,26 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||
("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)),
|
||||
("cublas.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||
("cublas_v2.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||
("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
||||
("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete2.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_globals.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_lognormal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mrg32k3a.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_host.h", ("hiprand/hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand.h", ("hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
||||
("curand_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete2.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_globals.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_lognormal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mrg32k3a.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_host.h", ("hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
(
|
||||
"curand_mtgp32dc_p_11213.h",
|
||||
("rocrand/rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND),
|
||||
("rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND),
|
||||
),
|
||||
("curand_normal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal_static.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_philox4x32_x.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_poisson.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal_static.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_philox4x32_x.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_poisson.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_precalc.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_uniform.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)),
|
||||
("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
|
||||
("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
|
||||
@ -586,7 +586,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||
("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("nvToolsExt.h", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)),
|
||||
("nvToolsExt.h", ("roctx.h", CONV_INCLUDE, API_ROCTX)),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -34,12 +34,8 @@ from . import constants
|
||||
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
|
||||
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
|
||||
|
||||
from typing import Dict, List, Iterator, Optional
|
||||
from collections.abc import Mapping, Iterable
|
||||
HipifyResult = Dict[str, Optional[str]]
|
||||
HipifyFinalResult = Dict[str, HipifyResult]
|
||||
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
|
||||
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
|
||||
from typing import Dict, List
|
||||
from collections.abc import Mapping
|
||||
|
||||
# Hardcode the PyTorch template map
|
||||
"""This dictionary provides the mapping from PyTorch kernel template types
|
||||
@ -113,20 +109,14 @@ class GeneratedFileCleaner:
|
||||
for d in self.dirs_to_clean[::-1]:
|
||||
os.rmdir(d)
|
||||
|
||||
def match_extensions(filename: str, extensions: Iterable) -> bool:
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return any(filename.endswith(e) for e in extensions)
|
||||
|
||||
def matched_files_iter(
|
||||
root_path: str,
|
||||
includes: Iterable = ('*',),
|
||||
ignores: Iterable = (),
|
||||
extensions: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
is_pytorch_extension: bool = False) -> Iterator[str]:
|
||||
def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), out_of_place_only=False, is_pytorch_extension=False):
|
||||
def _fnmatch(filepath, patterns):
|
||||
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
|
||||
|
||||
def match_extensions(filename):
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return any(filename.endswith(e) for e in extensions)
|
||||
|
||||
exact_matches = set(includes)
|
||||
|
||||
# This is a very rough heuristic; really, we want to avoid scanning
|
||||
@ -151,7 +141,7 @@ def matched_files_iter(
|
||||
if (
|
||||
_fnmatch(filepath, includes)
|
||||
and (not _fnmatch(filepath, ignores))
|
||||
and (match_extensions(filepath, extensions) or filepath in exact_matches)
|
||||
and (match_extensions(filepath) or filepath in exact_matches)
|
||||
):
|
||||
if not is_pytorch_extension: # for pytorch extensions, consider all files
|
||||
if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
|
||||
@ -161,39 +151,14 @@ def matched_files_iter(
|
||||
yield filepath
|
||||
|
||||
|
||||
def preprocess_file_and_save_result(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> None:
|
||||
result = preprocessor(output_directory, filepath, all_files, includes, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
fin_path = os.path.join(output_directory, filepath)
|
||||
# Show what happened
|
||||
if show_progress:
|
||||
print(
|
||||
fin_path, "->",
|
||||
result["hipified_path"], result["status"])
|
||||
|
||||
if result["hipified_path"] is not None:
|
||||
HIPIFY_FINAL_RESULT[fin_path] = result
|
||||
|
||||
|
||||
def preprocess(
|
||||
output_directory: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
show_detailed: bool = False,
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
clean_ctx: GeneratedFileCleaner = None) -> HipifyFinalResult:
|
||||
output_directory,
|
||||
all_files,
|
||||
show_detailed=False,
|
||||
show_progress=True,
|
||||
hip_clang_launch=False,
|
||||
is_pytorch_extension=False,
|
||||
clean_ctx=None):
|
||||
"""
|
||||
Call preprocessor on selected files.
|
||||
|
||||
@ -208,8 +173,13 @@ def preprocess(
|
||||
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
|
||||
|
||||
for filepath in all_files:
|
||||
preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx)
|
||||
|
||||
# Show what happened
|
||||
if show_progress:
|
||||
print(
|
||||
filepath, "->",
|
||||
get_hip_file_path(filepath), result)
|
||||
|
||||
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
|
||||
|
||||
@ -217,8 +187,6 @@ def preprocess(
|
||||
if show_detailed:
|
||||
compute_stats(stats)
|
||||
|
||||
return HIPIFY_FINAL_RESULT
|
||||
|
||||
|
||||
def compute_stats(stats):
|
||||
unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
|
||||
@ -509,13 +477,13 @@ def replace_extern_shared(input_string):
|
||||
return output_string
|
||||
|
||||
|
||||
def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
def get_hip_file_path(filepath):
|
||||
"""
|
||||
Returns the new name of the hipified file
|
||||
"""
|
||||
# At the moment, some PyTorch source files are HIPified in place. The predicate
|
||||
# At the moment, some files are HIPified in place. The predicate
|
||||
# is_out_of_place tells us if this is the case or not.
|
||||
if not is_pytorch_extension and not is_out_of_place(filepath):
|
||||
if not is_out_of_place(filepath):
|
||||
return filepath
|
||||
|
||||
dirpath, filename = os.path.split(filepath)
|
||||
@ -524,8 +492,10 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
# Here's the plan:
|
||||
#
|
||||
# In general, we need to disambiguate the HIPified filename so that
|
||||
# it gets a different name from the original filename, so
|
||||
# that we don't overwrite the original file
|
||||
# it gets a different name from the original Caffe2 filename, so
|
||||
# that we don't overwrite the original file. (Additionally,
|
||||
# hcc historically had a bug where if you had two files with
|
||||
# the same basename, they would clobber each other.)
|
||||
#
|
||||
# There's a lot of different naming conventions across PyTorch
|
||||
# and Caffe2, but the general recipe is to convert occurrences
|
||||
@ -539,18 +509,12 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
#
|
||||
# - If the file name contains "CUDA", replace it with "HIP", AND
|
||||
#
|
||||
# - ALWAYS replace '.cu' with '.hip', because those files
|
||||
# contain CUDA kernels that needs to be hipified and processed with
|
||||
# hip compiler
|
||||
# If NONE of the above occurred, then insert "hip" in the file path
|
||||
# as the direct parent folder of the file
|
||||
#
|
||||
# - If we are not hipifying a PyTorch extension, and the parent
|
||||
# directory name did not change as a result of the above
|
||||
# transformations, insert "hip" in the file path
|
||||
# as the direct parent folder of the file
|
||||
#
|
||||
# - If we are hipifying a PyTorch extension, and the parent directory
|
||||
# name as well as the filename (incl. extension) did not change as
|
||||
# a result of the above transformations, insert "_hip" in the filename
|
||||
# Furthermore, ALWAYS replace '.cu' with '.hip', because those files
|
||||
# contain CUDA kernels that needs to be hipified and processed with
|
||||
# hcc compiler
|
||||
#
|
||||
# This isn't set in stone; we might adjust this to support other
|
||||
# naming conventions.
|
||||
@ -558,7 +522,6 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
if ext == '.cu':
|
||||
ext = '.hip'
|
||||
|
||||
orig_filename = filename
|
||||
orig_dirpath = dirpath
|
||||
|
||||
dirpath = dirpath.replace('cuda', 'hip')
|
||||
@ -570,12 +533,9 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
if dirpath != "caffe2/core":
|
||||
root = root.replace('THC', 'THH')
|
||||
|
||||
if not is_pytorch_extension and dirpath == orig_dirpath:
|
||||
if dirpath == orig_dirpath:
|
||||
dirpath = os.path.join(dirpath, 'hip')
|
||||
|
||||
if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
|
||||
root = root + "_hip"
|
||||
|
||||
return os.path.join(dirpath, root + ext)
|
||||
|
||||
|
||||
@ -693,35 +653,13 @@ RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
|
||||
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
|
||||
RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
|
||||
|
||||
"""
|
||||
Returns a dict with the following keys:
|
||||
"hipified_path" : absolute path of hipified source file
|
||||
"status" : "ok" if hipified file was written out
|
||||
"skipped" if an identical hipified file already existed
|
||||
"ignored" if the source file was a hipified file itself
|
||||
"""
|
||||
def preprocessor(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> HipifyResult:
|
||||
def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx):
|
||||
""" Executes the CUDA -> HIP conversion on the specified file. """
|
||||
fin_path = os.path.join(output_directory, filepath)
|
||||
|
||||
with open(fin_path, 'r', encoding='utf-8') as fin:
|
||||
if fin.readline() == HIPIFY_C_BREADCRUMB:
|
||||
return {"hipified_path": None, "status": "ignored"}
|
||||
fin.seek(0)
|
||||
output_source = fin.read()
|
||||
|
||||
orig_output_source = output_source
|
||||
|
||||
fout_path = os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension))
|
||||
fout_path = os.path.join(output_directory, get_hip_file_path(filepath))
|
||||
if not os.path.exists(os.path.dirname(fout_path)):
|
||||
clean_ctx.makedirs(os.path.dirname(fout_path))
|
||||
|
||||
@ -740,10 +678,9 @@ def preprocessor(
|
||||
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
|
||||
|
||||
# Header rewrites
|
||||
def mk_repl(templ, include_current_dir=True):
|
||||
def mk_repl(templ):
|
||||
def repl(m):
|
||||
f = m.group(1)
|
||||
dirpath, filename = os.path.split(f)
|
||||
if (
|
||||
f.startswith("ATen/cuda")
|
||||
or f.startswith("ATen/native/cuda")
|
||||
@ -753,41 +690,11 @@ def preprocessor(
|
||||
or f.startswith("THCUNN/")
|
||||
or (f.startswith("THC") and not f.startswith("THCP"))
|
||||
):
|
||||
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
|
||||
# if filename is one of the files being hipified for this extension
|
||||
if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
|
||||
header_dir = None
|
||||
header_filepath = None
|
||||
# If include_current_dir True, look first in same dir as the including source file
|
||||
if include_current_dir:
|
||||
header_dir_to_check = os.path.dirname(fin_path)
|
||||
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
|
||||
if os.path.exists(header_path_to_check):
|
||||
header_dir = header_dir_to_check
|
||||
header_filepath = header_path_to_check
|
||||
# If not found, look in include dirs one by one and first match wins
|
||||
if header_filepath is None:
|
||||
for include in includes:
|
||||
header_dir_to_check = os.path.join(output_directory, os.path.dirname(include))
|
||||
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
|
||||
if os.path.exists(header_path_to_check):
|
||||
header_dir = header_dir_to_check
|
||||
header_filepath = header_path_to_check
|
||||
# If header file not found, keep as is
|
||||
if header_filepath is None:
|
||||
return m.group(0)
|
||||
# Hipify header file first if needed
|
||||
if header_filepath not in HIPIFY_FINAL_RESULT:
|
||||
preprocess_file_and_save_result(output_directory,
|
||||
os.path.relpath(header_filepath, output_directory),
|
||||
all_files, includes, stats, hip_clang_launch, is_pytorch_extension,
|
||||
clean_ctx, show_progress)
|
||||
return templ.format(os.path.relpath(HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"], header_dir))
|
||||
|
||||
return templ.format(get_hip_file_path(m.group(1)))
|
||||
return m.group(0)
|
||||
return repl
|
||||
output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
|
||||
output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
|
||||
output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"'), output_source)
|
||||
output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>'), output_source)
|
||||
output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
|
||||
|
||||
# CMakeLists.txt rewrites
|
||||
@ -810,18 +717,6 @@ def preprocessor(
|
||||
# Replace the extern __shared__
|
||||
output_source = replace_extern_shared(output_source)
|
||||
|
||||
# Don't write out identical hipified files for extensions if dirpath has not changed
|
||||
if (
|
||||
is_pytorch_extension
|
||||
and orig_output_source == output_source
|
||||
and os.path.dirname(fin_path) == os.path.dirname(fout_path)
|
||||
):
|
||||
return {"hipified_path": fin_path, "status": "ok"}
|
||||
|
||||
# Add hipify breadcrumb for C-style files to avoid re-hipification
|
||||
if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
|
||||
output_source = HIPIFY_C_BREADCRUMB + output_source
|
||||
|
||||
do_write = True
|
||||
if os.path.exists(fout_path):
|
||||
with open(fout_path, 'r', encoding='utf-8') as fout_old:
|
||||
@ -829,9 +724,9 @@ def preprocessor(
|
||||
if do_write:
|
||||
with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
|
||||
fout.write(output_source)
|
||||
return {"hipified_path": fout_path, "status": "ok"}
|
||||
return "ok"
|
||||
else:
|
||||
return {"hipified_path": fout_path, "status": "skipped"}
|
||||
return "skipped"
|
||||
|
||||
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
|
||||
with openf(filepath, "r+") as f:
|
||||
@ -923,19 +818,19 @@ def str2bool(v):
|
||||
|
||||
|
||||
def hipify(
|
||||
project_directory: str,
|
||||
show_detailed: bool = False,
|
||||
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
|
||||
output_directory: str = "",
|
||||
includes: Iterable = (),
|
||||
extra_files: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
ignores: Iterable = (),
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
clean_ctx: GeneratedFileCleaner = None
|
||||
) -> HipifyFinalResult:
|
||||
project_directory,
|
||||
show_detailed=False,
|
||||
extensions=(".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
|
||||
output_directory="",
|
||||
includes=(),
|
||||
extra_files=(),
|
||||
out_of_place_only=False,
|
||||
ignores=(),
|
||||
show_progress=True,
|
||||
hip_clang_launch=False,
|
||||
is_pytorch_extension=False,
|
||||
clean_ctx=None
|
||||
):
|
||||
if project_directory == "":
|
||||
project_directory = os.getcwd()
|
||||
|
||||
@ -958,17 +853,12 @@ def hipify(
|
||||
out_of_place_only=out_of_place_only,
|
||||
is_pytorch_extension=is_pytorch_extension))
|
||||
all_files_set = set(all_files)
|
||||
# Convert extra_files to relative paths since all_files has all relative paths
|
||||
for f in extra_files:
|
||||
f_rel = os.path.relpath(f, output_directory)
|
||||
if f_rel not in all_files_set:
|
||||
all_files.append(f_rel)
|
||||
all_files += [f for f in extra_files if f not in all_files_set]
|
||||
|
||||
# Start Preprocessor
|
||||
return preprocess(
|
||||
preprocess(
|
||||
output_directory,
|
||||
all_files,
|
||||
includes,
|
||||
show_detailed=show_detailed,
|
||||
show_progress=show_progress,
|
||||
hip_clang_launch=hip_clang_launch,
|
||||
|
@ -1 +0,0 @@
|
||||
__version__ = '1.0.0'
|
Reference in New Issue
Block a user