mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
02c7ab2f9baac05bd199392b70bc016d55f99b13
22 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
| a6084b71ed |
[BE][1/X] Phase out usage of use_max_autotune() (#155847)
`use_max_autotune()` is likely not what people expect it to be; Originally, `use_max_autotune()` was setup to decide when we should include Triton templates as choices in GEMM autotuning. As expected, `use_max_autotune()=True` if `max_autotune=True` or `max_autotune_gemm=True`. However, with the addition of the offline GEMM autotuning cache two years back `use_max_autotune()=True` also in the case that `search_autotune_cache=True`; in this case though, `search_autotune_cache=True` should never trigger autotuning. Over time, people have used `use_max_autotune()` likely without realizing that this gives unexpected behavior if `search_autotune_cache=True`. We could rename the method to be more clear, but prefer to phase it out entirely for maximal clarity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155847 Approved by: https://github.com/jingsh, https://github.com/masnesral |
|||
| cb56df55dc |
[Inductor]Cleanup autotune_fallback_to_aten post-deprecation (#154331)
Fixes #153298 This PR is the 3rd and final step of #147479 All references to autotune_fallback_to_aten have been removed, and the feature is now deprecated. All calls to should_fallback_to_aten() were also removed, as they were deemed unnecessary. [henrylhtsang](https://github.com/henrylhtsang) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154331 Approved by: https://github.com/henrylhtsang |
|||
| c1dd75e4dc |
Add AOTI shim for _weight_int4pack_mm_cpu_tensor (#149031)
**Summary** Previous implementation of shim did not align with the design and it was removed by https://github.com/pytorch/pytorch/pull/148907 This PR adds it back in the files of MKLDNN backend and re-enable the CPP wrapper UT. **Test plan** ``` pytest -s test/inductor/test_cpu_cpp_wrapper.py -k test_woq_int4 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149031 Approved by: https://github.com/leslie-fang-intel, https://github.com/EikanWang, https://github.com/desertfire |
|||
| ab81ca5053 |
[Inductor][CPU] Add GEMM templates for _weight_int4pack_mm_for_cpu with AVX512 (#146756)
**Summary** It's part of the task to enable max-autotune with GEMM template for WoQ INT4 GEMM on CPU. This PR adds GEMM templates for `torch.ops.aten_weight_int4pack_mm_for_cpu`. The micro kernel used for the templates is based on AVX512 and it's a copy of the ATen implementation of `torch.ops.aten_weight_int4pack_mm_for_cpu` with minor changes. Due to better blocking and loop schedule, the GEMM template based implementation outperforms the ATen implementation in all cases we tested. **Test plan** ``` python test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_avx512 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146756 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel |
|||
| ca3aabc8e6 |
[Inductor][CPU] Add a lowering pass for _weight_int4pack_mm_for_cpu (#145250)
**Summary** It's part of the task to enable max-autotune with GEMM template for WoQ INT4 GEMM on CPU. This PR adds a lowering pass for `torch.ops.aten_weight_int4pack_mm_for_cpu`. This op is used for WoQ int4 in Torchao. The lowering pass is a prerequisite for max-autotune, which is planed to be enabled for this op in subsequent PRs. **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_woq_int4 python test/inductor/test_cpu_cpp_wrapper.py -k test_woq_int4 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145250 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168 ghstack dependencies: #145245 |
|||
| 41b38f755c |
Revert "Reverting the PR adding Kleidiai-based int4 kernels (#145392)" (#145505)
https://github.com/pytorch/pytorch/pull/134124 was reverted by https://github.com/pytorch/pytorch/pull/145392 due to KleidiAI clone issue. 1. This reverts commit 0940eb6d44f3cf69dd840db990245cbe1f78e770 (https://github.com/pytorch/pytorch/pull/145392 )and Fixes KleidiAI mirror issue. 2. KleidiAI is now cloned from github mirror instead of arm gitlab Change-Id: I7d6eee7214cd117d3057d615936fcc3ee6052fa2 Fixes https://github.com/pytorch/pytorch/issues/145273 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145505 Approved by: https://github.com/malfet |
|||
| 0940eb6d44 |
Reverting the PR adding Kleidiai-based int4 kernels (#145392)
Mitigation for https://github.com/pytorch/pytorch/issues/145273 Reverting https://github.com/pytorch/pytorch/pull/134124 and https://github.com/pytorch/pytorch/pull/144074 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145392 Approved by: https://github.com/ZainRizvi, https://github.com/malfet, https://github.com/atalman, https://github.com/digantdesai |
|||
| 94737e8a2a |
[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet |
|||
| 8136daff5a |
Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"
This reverts commit 4b82251011f85f9d1395b451d61e976af844d9b1. Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks lots of internal build ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2555953189)) |
|||
| 4b82251011 |
[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet |
|||
| 14fe1f7190 |
Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"
This reverts commit d3ff2d42c28a2c187cbedfd8f60b84a4dfa2d6bf. Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/malfet due to This broke S390 builds, includes cpuinfo unconditionally ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2552560208)) |
|||
| d3ff2d42c2 |
[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet |
|||
| 20f24e3fbd |
[inductor][cpp] Add BMM kernel template for autotuning (#129772)
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m > 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel |
|||
| d67b4f9e5f |
type _inductor/quantized_lowerings.py (#137598)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137598 Approved by: https://github.com/Skylion007 |
|||
| f951fcd1d7 |
Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel |
|||
| 89670d5bdd |
Revert "Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)"
This reverts commit 8fbd7d92a81b61d41363edb1b3902ba7701d5a27. Reverted https://github.com/pytorch/pytorch/pull/131887 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/131887#issuecomment-2285082401)) |
|||
| 8fbd7d92a8 |
Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel |
|||
| b6d477fd56 |
[BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768 Approved by: https://github.com/jansel |
|||
| ea614fb2b1 |
Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839 Approved by: https://github.com/oulgen |
|||
| 4b725e1619 |
[AOTInductor] Support quantized linear on CPU with fbgemm (#123069)
Summary: Added support for quantized linear on CPU with fbgemm. Specifically, for torch.ops.quantized.linear_unpacked_dynamic_fp16, we decompose it into two steps, pack weight, and fbgemm's qlinear with packed weight. Test Plan: Included in commit. test_aot_inductor::test_quantized_linear Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D55577959](https://our.internmc.facebook.com/intern/diff/D55577959) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123069 Approved by: https://github.com/hl475 |
|||
| 19d6004b97 |
add int8 woq mm pattern matcher (#120985)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120985 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/eellison |
|||
| 9319dd1c7c |
[Quant][Inductor] Enable the lowering of quantized maxpool2d (#105906)
**Summary** Enable the `dq-maxpool2d-q` pattern match and lower into `torch.ops.quantized.max_pool2d`. **Test Plan** ``` python -m pytest test_mkldnn_pattern_matcher.py -k test_qmaxpool2d python -m pytest test_quantized_op.py -k test_max_pool2d_pt2e ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105906 Approved by: https://github.com/jgong5, https://github.com/eellison ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639 |