mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Hipify fixes for a successful DeepSpeed build
These commits are required to build DeepSpeed on ROCm without the hipify errors.a41829d9ed
663c718462
cc: @jeffdaily Pull Request resolved: https://github.com/pytorch/pytorch/pull/76141 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbc263eb5d
commit
7422ccea8b
@ -99,6 +99,8 @@ includes = [
|
||||
"tools/autograd/templates/python_variable_methods.cpp",
|
||||
]
|
||||
|
||||
includes = [os.path.join(proj_dir, include) for include in includes]
|
||||
|
||||
for new_dir in args.extra_include_dir:
|
||||
abs_new_dir = os.path.join(proj_dir, new_dir)
|
||||
if os.path.exists(abs_new_dir):
|
||||
@ -122,6 +124,8 @@ ignores = [
|
||||
"torch/include/*",
|
||||
]
|
||||
|
||||
ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]
|
||||
|
||||
# Check if the compiler is hip-clang.
|
||||
def is_hip_clang() -> bool:
|
||||
try:
|
||||
|
@ -17,7 +17,7 @@ import torch._appdirs
|
||||
from .file_baton import FileBaton
|
||||
from ._cpp_extension_versioner import ExtensionVersioner
|
||||
from .hipify import hipify_python
|
||||
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
|
||||
from .hipify.hipify_python import GeneratedFileCleaner
|
||||
from typing import List, Optional, Union, Tuple
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
@ -1010,16 +1010,19 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
hipify_result = hipify_python.hipify(
|
||||
project_directory=build_dir,
|
||||
output_directory=build_dir,
|
||||
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'],
|
||||
header_include_dirs=include_dirs,
|
||||
includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only
|
||||
extra_files=[os.path.abspath(s) for s in sources],
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True, # don't hipify everything in includes path
|
||||
)
|
||||
|
||||
hipified_sources = set()
|
||||
for source in sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
|
||||
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and
|
||||
hipify_result[s_abs]["hipified_path"] is not None) else s_abs)
|
||||
|
||||
sources = list(hipified_sources)
|
||||
|
||||
@ -1400,15 +1403,25 @@ def _jit_compile(name,
|
||||
try:
|
||||
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
|
||||
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
|
||||
hipify_python.hipify(
|
||||
hipify_result = hipify_python.hipify(
|
||||
project_directory=build_directory,
|
||||
output_directory=build_directory,
|
||||
includes=os.path.join(build_directory, '*'),
|
||||
header_include_dirs=(extra_include_paths if extra_include_paths is not None else []),
|
||||
extra_files=[os.path.abspath(s) for s in sources],
|
||||
ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers
|
||||
show_detailed=verbose,
|
||||
show_progress=verbose,
|
||||
is_pytorch_extension=True,
|
||||
clean_ctx=clean_ctx
|
||||
)
|
||||
|
||||
hipified_sources = set()
|
||||
for source in sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
|
||||
|
||||
sources = list(hipified_sources)
|
||||
|
||||
_write_ninja_file_and_build_library(
|
||||
name=name,
|
||||
sources=sources,
|
||||
@ -1904,10 +1917,6 @@ def _write_ninja_file_to_build_library(path,
|
||||
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
|
||||
cuda_flags += extra_cuda_cflags
|
||||
cuda_flags += _get_rocm_arch_flags(cuda_flags)
|
||||
sources = [s if not _is_cuda_file(s) else
|
||||
os.path.abspath(os.path.join(
|
||||
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
|
||||
for s in sources]
|
||||
elif with_cuda:
|
||||
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
|
||||
if IS_WINDOWS:
|
||||
@ -2018,6 +2027,8 @@ def _write_ninja_file(path,
|
||||
nvcc = _join_cuda_home('bin', 'nvcc')
|
||||
config.append(f'nvcc = {nvcc}')
|
||||
|
||||
if IS_HIP_EXTENSION:
|
||||
post_cflags = COMMON_HIP_FLAGS + post_cflags
|
||||
flags = [f'cflags = {" ".join(cflags)}']
|
||||
flags.append(f'post_cflags = {" ".join(post_cflags)}')
|
||||
if with_cuda:
|
||||
|
204
torch/utils/hipify/hipify_python.py
Normal file → Executable file
204
torch/utils/hipify/hipify_python.py
Normal file → Executable file
@ -117,15 +117,16 @@ def match_extensions(filename: str, extensions: Iterable) -> bool:
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return any(filename.endswith(e) for e in extensions)
|
||||
|
||||
def _fnmatch(filepath, patterns):
|
||||
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
|
||||
|
||||
def matched_files_iter(
|
||||
root_path: str,
|
||||
includes: Iterable = ('*',),
|
||||
includes: Iterable = (),
|
||||
ignores: Iterable = (),
|
||||
extensions: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
is_pytorch_extension: bool = False) -> Iterator[str]:
|
||||
def _fnmatch(filepath, patterns):
|
||||
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
|
||||
|
||||
exact_matches = set(includes)
|
||||
|
||||
@ -145,7 +146,8 @@ def matched_files_iter(
|
||||
if "third_party" in dirs:
|
||||
dirs.remove("third_party")
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(rel_dirpath, filename)
|
||||
filepath = os.path.join(abs_dirpath, filename)
|
||||
rel_filepath = os.path.join(rel_dirpath, filename)
|
||||
# We respect extensions, UNLESS you wrote the entire
|
||||
# filename verbatim, in which case we always accept it
|
||||
if (
|
||||
@ -154,9 +156,9 @@ def matched_files_iter(
|
||||
and (match_extensions(filepath, extensions) or filepath in exact_matches)
|
||||
):
|
||||
if not is_pytorch_extension: # for pytorch extensions, consider all files
|
||||
if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
|
||||
if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
|
||||
continue
|
||||
if out_of_place_only and not is_out_of_place(filepath):
|
||||
if out_of_place_only and not is_out_of_place(rel_filepath):
|
||||
continue
|
||||
yield filepath
|
||||
|
||||
@ -165,59 +167,23 @@ def preprocess_file_and_save_result(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
header_include_dirs: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> None:
|
||||
result = preprocessor(output_directory, filepath, all_files, includes, stats,
|
||||
result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
|
||||
# Show what happened
|
||||
if show_progress:
|
||||
if show_progress and "ignored" not in str(result["status"]):
|
||||
print(
|
||||
fin_path, "->",
|
||||
result["hipified_path"], result["status"])
|
||||
result["hipified_path"], result["status"], flush=True)
|
||||
|
||||
if result["hipified_path"] is not None:
|
||||
HIPIFY_FINAL_RESULT[fin_path] = result
|
||||
|
||||
|
||||
def preprocess(
|
||||
output_directory: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
show_detailed: bool = False,
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
clean_ctx: Optional[GeneratedFileCleaner] = None) -> HipifyFinalResult:
|
||||
"""
|
||||
Call preprocessor on selected files.
|
||||
|
||||
Arguments)
|
||||
show_detailed - Show a detailed summary of the transpilation process.
|
||||
"""
|
||||
|
||||
if clean_ctx is None:
|
||||
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
|
||||
|
||||
# Preprocessing statistics.
|
||||
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
|
||||
|
||||
for filepath in all_files:
|
||||
preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
|
||||
|
||||
# Show detailed summary
|
||||
if show_detailed:
|
||||
compute_stats(stats)
|
||||
|
||||
return HIPIFY_FINAL_RESULT
|
||||
HIPIFY_FINAL_RESULT[fin_path] = result
|
||||
|
||||
|
||||
def compute_stats(stats):
|
||||
@ -544,16 +510,17 @@ def replace_extern_shared(input_string):
|
||||
return output_string
|
||||
|
||||
|
||||
def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
|
||||
"""
|
||||
Returns the new name of the hipified file
|
||||
"""
|
||||
# At the moment, some PyTorch source files are HIPified in place. The predicate
|
||||
# is_out_of_place tells us if this is the case or not.
|
||||
if not is_pytorch_extension and not is_out_of_place(filepath):
|
||||
return filepath
|
||||
assert(not os.path.isabs(rel_filepath))
|
||||
if not is_pytorch_extension and not is_out_of_place(rel_filepath):
|
||||
return rel_filepath
|
||||
|
||||
dirpath, filename = os.path.split(filepath)
|
||||
dirpath, filename = os.path.split(rel_filepath)
|
||||
root, ext = os.path.splitext(filename)
|
||||
|
||||
# Here's the plan:
|
||||
@ -597,6 +564,7 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
orig_dirpath = dirpath
|
||||
|
||||
dirpath = dirpath.replace('cuda', 'hip')
|
||||
dirpath = dirpath.replace('CUDA', 'HIP')
|
||||
dirpath = dirpath.replace('THC', 'THH')
|
||||
|
||||
root = root.replace('cuda', 'hip')
|
||||
@ -614,36 +582,39 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
return os.path.join(dirpath, root + ext)
|
||||
|
||||
|
||||
def is_out_of_place(filepath):
|
||||
if filepath.startswith("torch/"):
|
||||
def is_out_of_place(rel_filepath):
|
||||
assert(not os.path.isabs(rel_filepath))
|
||||
if rel_filepath.startswith("torch/"):
|
||||
return False
|
||||
if filepath.startswith("tools/autograd/templates/"):
|
||||
if rel_filepath.startswith("tools/autograd/templates/"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Keep this synchronized with includes/ignores in build_amd.py
|
||||
def is_pytorch_file(filepath):
|
||||
if filepath.startswith("aten/"):
|
||||
if filepath.startswith("aten/src/ATen/core/"):
|
||||
def is_pytorch_file(rel_filepath):
|
||||
assert(not os.path.isabs(rel_filepath))
|
||||
if rel_filepath.startswith("aten/"):
|
||||
if rel_filepath.startswith("aten/src/ATen/core/"):
|
||||
return False
|
||||
return True
|
||||
if filepath.startswith("torch/"):
|
||||
if rel_filepath.startswith("torch/"):
|
||||
return True
|
||||
if filepath.startswith("tools/autograd/templates/"):
|
||||
if rel_filepath.startswith("tools/autograd/templates/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_cusparse_file(filepath):
|
||||
if is_pytorch_file(filepath):
|
||||
return "sparse" in filepath.lower()
|
||||
def is_cusparse_file(rel_filepath):
|
||||
if is_pytorch_file(rel_filepath):
|
||||
return "sparse" in rel_filepath.lower()
|
||||
return False
|
||||
|
||||
def is_caffe2_gpu_file(filepath):
|
||||
if filepath.startswith("c10/cuda"):
|
||||
def is_caffe2_gpu_file(rel_filepath):
|
||||
assert(not os.path.isabs(rel_filepath))
|
||||
if rel_filepath.startswith("c10/cuda"):
|
||||
return True
|
||||
filename = os.path.basename(filepath)
|
||||
filename = os.path.basename(rel_filepath)
|
||||
_, ext = os.path.splitext(filename)
|
||||
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
||||
|
||||
@ -752,31 +723,36 @@ RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
|
||||
Returns a dict with the following keys:
|
||||
"hipified_path" : absolute path of hipified source file
|
||||
"status" : "ok" if hipified file was written out
|
||||
"skipped" if an identical hipified file already existed
|
||||
"ignored" if the source file was a hipified file itself
|
||||
"skipped" if an identical hipified file already existed or hipified file couldn't be written out
|
||||
"ignored" if the source file was a hipified file itself or not meant to be hipified
|
||||
"""
|
||||
def preprocessor(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
header_include_dirs: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> HipifyResult:
|
||||
""" Executes the CUDA -> HIP conversion on the specified file. """
|
||||
if filepath not in all_files:
|
||||
return {"hipified_path": None, "status": "[ignored, not to be hipified]"}
|
||||
|
||||
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
|
||||
rel_filepath = os.path.relpath(filepath, output_directory)
|
||||
|
||||
with open(fin_path, 'r', encoding='utf-8') as fin:
|
||||
if fin.readline() == HIPIFY_C_BREADCRUMB:
|
||||
return {"hipified_path": None, "status": "ignored"}
|
||||
return {"hipified_path": None, "status": "[ignored, input is hipified output]"}
|
||||
fin.seek(0)
|
||||
output_source = fin.read()
|
||||
|
||||
orig_output_source = output_source
|
||||
|
||||
fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension)))
|
||||
# get_hip_file_path needs a relative path to work correctly
|
||||
fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
|
||||
if not os.path.exists(os.path.dirname(fout_path)):
|
||||
clean_ctx.makedirs(os.path.dirname(fout_path))
|
||||
|
||||
@ -791,9 +767,9 @@ def preprocessor(
|
||||
if is_pytorch_extension:
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
if is_cusparse_file(filepath):
|
||||
if is_cusparse_file(rel_filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
|
||||
elif is_pytorch_file(filepath):
|
||||
elif is_pytorch_file(rel_filepath):
|
||||
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
|
||||
else:
|
||||
def c2_repl(m):
|
||||
@ -829,8 +805,8 @@ def preprocessor(
|
||||
header_filepath = header_path_to_check
|
||||
# If not found, look in include dirs one by one and first match wins
|
||||
if header_filepath is None:
|
||||
for include in includes:
|
||||
header_dir_to_check = os.path.join(output_directory, os.path.dirname(include))
|
||||
for header_include_dir in header_include_dirs:
|
||||
header_dir_to_check = os.path.join(output_directory, header_include_dir)
|
||||
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
|
||||
if os.path.exists(header_path_to_check):
|
||||
header_dir = header_dir_to_check
|
||||
@ -841,12 +817,12 @@ def preprocessor(
|
||||
# Hipify header file first if needed
|
||||
if header_filepath not in HIPIFY_FINAL_RESULT:
|
||||
preprocess_file_and_save_result(output_directory,
|
||||
os.path.relpath(header_filepath, output_directory),
|
||||
all_files, includes, stats, hip_clang_launch, is_pytorch_extension,
|
||||
clean_ctx, show_progress)
|
||||
value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
|
||||
assert value is not None
|
||||
return templ.format(os.path.relpath(value, header_dir))
|
||||
header_filepath,
|
||||
all_files, header_include_dirs, stats, hip_clang_launch,
|
||||
is_pytorch_extension, clean_ctx, show_progress)
|
||||
hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
|
||||
return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
|
||||
else header_filepath, header_dir))
|
||||
|
||||
return m.group(0)
|
||||
return repl
|
||||
@ -880,7 +856,7 @@ def preprocessor(
|
||||
and orig_output_source == output_source
|
||||
and os.path.dirname(fin_path) == os.path.dirname(fout_path)
|
||||
):
|
||||
return {"hipified_path": fin_path, "status": "ok"}
|
||||
return {"hipified_path": fin_path, "status": "[skipped, no changes]"}
|
||||
|
||||
# Add hipify breadcrumb for C-style files to avoid re-hipification
|
||||
if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
|
||||
@ -894,13 +870,13 @@ def preprocessor(
|
||||
try:
|
||||
with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
|
||||
fout.write(output_source)
|
||||
return {"hipified_path": fout_path, "status": "ok"}
|
||||
return {"hipified_path": fout_path, "status": "[ok]"}
|
||||
except PermissionError as e:
|
||||
print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}",
|
||||
file=sys.stderr)
|
||||
return {"hipified_path": fin_path, "status": "skipped"}
|
||||
return {"hipified_path": fin_path, "status": "[skipped, no permissions]"}
|
||||
else:
|
||||
return {"hipified_path": fout_path, "status": "skipped"}
|
||||
return {"hipified_path": fout_path, "status": "[skipped, already hipified]"}
|
||||
|
||||
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
|
||||
with openf(filepath, "r+") as f:
|
||||
@ -995,14 +971,17 @@ def hipify(
|
||||
project_directory: str,
|
||||
show_detailed: bool = False,
|
||||
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
|
||||
header_extensions: Iterable = (".cuh", ".h", ".hpp"),
|
||||
output_directory: str = "",
|
||||
includes: Iterable = (),
|
||||
header_include_dirs: Iterable = (),
|
||||
includes: Iterable = ('*',),
|
||||
extra_files: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
ignores: Iterable = (),
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
hipify_extra_files_only: bool = False,
|
||||
clean_ctx: Optional[GeneratedFileCleaner] = None
|
||||
) -> HipifyFinalResult:
|
||||
if project_directory == "":
|
||||
@ -1018,6 +997,10 @@ def hipify(
|
||||
project_directory.rstrip("/")
|
||||
output_directory = project_directory + "_amd"
|
||||
|
||||
if project_directory != output_directory:
|
||||
includes = [include.replace(project_directory, output_directory) for include in includes]
|
||||
ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
|
||||
|
||||
# Copy from project directory to output directory if not done already.
|
||||
if not os.path.exists(output_directory):
|
||||
shutil.copytree(project_directory, output_directory)
|
||||
@ -1027,19 +1010,42 @@ def hipify(
|
||||
out_of_place_only=out_of_place_only,
|
||||
is_pytorch_extension=is_pytorch_extension))
|
||||
all_files_set = set(all_files)
|
||||
# Convert extra_files to relative paths since all_files has all relative paths
|
||||
for f in extra_files:
|
||||
f_rel = os.path.relpath(f, output_directory)
|
||||
if f_rel not in all_files_set:
|
||||
all_files.append(f_rel)
|
||||
if not os.path.isabs(f):
|
||||
f = os.path.join(output_directory, f)
|
||||
if f not in all_files_set:
|
||||
all_files.append(f)
|
||||
|
||||
# Start Preprocessor
|
||||
return preprocess(
|
||||
output_directory,
|
||||
all_files,
|
||||
includes,
|
||||
show_detailed=show_detailed,
|
||||
show_progress=show_progress,
|
||||
hip_clang_launch=hip_clang_launch,
|
||||
is_pytorch_extension=is_pytorch_extension,
|
||||
clean_ctx=clean_ctx)
|
||||
# List all files in header_include_paths to ensure they are hipified
|
||||
from pathlib import Path
|
||||
for header_include_dir in header_include_dirs:
|
||||
if os.path.isabs(header_include_dir):
|
||||
header_include_dir_path = Path(header_include_dir)
|
||||
else:
|
||||
header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
|
||||
for path in header_include_dir_path.rglob('*'):
|
||||
if (
|
||||
path.is_file()
|
||||
and _fnmatch(str(path), includes)
|
||||
and (not _fnmatch(str(path), ignores))
|
||||
and match_extensions(path.name, header_extensions)
|
||||
):
|
||||
all_files.append(str(path))
|
||||
|
||||
if clean_ctx is None:
|
||||
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
|
||||
|
||||
# Preprocessing statistics.
|
||||
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
|
||||
|
||||
for filepath in (all_files if not hipify_extra_files_only else extra_files):
|
||||
preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
|
||||
stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
|
||||
|
||||
# Show detailed summary
|
||||
if show_detailed:
|
||||
compute_stats(stats)
|
||||
|
||||
return HIPIFY_FINAL_RESULT
|
||||
|
Reference in New Issue
Block a user