Files
pytorch/cmake/Codegen.cmake
AaronWang04 772d590415 [CUTLASS] [CUDA] SM100 GroupMM (#156203)
Closes https://github.com/pytorch/pytorch/issues/156202

PR adds blackwell support for GroupMM

Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html

Did some preliminary benchmarking of H200 vs B200

Script
```py
import torch
print(torch.__file__)
device = torch.device("cuda")
dtype = torch.bfloat16

shapes = [
    (16, 128000, 7168, 7168),
    (128, 1, 2048, 7168)
]

for batch, M, N, K in shapes:
    a = torch.randn(batch, M, K, device=device, dtype=dtype)
    b = torch.randn(batch, N, K, device=device, dtype=dtype)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for i in range(5): c = torch._grouped_mm(a, b)

    num_iter = 50
    start_event.record()

    for i in range(num_iter): c = torch._grouped_mm(a, b)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iter
    print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}")
    print(f"Time per Iteration:\t {avg_time_ms:.4f} ms")
```

On H200
```
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 298.6668 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 4.1462 ms
```

B200
```
batch: 16       M: 128000       N: 7168 K: 7168
Time per Iteration:      190.7458 ms
batch: 128      M: 1    N: 2048 K: 7168
Time per Iteration:      3.0680 ms
```
nsys nvprof
```
root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py
WARNING: python and any of its children processes will be profiled.

Collecting data...
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 192.6420 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 1.2255 ms
Generating '/tmp/nsys-report-6a53.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)                 Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  ---------------------------------
     98.9      10586895744          2  5293447872.0  5293447872.0  73786464  10513109280  7381715954.2  cudaDeviceSynchronize
      1.0        104084608          5    20816921.6    33552480.0    100800     34786208    18048125.3  cudaMalloc
      0.1          5694304          4     1423576.0     1416656.0   1258560      1602432      181668.1  cudaGetDeviceProperties_v2_v12000
      0.1          5430496        130       41773.0        4560.0      2496      3854368      345761.8  cudaLaunchKernel
      0.0           587584        110        5341.7        4992.0      4224        16992        1482.0  cudaLaunchKernelExC_v11060
      0.0           119200        660         180.6         128.0        96         4128         206.7  cudaGetDriverEntryPoint_v11030
      0.0            68352        660         103.6          64.0        32         4928         224.6  cuTensorMapEncodeTiled
      0.0            34976         49         713.8         224.0       160         6720        1343.4  cudaStreamIsCapturing_v10000
      0.0            32992          4        8248.0        7456.0      4128        13952        4804.4  cudaEventRecord
      0.0            16928          4        4232.0        3600.0      1728         8000        2764.7  cudaEventQuery
      0.0            16288          4        4072.0        3568.0      1952         7200        2396.1  cudaEventCreateWithFlags
      0.0            13632          4        3408.0        2672.0       544         7744        3408.7  cudaEventDestroy
      0.0             1056          1        1056.0        1056.0      1056         1056           0.0  cuModuleGetLoadingMode

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     99.0      10549232845         55  191804233.5  192944479.0  165746368  203645313    5353204.3  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.6         67327135         55    1224129.7    1330656.0     924320    1364928     182180.4  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.3         34854783         20    1742739.1    1597856.0      10080    3899616     818421.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0           354880        110       3226.2       3296.0       1920       4160        554.4  void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:…
```

The kernel names are too long to be shown via nvprof, I pasted this from nsight systems
```
small kernel 1SM
100.0%	1.286 ms	1	1.286 ms	1.286 ms	1.286 ms	1.286 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

large kernel 2SM
100.0%	194.178 ms	1	194.178 ms	194.178 ms	194.178 ms	194.178 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-06-28 23:02:00 +00:00

478 lines
19 KiB
CMake

# This ill-named file does a number of things:
# - Installs Caffe2 header files (this has nothing to do with code generation)
# - Configures caffe2/core/macros.h
# - Creates an ATen target for its generated C++ files and adds it
# as a dependency
# - Reads build lists defined in build_variables.bzl
################################################################################
# Helper functions
################################################################################
function(filter_list output input)
unset(result)
foreach(filename ${${input}})
foreach(pattern ${ARGN})
if("${filename}" MATCHES "${pattern}")
list(APPEND result "${filename}")
endif()
endforeach()
endforeach()
set(${output} ${result} PARENT_SCOPE)
endfunction()
function(filter_list_exclude output input)
unset(result)
foreach(filename ${${input}})
foreach(pattern ${ARGN})
if(NOT "${filename}" MATCHES "${pattern}")
list(APPEND result "${filename}")
endif()
endforeach()
endforeach()
set(${output} ${result} PARENT_SCOPE)
endfunction()
################################################################################
# -- [ Determine commit hash
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "from tools.generate_torch_version import get_sha;print(get_sha('.'), end='')"
OUTPUT_VARIABLE COMMIT_SHA
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
# ---[ Write the macros file
configure_file(
${CMAKE_CURRENT_LIST_DIR}/../caffe2/core/macros.h.in
${CMAKE_BINARY_DIR}/caffe2/core/macros.h)
# ---[ Installing the header files
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/../caffe2
DESTINATION include
FILES_MATCHING PATTERN "*.h")
if(NOT INTERN_BUILD_ATEN_OPS)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core
DESTINATION include/ATen
FILES_MATCHING PATTERN "*.h")
endif()
install(FILES ${CMAKE_BINARY_DIR}/caffe2/core/macros.h
DESTINATION include/caffe2/core)
# ---[ ATen specific
if(INTERN_BUILD_ATEN_OPS)
if(MSVC)
set(OPT_FLAG "/fp:strict ")
else(MSVC)
set(OPT_FLAG "-O3 ")
if("${CMAKE_BUILD_TYPE}" MATCHES "Debug")
set(OPT_FLAG " ")
endif()
endif(MSVC)
if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
endif()
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
# Handle files that may need sm89/sm90a/sm100a flags (stable/nightly
# builds are not built for these archs).
if(USE_CUDA)
# The stable/nightly builds do not enable some SM architectures,
# like 89/90a/100a. Still, some files need to be built for these
# architectures specifically. This function makes it possible to
# enable building given file for a specific such architecture, in
# case if PyTorch is built for corresponding other architecture;
# for example, it will enable building for SM 90a in case PyTorch
# built for SM 90, etc. For examples of how to use the function,
# see below the function itself.
function(_BUILD_FOR_ADDITIONAL_ARCHS file archs)
torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags)
set(_file_compile_flags "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
foreach(_arch ${archs})
if("${_arch}" STREQUAL "89")
if(_existing_arch_flags MATCHES ".*compute_86.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89")
endif()
endif()
if("${_arch}" STREQUAL "90a")
if(_existing_arch_flags MATCHES ".*compute_90.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif()
if("${_arch}" STREQUAL "100a")
if(_existing_arch_flags MATCHES ".*compute_100.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()
if("${_arch}" STREQUAL "120a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
endforeach()
endif()
list(JOIN _file_compile_flags " " _file_compile_flags)
set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}")
endfunction()
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;120a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
"90a;100a")
endif()
set(GEN_ROCM_FLAG)
if(USE_ROCM)
set(GEN_ROCM_FLAG --rocm)
endif()
set(GEN_MPS_FLAG)
if(USE_MPS)
set(GEN_MPS_FLAG --mps)
endif()
set(GEN_XPU_FLAG)
if(USE_XPU)
set(GEN_XPU_FLAG --xpu)
endif()
set(CUSTOM_BUILD_FLAGS)
if(INTERN_BUILD_MOBILE)
if(USE_VULKAN)
list(APPEND CUSTOM_BUILD_FLAGS --backend_whitelist CPU QuantizedCPU Vulkan)
else()
list(APPEND CUSTOM_BUILD_FLAGS --backend_whitelist CPU QuantizedCPU)
endif()
endif()
if(SELECTED_OP_LIST)
if(TRACING_BASED)
message(STATUS "Running tracing-based selective build given operator list: ${SELECTED_OP_LIST}")
list(APPEND CUSTOM_BUILD_FLAGS
--op_selection_yaml_path ${SELECTED_OP_LIST})
elseif(NOT STATIC_DISPATCH_BACKEND)
message(WARNING
"You have to run tracing-based selective build with dynamic dispatch.\n"
"Switching to STATIC_DISPATCH_BACKEND=CPU."
)
set(STATIC_DISPATCH_BACKEND CPU)
endif()
endif()
if(STATIC_DISPATCH_BACKEND)
message(STATUS "Custom build with static dispatch backends: ${STATIC_DISPATCH_BACKEND}")
list(LENGTH STATIC_DISPATCH_BACKEND len)
list(APPEND CUSTOM_BUILD_FLAGS
--static_dispatch_backend ${STATIC_DISPATCH_BACKEND})
endif()
# Codegen unboxing
if(USE_LIGHTWEIGHT_DISPATCH)
file(GLOB_RECURSE all_unboxing_script "${CMAKE_CURRENT_LIST_DIR}/../tools/jit/*.py")
list(APPEND CUSTOM_BUILD_FLAGS --skip_dispatcher_op_registration)
set(GEN_UNBOXING_COMMAND
"${Python_EXECUTABLE}" -m tools.jit.gen_unboxing
--source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
)
if(SELECTED_OP_LIST)
list(APPEND GEN_UNBOXING_COMMAND
--TEST_ONLY_op_registration_allowlist_yaml_path "${SELECTED_OP_LIST}")
endif()
set("GEN_UNBOXING_COMMAND_sources"
${GEN_UNBOXING_COMMAND}
--output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_unboxing_sources.cmake
)
message(STATUS "Generating sources for lightweight dispatch")
execute_process(
COMMAND ${GEN_UNBOXING_COMMAND_sources} --dry-run
RESULT_VARIABLE RETURN_VALUE
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
if(NOT RETURN_VALUE EQUAL 0)
message(FATAL_ERROR "Failed to get generated_unboxing_sources list")
endif()
include("${CMAKE_BINARY_DIR}/aten/src/ATen/generated_unboxing_sources.cmake")
add_custom_command(
COMMENT "Generating ATen unboxing sources"
OUTPUT
${generated_unboxing_sources}
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_unboxing_sources.cmake
COMMAND ${GEN_UNBOXING_COMMAND_sources}
DEPENDS ${all_unboxing_script} ${sources_templates}
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
else() # Otherwise do not generate or include sources into build.
set(generated_unboxing_sources "")
endif()
set(GEN_PER_OPERATOR_FLAG)
if(USE_PER_OPERATOR_HEADERS)
list(APPEND GEN_PER_OPERATOR_FLAG "--per-operator-headers")
endif()
set(GEN_COMMAND
"${Python_EXECUTABLE}" -m torchgen.gen
--source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
${GEN_PER_OPERATOR_FLAG}
${GEN_ROCM_FLAG}
${GEN_MPS_FLAG}
${GEN_XPU_FLAG}
${CUSTOM_BUILD_FLAGS}
)
file(GLOB_RECURSE headers_templates "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/templates/*\.h")
file(GLOB_RECURSE sources_templates "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/templates/*\.cpp")
set(declarations_yaml_templates "")
foreach(gen_type "headers" "sources" "declarations_yaml")
# The codegen outputs may change dynamically as PyTorch is
# developed, but add_custom_command only supports dynamic inputs.
#
# We work around this by generating a .cmake file which is
# included below to set the list of output files. If that file
# ever changes then cmake will be re-run automatically because it
# was included and so we get fully dynamic outputs.
set("GEN_COMMAND_${gen_type}"
${GEN_COMMAND}
--generate ${gen_type}
--output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
)
# Dry run to bootstrap the output variables
execute_process(
COMMAND ${GEN_COMMAND_${gen_type}} --dry-run
RESULT_VARIABLE RETURN_VALUE
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
if(NOT RETURN_VALUE EQUAL 0)
message(FATAL_ERROR "Failed to get generated_${gen_type} list")
endif()
include("${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake")
if(USE_XPU)
include("${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake")
endif()
message(STATUS "${gen_type} outputs: ${gen_outputs}")
set(OUTPUT_LIST
${generated_${gen_type}}
${cuda_generated_${gen_type}}
${core_generated_${gen_type}}
${cpu_vec_generated_${gen_type}}
${ops_generated_${gen_type}}
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake)
if(USE_XPU)
list(APPEND OUTPUT_LIST
${xpu_generated_${gen_type}}
${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake
)
endif()
add_custom_command(
COMMENT "Generating ATen ${gen_type}"
OUTPUT ${OUTPUT_LIST}
COMMAND ${GEN_COMMAND_${gen_type}}
DEPENDS ${all_python} ${${gen_type}_templates}
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
endforeach()
# Generated headers used from a CUDA (.cu) file are
# not tracked correctly in CMake. We make the libATen.so depend explicitly
# on building the generated ATen files to workaround.
add_custom_target(ATEN_CPU_FILES_GEN_TARGET DEPENDS
${generated_headers} ${core_generated_headers} ${cpu_vec_generated_headers} ${ops_generated_headers}
${generated_sources} ${core_generated_sources} ${cpu_vec_generated_sources} ${ops_generated_sources}
${generated_declarations_yaml} ${generated_unboxing_sources})
add_custom_target(ATEN_CUDA_FILES_GEN_TARGET DEPENDS
${cuda_generated_headers} ${cuda_generated_sources})
add_library(ATEN_CPU_FILES_GEN_LIB INTERFACE)
add_library(ATEN_CUDA_FILES_GEN_LIB INTERFACE)
add_dependencies(ATEN_CPU_FILES_GEN_LIB ATEN_CPU_FILES_GEN_TARGET)
add_dependencies(ATEN_CUDA_FILES_GEN_LIB ATEN_CUDA_FILES_GEN_TARGET)
if(USE_PER_OPERATOR_HEADERS)
target_compile_definitions(ATEN_CPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
endif()
if(USE_XPU)
add_custom_target(ATEN_XPU_FILES_GEN_TARGET DEPENDS
${xpu_generated_headers} ${xpu_generated_sources})
add_library(ATEN_XPU_FILES_GEN_LIB INTERFACE)
add_dependencies(ATEN_XPU_FILES_GEN_LIB ATEN_XPU_FILES_GEN_TARGET)
if(USE_PER_OPERATOR_HEADERS)
target_compile_definitions(ATEN_XPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
endif()
endif()
# Handle source files that need to be compiled multiple times for
# different vectorization options
file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp")
list(APPEND CPU_CAPABILITY_NAMES "DEFAULT")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}")
if(CXX_AVX512_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "AVX512")
if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma")
endif(MSVC)
endif(CXX_AVX512_FOUND)
if(CXX_AVX2_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION")
# Some versions of GCC pessimistically split unaligned load and store
# instructions when using the default tuning. This is a bad choice on
# new Intel and AMD processors so we disable it when compiling with AVX2.
# See https://stackoverflow.com/questions/52626726/why-doesnt-gcc-resolve-mm256-loadu-pd-as-single-vmovupd#tab-top
check_cxx_compiler_flag("-mno-avx256-split-unaligned-load -mno-avx256-split-unaligned-store" COMPILER_SUPPORTS_NO_AVX256_SPLIT)
if(COMPILER_SUPPORTS_NO_AVX256_SPLIT)
set(CPU_NO_AVX256_SPLIT_FLAGS "-mno-avx256-split-unaligned-load -mno-avx256-split-unaligned-store")
endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT)
list(APPEND CPU_CAPABILITY_NAMES "AVX2")
if(DEFINED ENV{ATEN_AVX512_256})
if($ENV{ATEN_AVX512_256} MATCHES "TRUE")
if(CXX_AVX512_FOUND)
message("-- ATen AVX2 kernels will use 32 ymm registers")
if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=native ${CPU_NO_AVX256_SPLIT_FLAGS}")
endif(MSVC)
endif(CXX_AVX512_FOUND)
endif()
else()
if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2")
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma -mf16c ${CPU_NO_AVX256_SPLIT_FLAGS}")
endif(MSVC)
endif()
endif(CXX_AVX2_FOUND)
if(CXX_VSX_FOUND)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_VSX_CPU_DEFINITION")
LIST(APPEND CPU_CAPABILITY_NAMES "VSX")
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_VSX_FLAGS}")
endif(CXX_VSX_FOUND)
if(CXX_ZVECTOR_FOUND)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_ZVECTOR_CPU_DEFINITION")
LIST(APPEND CPU_CAPABILITY_NAMES "ZVECTOR")
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
endif(CXX_ZVECTOR_FOUND)
if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND)
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
endif()
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")
# The sources list might get reordered later based on the capabilities.
# See NOTE [ Linking AVX and non-AVX files ]
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
function(process_vec NAME)
list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY)
set(NEW_IMPL ${CMAKE_BINARY_DIR}/aten/src/ATen/${NAME}.${CPU_CAPABILITY}.cpp)
configure_file("${PROJECT_SOURCE_DIR}/cmake/IncludeSource.cpp.in" ${NEW_IMPL})
set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp} PARENT_SCOPE) # Create list of copies
list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS)
if(MSVC)
set(EXTRA_FLAGS "/DCPU_CAPABILITY=${CPU_CAPABILITY} /DCPU_CAPABILITY_${CPU_CAPABILITY}")
else(MSVC)
set(EXTRA_FLAGS "-DCPU_CAPABILITY=${CPU_CAPABILITY} -DCPU_CAPABILITY_${CPU_CAPABILITY}")
endif(MSVC)
# Only parallelize the SortingKernel for now to avoid side effects
if(${NAME} STREQUAL "native/cpu/SortingKernel.cpp" AND NOT MSVC AND USE_OMP)
string(APPEND EXTRA_FLAGS " -D_GLIBCXX_PARALLEL")
endif()
# Disable certain warnings for GCC-9.X
if(CMAKE_COMPILER_IS_GNUCXX)
if(("${NAME}" STREQUAL "native/cpu/GridSamplerKernel.cpp") AND ("${CPU_CAPABILITY}" STREQUAL "DEFAULT"))
# See https://github.com/pytorch/pytorch/issues/38855
set(EXTRA_FLAGS "${EXTRA_FLAGS} -Wno-uninitialized")
endif()
if("${NAME}" STREQUAL "native/quantized/cpu/kernels/QuantizedOpKernels.cpp")
# See https://github.com/pytorch/pytorch/issues/38854
set(EXTRA_FLAGS "${EXTRA_FLAGS} -Wno-deprecated-copy")
endif()
endif()
set_source_files_properties(${NEW_IMPL} PROPERTIES COMPILE_FLAGS "${FLAGS} ${EXTRA_FLAGS}")
endfunction()
foreach(IMPL ${cpu_kernel_cpp_in})
file(RELATIVE_PATH NAME "${PROJECT_SOURCE_DIR}/aten/src/ATen/" "${IMPL}")
process_vec("${NAME}")
endforeach()
foreach(IMPL ${cpu_vec_generated_sources})
file(RELATIVE_PATH NAME "${CMAKE_BINARY_DIR}/aten/src/ATen/" "${IMPL}")
process_vec("${NAME}")
endforeach()
endforeach()
list(APPEND ATen_CPU_SRCS ${cpu_kernel_cpp})
endif()
function(append_filelist name outputvar)
set(_rootdir "${Torch_SOURCE_DIR}/")
# configure_file adds its input to the list of CMAKE_RERUN dependencies
configure_file(
${PROJECT_SOURCE_DIR}/build_variables.bzl
${PROJECT_BINARY_DIR}/caffe2/build_variables.bzl)
execute_process(
COMMAND "${Python_EXECUTABLE}" -c
"exec(open('${PROJECT_SOURCE_DIR}/build_variables.bzl').read());print(';'.join(['${_rootdir}' + x for x in ${name}]))"
WORKING_DIRECTORY "${_rootdir}"
RESULT_VARIABLE _retval
OUTPUT_VARIABLE _tempvar)
if(NOT _retval EQUAL 0)
message(FATAL_ERROR "Failed to fetch filelist ${name} from build_variables.bzl")
endif()
string(REPLACE "\n" "" _tempvar "${_tempvar}")
list(APPEND ${outputvar} ${_tempvar})
set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
endfunction()
set(NUM_CPU_CAPABILITY_NAMES ${NUM_CPU_CAPABILITY_NAMES} PARENT_SCOPE)
set(CPU_CAPABILITY_FLAGS ${CPU_CAPABILITY_FLAGS} PARENT_SCOPE)