Change hip filename extension to .hip (#14036)

Summary:
xw285cornell

- To make hip files to have unique filename extension we change hip files from _hip.cc to .hip (it's the only blessing option other than .cu in hipcc 3d51a1fb01/bin/hipcc (L552)).
- Change to use host compiler to compile .cc|.cpp files. Previously we use hcc to compile them which is unnecessary
- Change the hipify script to not replace "gpu" with "hip" in the filename of the generated hipified files. Previously we do this because hcc has a bug when linking files that have same filename. We have now changed to use host linker to do linking so this is unnecessary anymore.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14036

Reviewed By: xw285cornell

Differential Revision: D13091813

Pulled By: bddppq

fbshipit-source-id: ea3d887751d8abb39d75f5d5104aa66ce66b9ee0
This commit is contained in:
Junjie Bai
2018-11-16 11:50:29 -08:00
committed by Facebook Github Bot
parent 30018fcd0b
commit 0d7a986da1
35 changed files with 145 additions and 160 deletions

View File

@ -24,13 +24,13 @@ if (USE_CUDA)
endif()
if (USE_ROCM)
caffe2_hip_binary_target("hip/inspect_hip.cc")
caffe2_hip_binary_target("hip/print_core_object_sizes_hip.cc")
caffe2_hip_binary_target("hip/inspect_gpu.cc")
caffe2_hip_binary_target("hip/print_core_object_sizes_gpu.cc")
if (BUILD_TEST)
# Core overhead benchmark
caffe2_hip_binary_target("hip/core_overhead_benchmark_hip.cc")
target_link_libraries(core_overhead_benchmark_hip benchmark)
caffe2_hip_binary_target("hip/core_overhead_benchmark_gpu.cc")
target_link_libraries(core_overhead_benchmark_gpu benchmark)
endif()
endif()

View File

@ -374,14 +374,14 @@ if(USE_ROCM)
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
hip_include_directories(${Caffe2_HIP_INCLUDES})
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cc|cpp|cu)$")
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$")
set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# FindHIP.CMake checks if the SHARED flag is set and adds extra logic accordingly.
hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS})
# Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added.
target_compile_options(caffe2_hip PRIVATE ${HIP_HCC_FLAGS})
target_compile_options(caffe2_hip PRIVATE ${HIP_CXX_FLAGS})
target_link_libraries(caffe2_hip PUBLIC caffe2)
target_link_libraries(caffe2_hip PUBLIC ${Caffe2_HIP_DEPENDENCY_LIBS})
@ -435,7 +435,6 @@ if (BUILD_TEST)
if(USE_ROCM)
foreach(test_src ${Caffe2_HIP_TEST_SRCS})
set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
target_link_libraries(${test_name} ${Caffe2_MAIN_LIBS} gtest_main)

View File

@ -14,14 +14,14 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
file(GLOB tmp *_test.cc)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ---[ HIP files
file(GLOB_RECURSE tmp *_hip.cc)
# ---[ general HIP files
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ------[ MIOpen files
file(GLOB_RECURSE tmp *_miopen.cc)
# ------[ HIP sources
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB_RECURSE tmp *_test.cc)
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
@ -42,7 +42,7 @@ file(GLOB tmp *_gpu_test.cc)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ---[ HIP test files
file(GLOB_RECURSE tmp *_hip_test.cc)
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files

View File

@ -194,7 +194,7 @@ void DeviceQuery(const int device) {
<< std::endl;
ss << "Total registers per block: " << prop.regsPerBlock << std::endl;
ss << "Warp size: " << prop.warpSize << std::endl;
#ifndef __HIPCC__
#ifndef __HIP_PLATFORM_HCC__
ss << "Maximum memory pitch: " << prop.memPitch << std::endl;
#endif
ss << "Maximum threads per block: " << prop.maxThreadsPerBlock
@ -207,14 +207,14 @@ void DeviceQuery(const int device) {
<< prop.maxGridSize[2] << std::endl;
ss << "Clock rate: " << prop.clockRate << std::endl;
ss << "Total constant memory: " << prop.totalConstMem << std::endl;
#ifndef __HIPCC__
#ifndef __HIP_PLATFORM_HCC__
ss << "Texture alignment: " << prop.textureAlignment << std::endl;
ss << "Concurrent copy and execution: "
<< (prop.deviceOverlap ? "Yes" : "No") << std::endl;
#endif
ss << "Number of multiprocessors: " << prop.multiProcessorCount
<< std::endl;
#ifndef __HIPCC__
#ifndef __HIP_PLATFORM_HCC__
ss << "Kernel execution timeout: "
<< (prop.kernelExecTimeoutEnabled ? "Yes" : "No") << std::endl;
#endif
@ -266,7 +266,7 @@ const char* cublasGetErrorString(cublasStatus_t error) {
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
#ifndef __HIPCC__
#ifndef __HIP_PLATFORM_HCC__
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
@ -282,7 +282,7 @@ const char* cublasGetErrorString(cublasStatus_t error) {
return "CUBLAS_STATUS_LICENSE_ERROR";
#endif // CUDA_VERSION >= 6050
#endif // CUDA_VERSION >= 6000
#ifdef __HIPCC__
#ifdef __HIP_PLATFORM_HCC__
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
#endif
@ -319,7 +319,7 @@ const char* curandGetErrorString(curandStatus_t error) {
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
#ifdef __HIPCC__
#ifdef __HIP_PLATFORM_HCC__
case HIPRAND_STATUS_NOT_IMPLEMENTED:
return "HIPRAND_STATUS_NOT_IMPLEMENTED";
#endif

View File

@ -282,7 +282,7 @@ CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error);
// CUDA_KERNEL_ASSERT is a macro that wraps an assert() call inside cuda
// kernels. This is not supported by Apple platforms so we special case it.
// See http://docs.nvidia.com/cuda/cuda-c-programming-guide/#assertion
#if defined(__APPLE__) || defined(__HIPCC__)
#if defined(__APPLE__) || defined(__HIP_PLATFORM_HCC__)
#define CUDA_KERNEL_ASSERT(...)
#else // __APPLE__
#define CUDA_KERNEL_ASSERT(...) assert(__VA_ARGS__)

View File

@ -3,7 +3,7 @@
#define CAFFE2_CORE_MIOPEN_WRAPPERS_H_
#include "caffe2/core/hip/common_miopen.h"
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
namespace caffe2 {

View File

@ -6,13 +6,13 @@ set(Caffe2_DB_COMMON_GPU_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/create_db_op_gpu.cc"
)
set(Caffe2_DB_COMMON_HIP_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/hip/create_db_op_hip.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/hip/create_db_op_gpu.cc"
)
# Common files that are always going to be included.
list(APPEND Caffe2_CPU_SRCS ${Caffe2_DB_COMMON_CPU_SRC})
list(APPEND Caffe2_GPU_SRCS ${Caffe2_DB_COMMON_GPU_SRC})
list(APPEND Caffe2_HIP_SRCS ${Caffe2_DB_COMMON_HIP_SRC})
list(APPEND Caffe2_HIP_SRCS ${Caffe2_DB_COMMON_HIP_SRC})
# DB specific files
if (USE_LMDB)

View File

@ -12,13 +12,16 @@ if(USE_OPENCV AND OpenCV_FOUND)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ---[ HIP files
# ------[ general hip
file(GLOB_RECURSE tmp *_hip.cc)
# ------[ general HIP
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ------[ HIP sources
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB_RECURSE tmp *_test.cc)
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
file(GLOB tmp *.cc)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
@ -33,9 +36,9 @@ if(USE_OPENCV AND OpenCV_FOUND)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ---[ HIP test files
file(GLOB_RECURSE tmp *_hip_test.cc)
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files
file(GLOB tmp *_test.cc)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})

View File

@ -23,15 +23,15 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
file(GLOB tmp *_test.cc)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ------[ general HIP
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ------[ HIP sources
file(GLOB_RECURSE tmp *_hip.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ------[ HIP device sources
file(GLOB_RECURSE tmp *_hipdev.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ---[ MIOPEN files
file(GLOB_RECURSE tmp *_miopen.cc)
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
file(GLOB tmp *.cc)
@ -58,10 +58,7 @@ file(GLOB tmp *_gpu_test.cc)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ---[ HIP test files
file(GLOB_RECURSE tmp *_hip_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ MIOPEN test files
file(GLOB_RECURSE tmp *_miopen_test.cc)
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files

View File

@ -1,7 +1,7 @@
#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_MIOPEN_H_
#define CAFFE2_OPERATORS_ACTIVATION_OPS_MIOPEN_H_
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/operators/conv_op.h"
#include "caffe2/operators/conv_pool_op_base.h"

View File

@ -1,4 +1,4 @@
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/operators/conv_transpose_op.h"

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/operators/conv_pool_op_base.h"

View File

@ -15,10 +15,10 @@
*/
#include <cfloat>
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/operators/spatial_batch_norm_op.h"
#include "caffe2/operators/hip/spatial_batch_norm_op_hip_impl.cuh"
#include "caffe2/operators/hip/spatial_batch_norm_op_gpu_impl.cuh"
#include "caffe2/utils/math.h"
const double MIOPEN_BN_MIN_EPSILON = 1e-6;

View File

@ -14,6 +14,16 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
file(GLOB tmp *_test.cc)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ------[ general HIP
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ------[ HIP sources
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
file(GLOB tmp *.cc)
# Manually remove the cudnn files since we might be using USE_CUDNN=OFF
@ -24,7 +34,7 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
# exclude test files and gpu files
file(GLOB tmp *_test.cc)
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS})
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS} ${Caffe2_HIP_SRCS})
# ---[ GPU test files
# ------[ cuDNN
@ -36,13 +46,19 @@ endif()
file(GLOB tmp *_gpu_test.cc)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ------[ HIP test files
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files
file(GLOB tmp *_test.cc)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS})
exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS} ${Caffe2_HIP_TEST_SRCS})
# ---[ Send the lists to the parent scope.
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} PARENT_SCOPE)

View File

@ -2,7 +2,7 @@
#define CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
#include "caffe2/core/context.h"
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"

View File

@ -6,8 +6,8 @@
#include <pybind11/stl.h>
#include "caffe2/core/hip/common_miopen.h"
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/operators/hip/operator_fallback_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/operators/hip/operator_fallback_gpu.h"
#include "caffe2/python/pybind_state_registry.h"
namespace caffe2 {

View File

@ -9,11 +9,14 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
file(GLOB tmp *_test.cc)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ------[ general HIP
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ---[ HIP files.
file(GLOB_RECURSE tmp *_hip.cc)
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB tmp *_test.cc)
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
@ -29,7 +32,7 @@ file(GLOB tmp *_gpu_test.cc)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ---[ HI test files
file(GLOB_RECURSE tmp *_hip_test.cc)
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files
@ -44,4 +47,3 @@ set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} PARENT_SCOPE)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} PARENT_SCOPE)

View File

@ -9,12 +9,14 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${tmp})
file(GLOB tmp *_test.cc)
exclude(Caffe2_GPU_SRCS "${Caffe2_GPU_SRCS}" ${tmp})
# ------[ general HIP
file(GLOB tmp hip/*.cc)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# ---[ HIP files
# ------[ general GPU
file(GLOB_RECURSE tmp *_hip.cc)
file(GLOB tmp hip/*.hip)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${tmp})
# exclude test files
file(GLOB_RECURSE tmp *_test.cc)
file(GLOB tmp hip/*_test.cc)
exclude(Caffe2_HIP_SRCS "${Caffe2_HIP_SRCS}" ${tmp})
# ---[ CPU files.
@ -31,7 +33,7 @@ file(GLOB tmp *_gpu_test.cc)
set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} ${tmp})
# ---[ HIP test files
file(GLOB_RECURSE tmp *_hip_test.cc)
file(GLOB tmp hip/*_test.cc)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS} ${tmp})
# ---[ CPU test files

View File

@ -28,7 +28,7 @@ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS}
)
set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS}
utils/hip/math_hip.cc
utils/hip/math_gpu.hip
)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS}
@ -47,8 +47,8 @@ set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS}
)
set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS}
utils/hip/math_hip_test.cc
utils/hip/math_blas_hip_test.cc
utils/hip/math_gpu_test.cc
utils/hip/math_blas_gpu_test.cc
)
# TODO Remove the CMake_xxx variables above and add them to the variables for the local library target below instead

View File

@ -8,7 +8,7 @@
#include <caffe2/core/common_gpu.h>
#endif
#if __HIP_DEVICE_COMPILE__
#include <caffe2/core/hip/common_hip.h>
#include <caffe2/core/hip/common_gpu.h>
#endif
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include "caffe2/core/blob.h"
#include "caffe2/core/context.h"
#include "caffe2/core/hip/context_hip.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/tensor.h"
#include "caffe2/operators/utility_ops.h"
#include "caffe2/proto/caffe2_pb.h"

View File

@ -18,7 +18,7 @@
#include "caffe2/utils/fixed_divisor.h"
// TODO: Move this to fixed_divisor.h
#ifdef __HIPCC__
#ifdef __HIP_PLATFORM_HCC__
#define FIXED_DIVISOR int32_t
#define FIXED_DIVISOR_DIV(d, n) (n / d)
#define FIXED_DIVISOR_MOD(d, n) (n % d)
@ -28,12 +28,12 @@
*q = n_copy / d; \
*r = n_copy % d; \
} while (0)
#else // __HIPCC__
#else // __HIP_PLATFORM_HCC__
#define FIXED_DIVISOR FixedDivisor<int32_t>
#define FIXED_DIVISOR_DIV(d, n) (d.Div(n))
#define FIXED_DIVISOR_MOD(d, n) (d.Mod(n))
#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r))
#endif // __HIPCC__
#endif // __HIP_PLATFORM_HCC__
#include "caffe2/utils/math_utils.h"
@ -743,7 +743,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
at::Half* C,
CUDAContext* context,
TensorProto::DataType math_type) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
// Note that cublas follows fortran order, so the order is different from
@ -841,7 +841,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<float, CUDAContext>(
float** C,
CUDAContext* context,
TensorProto::DataType math_type) {
#if __CUDACC_VER_MAJOR__ < 8 || defined(__HIPCC__)
#if __CUDACC_VER_MAJOR__ < 8 || defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<float, CUDAContext>(
@ -910,7 +910,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<float, CUDAContext>(
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIPCC__)
#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<float, CUDAContext>(
@ -968,7 +968,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
at::Half** C,
CUDAContext* context,
TensorProto::DataType math_type) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
#if __CUDACC_VER_MAJOR__ < 9
@ -1104,7 +1104,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
#if __CUDACC_VER_MAJOR__ < 8
@ -1479,7 +1479,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
at::Half* y,
CUDAContext* context,
TensorProto::DataType math_type) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
const cublasOperation_t cu_trans_A =
@ -1727,7 +1727,7 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
const at::Half* b,
at::Half* y,
CUDAContext* context) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
// execute with 32-bit math
@ -1997,7 +1997,7 @@ template <typename TAlpha, typename TData>
__global__ void
ScaleCUDAKernel(const int n, const TAlpha alpha, const TData* x, TData* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
y[i] = __ldg(x + i) * static_cast<TData>(alpha);
#else
y[i] = x[i] * static_cast<TData>(alpha);
@ -2009,7 +2009,7 @@ template <typename TAlpha, typename TData>
__global__ void
ScaleCUDAKernel(const int n, const TAlpha* alpha, const TData* x, TData* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
y[i] = __ldg(x + i) * static_cast<TData>(__ldg(alpha));
#else
y[i] = x[i] * static_cast<TData>(*alpha);
@ -2138,7 +2138,7 @@ DELEGATE_CUBLAS_SCALE_FUNCTION(double, double, cublasDscal)
CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t)
CAFFE2_SPECIALIZED_CUDA_SCALE(std::int64_t, std::int64_t)
#ifndef __HIPCC__
#ifndef __HIP_PLATFORM_HCC__
template <>
CAFFE2_CUDA_EXPORT void Scale<at::Half, at::Half, CUDAContext>(
const int N,
@ -2265,7 +2265,7 @@ CAFFE2_CUDA_EXPORT void Scale<float, at::Half, CUDAContext>(
CUDA_R_32F));
}
#else // __HIPCC__
#else // __HIP_PLATFORM_HCC__
namespace {
template <>
@ -2321,7 +2321,7 @@ __global__ void ScaleCUDAKernel<float, at::Half>(
CAFFE2_SPECIALIZED_HIP_SCALE(at::Half, at::Half)
CAFFE2_SPECIALIZED_HIP_SCALE(float, at::Half)
#endif // __HIPCC__
#endif // __HIP_PLATFORM_HCC__
#undef CAFFE2_SPECIALIZED_CUDA_SCALE
@ -2358,7 +2358,7 @@ CAFFE2_CUDA_EXPORT void Axpy<at::Half, CUDAContext>(
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
CUBLAS_ENFORCE(
@ -2397,7 +2397,7 @@ CAFFE2_CUDA_EXPORT void Axpy<at::Half, CUDAContext>(
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
#if defined(__HIPCC__) && !ROCBLAS_FP16
#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
CAFFE_THROW("HIP currently does not support FP16 yet.");
#else
CUBLAS_ENFORCE(cublasSetPointerMode(
@ -2558,7 +2558,7 @@ __global__ void Im2ColNCHWCUDAKernel(
for (int j = 0; j < kernel_w; ++j) {
const int h = h_in + dh;
const int w = w_in + dw;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
*col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) &&
utils::IsAGeZeroAndALtB(w, input_w)
? __ldg(img_data_ptr + dh * input_w + dw)
@ -2609,7 +2609,7 @@ __global__ void Im2ColNHWCCUDAKernel(
for (int j = 0; j < kernel_w; ++j) {
const int h = h_in + dh;
const int w = w_in + dw;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
*col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) &&
utils::IsAGeZeroAndALtB(w, input_w)
? __ldg(img_data + (h * input_w + w) * channels + channel_in)
@ -2671,7 +2671,7 @@ __global__ void Col2ImNCHWCUDAKernel(
(((c * patch_h + h_k) * patch_w + w_k) * output_h + h_col) *
output_w +
w_col;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
val += __ldg(col_data + col_data_index);
#else
val += col_data[col_data_index];
@ -2723,7 +2723,7 @@ __global__ void Col2ImNHWCCUDAKernel(
h_k /= dilation_h;
w_k /= dilation_w;
const int c_col = (h_k * patch_w + w_k) * channels + c;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
val += __ldg(
col_data + (h_col * output_w + w_col) * channels_col + c_col);
#else
@ -2775,7 +2775,7 @@ __global__ void Im2ColNdNCHWCUDAKernel(
is_padding |= !utils::IsAGeZeroAndALtB(d_img, img_shape.data[d_i + 1]);
img_index = img_index * img_shape.data[d_i + 1] + d_img;
}
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
if (!kCol2Im) {
Y_data[col_index] = is_padding ? 0 : __ldg(X_data + img_index);
} else if (!is_padding) {
@ -3638,7 +3638,7 @@ __global__ void BroadcastCUDAKernel(
FIXED_DIVISOR_DIV_MOD(Y_dims.data[i], Y_index_val, &Y_index_val, &d);
X_index += d * X_strides.data[i];
}
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[Y_index] = __ldg(X + X_index) * alpha;
#else
Y[Y_index] = X[X_index] * alpha;
@ -3730,7 +3730,7 @@ __global__ void RowwiseMomentsCUDAKernel(
T v_val = 0;
for (int j = threadIdx.x; j < cols; j += blockDim.x) {
const int X_index = i * cols + j;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
m_val += __ldg(X + X_index);
v_val += __ldg(X + X_index) * __ldg(X + X_index);
#else
@ -3764,7 +3764,7 @@ __global__ void ColwiseMomentsCUDAKernel(
T v_val = 0;
for (int j = threadIdx.x; j < rows; j += blockDim.x) {
const int X_index = j * cols + i;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
m_val += __ldg(X + X_index);
v_val += __ldg(X + X_index) * __ldg(X + X_index);
#else
@ -3807,7 +3807,7 @@ __global__ void MomentsCUDAKernel(
FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], Y_index, &Y_index, &r);
X_index += r * X_strides.data[d];
}
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
m_val += __ldg(X + X_index);
v_val += __ldg(X + X_index) * __ldg(X + X_index);
#else
@ -4014,7 +4014,7 @@ __global__ void BatchTranspose2DCUDAKernel(
int y = r * kTileDim + threadIdx.y;
if (x < W) {
for (int j = 0; j < kTileDim && y + j < H; j += kBlockRows) {
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
tile[threadIdx.y + j][threadIdx.x] =
__ldg(X + offset + (y + j) * W + x);
#else
@ -4050,7 +4050,7 @@ __global__ void TransposeCUDAKernel(
FIXED_DIVISOR_DIV_MOD(Y_dims.data[i], Y_index_val, &Y_index_val, &d);
X_index += d * X_strides.data[i];
}
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[Y_index] = __ldg(X + X_index);
#else
Y[Y_index] = X[X_index];
@ -4135,7 +4135,7 @@ __global__ void AffineChannelCUDAKernel(
T* Y) {
CUDA_1D_KERNEL_LOOP(i, size) {
const int c = kOrder == StorageOrder::NCHW ? i / HxW % C : i % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIPCC__)
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[i] = __ldg(scale + c) * __ldg(X + i) + __ldg(bias + c);
#else
Y[i] = scale[c] * X[i] + bias[c];

View File

@ -128,17 +128,8 @@ IF(HIP_FOUND)
message("\n***** Library versions from cmake find_package *****\n")
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
# https://github.com/ROCm-Developer-Tools/HIP/pull/558 #
set(CMAKE_SHARED_LIBRARY_SONAME_HIP_FLAG ${CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG})
set(CMAKE_HIP_LINK_EXECUTABLE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_PATH} <FLAGS> <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> -o <TARGET> <LINK_LIBRARIES>" )
set(CMAKE_HIP_CREATE_SHARED_LIBRARY "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_PATH} <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> <SONAME_FLAG><TARGET_SONAME> -o <TARGET> <LINK_LIBRARIES> -shared" )
set(CMAKE_HIP_CREATE_SHARED_MODULE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_PATH} <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> <SONAME_FLAG><TARGET_SONAME> -o <TARGET> <LINK_LIBRARIES> -shared" )
set(CMAKE_HIP_ARCHIVE_CREATE ${CMAKE_CXX_ARCHIVE_CREATE})
set(CMAKE_HIP_ARCHIVE_APPEND ${CMAKE_CXX_ARCHIVE_APPEND})
set(CMAKE_HIP_ARCHIVE_FINISH ${CMAKE_CXX_ARCHIVE_FINISH})
SET(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
SET(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)

View File

@ -113,13 +113,16 @@ function(caffe2_binary_target target_name_or_src)
endfunction()
function(caffe2_hip_binary_target target_name_or_src)
caffe2_binary_target(${target_name_or_src})
if (ARGC GREATER 1)
set(__target ${target_name_or_src})
prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
else()
get_filename_component(__target ${target_name_or_src} NAME_WE)
prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}")
endif()
caffe2_binary_target(${target_name_or_src})
target_compile_options(${__target} PRIVATE ${HIP_CXX_FLAGS})
target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDES})
endfunction()

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import argparse
import os
import sys
@ -35,18 +34,6 @@ ignores = [
file_extensions = ['.cc', '.cu', '.h', '.cuh']
parser = argparse.ArgumentParser(
description="The Script to Hipify Caffe2")
parser.add_argument(
'--hip-suffix',
type=str,
default='cc',
help="The suffix for the hipified files",
required=False)
args = parser.parse_args()
hipify_python.hipify(
project_directory=proj_dir,
output_directory=proj_dir,
@ -54,5 +41,4 @@ hipify_python.hipify(
extensions=file_extensions,
ignores=ignores,
hipify_caffe2=True,
add_static_casts_option=True,
hip_suffix=args.hip_suffix)
add_static_casts_option=True)

View File

@ -2216,19 +2216,19 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([
("cuda_stream" , ("hip_stream", API_CAFFE2)),
("context_gpu" , ("hip/context_hip", API_CAFFE2)),
("common_gpu" , ("hip/common_hip", API_CAFFE2)),
("mixed_utils" , ("hip/mixed_utils_hip", API_CAFFE2)),
("operator_fallback_gpu" , ("hip/operator_fallback_hip", API_CAFFE2)),
("spatial_batch_norm_op_gpu_impl" , ("hip/spatial_batch_norm_op_hip_impl", API_CAFFE2)),
("recurrent_network_executor_gpu" , ("hip/recurrent_network_executor_hip", API_CAFFE2)),
("max_pool_with_index_gpu", ("hip/max_pool_with_index_hip", API_CAFFE2)),
("THCCachingAllocator_gpu", ("hip/THCCachingAllocator_hip", API_CAFFE2)),
("top_k_heap_selection", ("hip/top_k_heap_selection_hip", API_CAFFE2)),
("top_k_radix_selection", ("hip/top_k_radix_selection_hip", API_CAFFE2)),
("GpuDefs", ("hip/GpuDefs_hip", API_CAFFE2)),
("GpuScanUtils", ("hip/GpuScanUtils_hip", API_CAFFE2)),
("GpuBitonicSort", ("hip/GpuBitonicSort_hip", API_CAFFE2)),
("/context_gpu" , ("/hip/context_gpu", API_CAFFE2)),
("/common_gpu" , ("/hip/common_gpu", API_CAFFE2)),
("/mixed_utils" , ("/hip/mixed_utils", API_CAFFE2)),
("/operator_fallback_gpu" , ("/hip/operator_fallback_gpu", API_CAFFE2)),
("/spatial_batch_norm_op_gpu_impl" , ("/hip/spatial_batch_norm_op_gpu_impl", API_CAFFE2)),
("/recurrent_network_executor_gpu" , ("/hip/recurrent_network_executor_gpu", API_CAFFE2)),
("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)),
("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)),
("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)),
("/top_k_radix_selection", ("/hip/top_k_radix_selection", API_CAFFE2)),
("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)),
("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)),
("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)),
("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)),
("REGISTER_CUDA_OPERATOR" , ("REGISTER_HIP_OPERATOR", API_CAFFE2)),
("CUDA_1D_KERNEL_LOOP" , ("HIP_1D_KERNEL_LOOP", API_CAFFE2)),

View File

@ -242,8 +242,7 @@ def preprocess(
all_files,
show_detailed=False,
show_progress=True,
hipify_caffe2=False,
hip_suffix='cc'):
hipify_caffe2=False):
"""
Call preprocessor on selected files.
@ -259,15 +258,14 @@ def preprocess(
stats = {"unsupported_calls": [], "kernel_launches": []}
for filepath in all_files:
preprocessor(output_directory, filepath, stats, hipify_caffe2, hip_suffix)
preprocessor(output_directory, filepath, stats, hipify_caffe2)
# Show what happened
if show_progress:
print(
filepath, "->",
get_hip_file_path(
filepath,
hipify_caffe2=hipify_caffe2,
hip_suffix=hip_suffix))
hipify_caffe2=hipify_caffe2))
finished_count += 1
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
@ -710,7 +708,7 @@ def disable_function(input_string, function, replace_style):
return output_string
def get_hip_file_path(filepath, hipify_caffe2, hip_suffix):
def get_hip_file_path(filepath, hipify_caffe2):
"""
Returns the new name of the hipified file
"""
@ -744,13 +742,10 @@ def get_hip_file_path(filepath, hipify_caffe2, hip_suffix):
#
# - If the file name contains "CUDA", replace it with "HIP", AND
#
# - If the file name contains "gpu", replace it with "hip".
#
# If NONE of the above occurred, then append "_hip" to the end of
# the filename (before the extension).
#
# Furthermore, ALWAYS replace '.cu' with '.cc', to appease the hcc
# compiler.
# Furthermore, ALWAYS replace '.cu' with '.hip'.
#
# This isn't set in stone; we might adjust this to support other
# naming conventions.
@ -760,21 +755,16 @@ def get_hip_file_path(filepath, hipify_caffe2, hip_suffix):
# currently support this file extension.
if ext == '.cu':
ext = '.' + hip_suffix
ext = '.hip'
orig_dirpath = dirpath
orig_root = root
dirpath = dirpath.replace('cuda', 'hip')
root = root.replace('gpu', 'hip')
root = root.replace('CUDA', 'HIP')
if dirpath == orig_dirpath:
dirpath = os.path.join(dirpath, 'hip')
if root == orig_root:
root += "_hip"
return os.path.join(dirpath, root + ext)
@ -786,13 +776,13 @@ def is_caffe2_gpu_file(filepath):
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
def preprocessor(output_directory, filepath, stats, hipify_caffe2, hip_suffix):
def preprocessor(output_directory, filepath, stats, hipify_caffe2):
""" Executes the CUDA -> HIP conversion on the specified file. """
fin_path = os.path.join(output_directory, filepath)
with open(fin_path, 'r') as fin:
output_source = fin.read()
fout_path = os.path.join(output_directory, get_hip_file_path(filepath, hipify_caffe2, hip_suffix))
fout_path = os.path.join(output_directory, get_hip_file_path(filepath, hipify_caffe2))
if not os.path.exists(os.path.dirname(fout_path)):
os.makedirs(os.path.dirname(fout_path))
@ -1257,8 +1247,7 @@ def main():
add_static_casts_option=args.add_static_casts,
hipify_caffe2=args.hipify_caffe2,
ignores=args.ignores,
show_progress=args.show_progress,
hip_suffix=args.hip_suffix)
show_progress=args.show_progress)
def hipify(
@ -1272,7 +1261,6 @@ def hipify(
hipify_caffe2=False,
ignores=(),
show_progress=True,
hip_suffix='cc',
):
if project_directory == "":
project_directory = os.getcwd()
@ -1389,8 +1377,7 @@ def hipify(
all_files,
show_detailed=show_detailed,
show_progress=show_progress,
hipify_caffe2=hipify_caffe2,
hip_suffix=hip_suffix)
hipify_caffe2=hipify_caffe2)
# Extract all of the kernel parameter and template type information.
if add_static_casts_option:
@ -1409,8 +1396,7 @@ def hipify(
output_directory,
get_hip_file_path(
filepath,
hipify_caffe2=hipify_caffe2,
hip_suffix=hip_suffix)),
hipify_caffe2=hipify_caffe2)),
KernelTemplateParams)