Hipify fixes for a successful DeepSpeed build

These commits are required to build DeepSpeed on ROCm without the hipify errors.

a41829d9ed
663c718462

cc: @jeffdaily

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76141
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/albanD
This commit is contained in:
rraminen
2022-04-28 13:19:59 +00:00
committed by PyTorch MergeBot
parent bbc263eb5d
commit 7422ccea8b
3 changed files with 129 additions and 108 deletions

View File

@ -99,6 +99,8 @@ includes = [
"tools/autograd/templates/python_variable_methods.cpp",
]
includes = [os.path.join(proj_dir, include) for include in includes]
for new_dir in args.extra_include_dir:
abs_new_dir = os.path.join(proj_dir, new_dir)
if os.path.exists(abs_new_dir):
@ -122,6 +124,8 @@ ignores = [
"torch/include/*",
]
ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]
# Check if the compiler is hip-clang.
def is_hip_clang() -> bool:
try:

View File

@ -17,7 +17,7 @@ import torch._appdirs
from .file_baton import FileBaton
from ._cpp_extension_versioner import ExtensionVersioner
from .hipify import hipify_python
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from .hipify.hipify_python import GeneratedFileCleaner
from typing import List, Optional, Union, Tuple
from torch.torch_version import TorchVersion
@ -1010,16 +1010,19 @@ def CUDAExtension(name, sources, *args, **kwargs):
hipify_result = hipify_python.hipify(
project_directory=build_dir,
output_directory=build_dir,
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'],
header_include_dirs=include_dirs,
includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only
extra_files=[os.path.abspath(s) for s in sources],
show_detailed=True,
is_pytorch_extension=True,
hipify_extra_files_only=True, # don't hipify everything in includes path
)
hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and
hipify_result[s_abs]["hipified_path"] is not None) else s_abs)
sources = list(hipified_sources)
@ -1400,15 +1403,25 @@ def _jit_compile(name,
try:
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
hipify_python.hipify(
hipify_result = hipify_python.hipify(
project_directory=build_directory,
output_directory=build_directory,
includes=os.path.join(build_directory, '*'),
header_include_dirs=(extra_include_paths if extra_include_paths is not None else []),
extra_files=[os.path.abspath(s) for s in sources],
ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers
show_detailed=verbose,
show_progress=verbose,
is_pytorch_extension=True,
clean_ctx=clean_ctx
)
hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
sources = list(hipified_sources)
_write_ninja_file_and_build_library(
name=name,
sources=sources,
@ -1904,10 +1917,6 @@ def _write_ninja_file_to_build_library(path,
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
sources = [s if not _is_cuda_file(s) else
os.path.abspath(os.path.join(
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
for s in sources]
elif with_cuda:
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
if IS_WINDOWS:
@ -2018,6 +2027,8 @@ def _write_ninja_file(path,
nvcc = _join_cuda_home('bin', 'nvcc')
config.append(f'nvcc = {nvcc}')
if IS_HIP_EXTENSION:
post_cflags = COMMON_HIP_FLAGS + post_cflags
flags = [f'cflags = {" ".join(cflags)}']
flags.append(f'post_cflags = {" ".join(post_cflags)}')
if with_cuda:

204
torch/utils/hipify/hipify_python.py Normal file → Executable file
View File

@ -117,15 +117,16 @@ def match_extensions(filename: str, extensions: Iterable) -> bool:
"""Helper method to see if filename ends with certain extension"""
return any(filename.endswith(e) for e in extensions)
def _fnmatch(filepath, patterns):
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
def matched_files_iter(
root_path: str,
includes: Iterable = ('*',),
includes: Iterable = (),
ignores: Iterable = (),
extensions: Iterable = (),
out_of_place_only: bool = False,
is_pytorch_extension: bool = False) -> Iterator[str]:
def _fnmatch(filepath, patterns):
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
exact_matches = set(includes)
@ -145,7 +146,8 @@ def matched_files_iter(
if "third_party" in dirs:
dirs.remove("third_party")
for filename in filenames:
filepath = os.path.join(rel_dirpath, filename)
filepath = os.path.join(abs_dirpath, filename)
rel_filepath = os.path.join(rel_dirpath, filename)
# We respect extensions, UNLESS you wrote the entire
# filename verbatim, in which case we always accept it
if (
@ -154,9 +156,9 @@ def matched_files_iter(
and (match_extensions(filepath, extensions) or filepath in exact_matches)
):
if not is_pytorch_extension: # for pytorch extensions, consider all files
if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
continue
if out_of_place_only and not is_out_of_place(filepath):
if out_of_place_only and not is_out_of_place(rel_filepath):
continue
yield filepath
@ -165,59 +167,23 @@ def preprocess_file_and_save_result(
output_directory: str,
filepath: str,
all_files: Iterable,
includes: Iterable,
header_include_dirs: Iterable,
stats: Dict[str, List],
hip_clang_launch: bool,
is_pytorch_extension: bool,
clean_ctx: GeneratedFileCleaner,
show_progress: bool) -> None:
result = preprocessor(output_directory, filepath, all_files, includes, stats,
result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
# Show what happened
if show_progress:
if show_progress and "ignored" not in str(result["status"]):
print(
fin_path, "->",
result["hipified_path"], result["status"])
result["hipified_path"], result["status"], flush=True)
if result["hipified_path"] is not None:
HIPIFY_FINAL_RESULT[fin_path] = result
def preprocess(
output_directory: str,
all_files: Iterable,
includes: Iterable,
show_detailed: bool = False,
show_progress: bool = True,
hip_clang_launch: bool = False,
is_pytorch_extension: bool = False,
clean_ctx: Optional[GeneratedFileCleaner] = None) -> HipifyFinalResult:
"""
Call preprocessor on selected files.
Arguments)
show_detailed - Show a detailed summary of the transpilation process.
"""
if clean_ctx is None:
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
# Preprocessing statistics.
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
for filepath in all_files:
preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats,
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
# Show detailed summary
if show_detailed:
compute_stats(stats)
return HIPIFY_FINAL_RESULT
HIPIFY_FINAL_RESULT[fin_path] = result
def compute_stats(stats):
@ -544,16 +510,17 @@ def replace_extern_shared(input_string):
return output_string
def get_hip_file_path(filepath, is_pytorch_extension=False):
def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
"""
Returns the new name of the hipified file
"""
# At the moment, some PyTorch source files are HIPified in place. The predicate
# is_out_of_place tells us if this is the case or not.
if not is_pytorch_extension and not is_out_of_place(filepath):
return filepath
assert(not os.path.isabs(rel_filepath))
if not is_pytorch_extension and not is_out_of_place(rel_filepath):
return rel_filepath
dirpath, filename = os.path.split(filepath)
dirpath, filename = os.path.split(rel_filepath)
root, ext = os.path.splitext(filename)
# Here's the plan:
@ -597,6 +564,7 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
orig_dirpath = dirpath
dirpath = dirpath.replace('cuda', 'hip')
dirpath = dirpath.replace('CUDA', 'HIP')
dirpath = dirpath.replace('THC', 'THH')
root = root.replace('cuda', 'hip')
@ -614,36 +582,39 @@ def get_hip_file_path(filepath, is_pytorch_extension=False):
return os.path.join(dirpath, root + ext)
def is_out_of_place(filepath):
if filepath.startswith("torch/"):
def is_out_of_place(rel_filepath):
assert(not os.path.isabs(rel_filepath))
if rel_filepath.startswith("torch/"):
return False
if filepath.startswith("tools/autograd/templates/"):
if rel_filepath.startswith("tools/autograd/templates/"):
return False
return True
# Keep this synchronized with includes/ignores in build_amd.py
def is_pytorch_file(filepath):
if filepath.startswith("aten/"):
if filepath.startswith("aten/src/ATen/core/"):
def is_pytorch_file(rel_filepath):
assert(not os.path.isabs(rel_filepath))
if rel_filepath.startswith("aten/"):
if rel_filepath.startswith("aten/src/ATen/core/"):
return False
return True
if filepath.startswith("torch/"):
if rel_filepath.startswith("torch/"):
return True
if filepath.startswith("tools/autograd/templates/"):
if rel_filepath.startswith("tools/autograd/templates/"):
return True
return False
def is_cusparse_file(filepath):
if is_pytorch_file(filepath):
return "sparse" in filepath.lower()
def is_cusparse_file(rel_filepath):
if is_pytorch_file(rel_filepath):
return "sparse" in rel_filepath.lower()
return False
def is_caffe2_gpu_file(filepath):
if filepath.startswith("c10/cuda"):
def is_caffe2_gpu_file(rel_filepath):
assert(not os.path.isabs(rel_filepath))
if rel_filepath.startswith("c10/cuda"):
return True
filename = os.path.basename(filepath)
filename = os.path.basename(rel_filepath)
_, ext = os.path.splitext(filename)
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
@ -752,31 +723,36 @@ RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
Returns a dict with the following keys:
"hipified_path" : absolute path of hipified source file
"status" : "ok" if hipified file was written out
"skipped" if an identical hipified file already existed
"ignored" if the source file was a hipified file itself
"skipped" if an identical hipified file already existed or hipified file couldn't be written out
"ignored" if the source file was a hipified file itself or not meant to be hipified
"""
def preprocessor(
output_directory: str,
filepath: str,
all_files: Iterable,
includes: Iterable,
header_include_dirs: Iterable,
stats: Dict[str, List],
hip_clang_launch: bool,
is_pytorch_extension: bool,
clean_ctx: GeneratedFileCleaner,
show_progress: bool) -> HipifyResult:
""" Executes the CUDA -> HIP conversion on the specified file. """
if filepath not in all_files:
return {"hipified_path": None, "status": "[ignored, not to be hipified]"}
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
rel_filepath = os.path.relpath(filepath, output_directory)
with open(fin_path, 'r', encoding='utf-8') as fin:
if fin.readline() == HIPIFY_C_BREADCRUMB:
return {"hipified_path": None, "status": "ignored"}
return {"hipified_path": None, "status": "[ignored, input is hipified output]"}
fin.seek(0)
output_source = fin.read()
orig_output_source = output_source
fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension)))
# get_hip_file_path needs a relative path to work correctly
fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
if not os.path.exists(os.path.dirname(fout_path)):
clean_ctx.makedirs(os.path.dirname(fout_path))
@ -791,9 +767,9 @@ def preprocessor(
if is_pytorch_extension:
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
if is_cusparse_file(filepath):
if is_cusparse_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
elif is_pytorch_file(filepath):
elif is_pytorch_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
def c2_repl(m):
@ -829,8 +805,8 @@ def preprocessor(
header_filepath = header_path_to_check
# If not found, look in include dirs one by one and first match wins
if header_filepath is None:
for include in includes:
header_dir_to_check = os.path.join(output_directory, os.path.dirname(include))
for header_include_dir in header_include_dirs:
header_dir_to_check = os.path.join(output_directory, header_include_dir)
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
if os.path.exists(header_path_to_check):
header_dir = header_dir_to_check
@ -841,12 +817,12 @@ def preprocessor(
# Hipify header file first if needed
if header_filepath not in HIPIFY_FINAL_RESULT:
preprocess_file_and_save_result(output_directory,
os.path.relpath(header_filepath, output_directory),
all_files, includes, stats, hip_clang_launch, is_pytorch_extension,
clean_ctx, show_progress)
value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
assert value is not None
return templ.format(os.path.relpath(value, header_dir))
header_filepath,
all_files, header_include_dirs, stats, hip_clang_launch,
is_pytorch_extension, clean_ctx, show_progress)
hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
else header_filepath, header_dir))
return m.group(0)
return repl
@ -880,7 +856,7 @@ def preprocessor(
and orig_output_source == output_source
and os.path.dirname(fin_path) == os.path.dirname(fout_path)
):
return {"hipified_path": fin_path, "status": "ok"}
return {"hipified_path": fin_path, "status": "[skipped, no changes]"}
# Add hipify breadcrumb for C-style files to avoid re-hipification
if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
@ -894,13 +870,13 @@ def preprocessor(
try:
with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
fout.write(output_source)
return {"hipified_path": fout_path, "status": "ok"}
return {"hipified_path": fout_path, "status": "[ok]"}
except PermissionError as e:
print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}",
file=sys.stderr)
return {"hipified_path": fin_path, "status": "skipped"}
return {"hipified_path": fin_path, "status": "[skipped, no permissions]"}
else:
return {"hipified_path": fout_path, "status": "skipped"}
return {"hipified_path": fout_path, "status": "[skipped, already hipified]"}
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
with openf(filepath, "r+") as f:
@ -995,14 +971,17 @@ def hipify(
project_directory: str,
show_detailed: bool = False,
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
header_extensions: Iterable = (".cuh", ".h", ".hpp"),
output_directory: str = "",
includes: Iterable = (),
header_include_dirs: Iterable = (),
includes: Iterable = ('*',),
extra_files: Iterable = (),
out_of_place_only: bool = False,
ignores: Iterable = (),
show_progress: bool = True,
hip_clang_launch: bool = False,
is_pytorch_extension: bool = False,
hipify_extra_files_only: bool = False,
clean_ctx: Optional[GeneratedFileCleaner] = None
) -> HipifyFinalResult:
if project_directory == "":
@ -1018,6 +997,10 @@ def hipify(
project_directory.rstrip("/")
output_directory = project_directory + "_amd"
if project_directory != output_directory:
includes = [include.replace(project_directory, output_directory) for include in includes]
ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
# Copy from project directory to output directory if not done already.
if not os.path.exists(output_directory):
shutil.copytree(project_directory, output_directory)
@ -1027,19 +1010,42 @@ def hipify(
out_of_place_only=out_of_place_only,
is_pytorch_extension=is_pytorch_extension))
all_files_set = set(all_files)
# Convert extra_files to relative paths since all_files has all relative paths
for f in extra_files:
f_rel = os.path.relpath(f, output_directory)
if f_rel not in all_files_set:
all_files.append(f_rel)
if not os.path.isabs(f):
f = os.path.join(output_directory, f)
if f not in all_files_set:
all_files.append(f)
# Start Preprocessor
return preprocess(
output_directory,
all_files,
includes,
show_detailed=show_detailed,
show_progress=show_progress,
hip_clang_launch=hip_clang_launch,
is_pytorch_extension=is_pytorch_extension,
clean_ctx=clean_ctx)
# List all files in header_include_paths to ensure they are hipified
from pathlib import Path
for header_include_dir in header_include_dirs:
if os.path.isabs(header_include_dir):
header_include_dir_path = Path(header_include_dir)
else:
header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
for path in header_include_dir_path.rglob('*'):
if (
path.is_file()
and _fnmatch(str(path), includes)
and (not _fnmatch(str(path), ignores))
and match_extensions(path.name, header_extensions)
):
all_files.append(str(path))
if clean_ctx is None:
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
# Preprocessing statistics.
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
for filepath in (all_files if not hipify_extra_files_only else extra_files):
preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
# Show detailed summary
if show_detailed:
compute_stats(stats)
return HIPIFY_FINAL_RESULT