mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUTLASS] [CUDA] SM100 GroupMM (#156203)
Closes https://github.com/pytorch/pytorch/issues/156202 PR adds blackwell support for GroupMM Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html Did some preliminary benchmarking of H200 vs B200 Script ```py import torch print(torch.__file__) device = torch.device("cuda") dtype = torch.bfloat16 shapes = [ (16, 128000, 7168, 7168), (128, 1, 2048, 7168) ] for batch, M, N, K in shapes: a = torch.randn(batch, M, K, device=device, dtype=dtype) b = torch.randn(batch, N, K, device=device, dtype=dtype) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) for i in range(5): c = torch._grouped_mm(a, b) num_iter = 50 start_event.record() for i in range(num_iter): c = torch._grouped_mm(a, b) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iter print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}") print(f"Time per Iteration:\t {avg_time_ms:.4f} ms") ``` On H200 ``` batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 298.6668 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 4.1462 ms ``` B200 ``` batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 190.7458 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 3.0680 ms ``` nsys nvprof ``` root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py WARNING: python and any of its children processes will be profiled. Collecting data... batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 192.6420 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 1.2255 ms Generating '/tmp/nsys-report-6a53.qdstrm' [1/7] [========================100%] report1.nsys-rep [2/7] [========================100%] report1.sqlite [3/7] Executing 'nvtx_sum' stats report SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data. [4/7] Executing 'cuda_api_sum' stats report Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ------------ ------------ -------- ----------- ------------ --------------------------------- 98.9 10586895744 2 5293447872.0 5293447872.0 73786464 10513109280 7381715954.2 cudaDeviceSynchronize 1.0 104084608 5 20816921.6 33552480.0 100800 34786208 18048125.3 cudaMalloc 0.1 5694304 4 1423576.0 1416656.0 1258560 1602432 181668.1 cudaGetDeviceProperties_v2_v12000 0.1 5430496 130 41773.0 4560.0 2496 3854368 345761.8 cudaLaunchKernel 0.0 587584 110 5341.7 4992.0 4224 16992 1482.0 cudaLaunchKernelExC_v11060 0.0 119200 660 180.6 128.0 96 4128 206.7 cudaGetDriverEntryPoint_v11030 0.0 68352 660 103.6 64.0 32 4928 224.6 cuTensorMapEncodeTiled 0.0 34976 49 713.8 224.0 160 6720 1343.4 cudaStreamIsCapturing_v10000 0.0 32992 4 8248.0 7456.0 4128 13952 4804.4 cudaEventRecord 0.0 16928 4 4232.0 3600.0 1728 8000 2764.7 cudaEventQuery 0.0 16288 4 4072.0 3568.0 1952 7200 2396.1 cudaEventCreateWithFlags 0.0 13632 4 3408.0 2672.0 544 7744 3408.7 cudaEventDestroy 0.0 1056 1 1056.0 1056.0 1056 1056 0.0 cuModuleGetLoadingMode [5/7] Executing 'cuda_gpu_kern_sum' stats report Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ----------- ----------- --------- --------- ----------- ---------------------------------------------------------------------------------------------------- 99.0 10549232845 55 191804233.5 192944479.0 165746368 203645313 5353204.3 void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm… 0.6 67327135 55 1224129.7 1330656.0 924320 1364928 182180.4 void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm… 0.3 34854783 20 1742739.1 1597856.0 10080 3899616 818421.2 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat… 0.0 354880 110 3226.2 3296.0 1920 4160 554.4 void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:… ``` The kernel names are too long to be shown via nvprof, I pasted this from nsight systems ``` small kernel 1SM 100.0% 1.286 ms 1 1.286 ms 1.286 ms 1.286 ms 1.286 ms 0 ns void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params) large kernel 2SM 100.0% 194.178 ms 1 194.178 ms 194.178 ms 194.178 ms 194.178 ms 0 ns void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203 Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
996206e66f
commit
772d590415
@ -128,7 +128,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
"90a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
|
||||
"90a")
|
||||
"90a;100a")
|
||||
|
||||
endif()
|
||||
|
||||
|
Reference in New Issue
Block a user