With this PR, we can turn on the inductor UTs on Windows CPU.
changes:
1. Turn on inductor UTs on Windows CPU.
2. Add a shard to balance added UTs, otherwise it should run timeout.
3. Fixed `test_invalid_artifact_flag_error_msg`.
4. Skiped `test_distributed_rank_logging` and `test_disable_recursive_false`.
5. Skiped whole UT `test_cpu_select_algorithm.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160161
Approved by: https://github.com/jansel
**Summary**
Add a configuration option to enable a smaller dequantization buffer for WOQ INT4 CPP GEMM template. This can improve the performance of the WOQ INT4 GEMM template in cases where M is small. In such scenarios, matrix B cannot be effectively reused across matrix A, and we found that reducing the Kc block size can lead to better performance.
**Test Plan**
```
python test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_with_small_buffer_config
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156395
Approved by: https://github.com/jansel
ghstack dependencies: #156407, #156387
### Summary
int8 WoQ GEMM concat linear optimization pertaining to the same activation applied to 3 sets of weights of the same shape.
### Perf data
GPT-J 128 input tokens, 128 output tokens.
32 physical cores of one socket of Intel(R) Xeon(R) 6972P (Xeon Gen 5). tcmalloc & Intel OpenMP were preloaded.
| May 8 nightly first token latency | First token latency with this implementation | Rest token latency with May 8 nightly | Rest token latency with this implementation combined with #149373 |
|---|---|---|---|
|202 ms | 190 ms | 33 ms | 30 ms|
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153004
Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/jansel
Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
**Summary**
There is an accuracy issue when `Nc_block` is greater than 1 in WOQ int4 GEMM. Previously, we used the slice `{%- set tile_W = kernel.slice_nd(W, [("n_start", "n_start + n_size"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}`, which means that each `ni` in `Nc_block` takes the exact same N slice from `n_start` to `n_start + n_size`, leading to the accuracy problem. This accuracy issue is exposed by [PR #156174](https://github.com/pytorch/pytorch/pull/156174), which changes `block_N` from 64 to 32. This change increases the likelihood of `Nc_block` being greater than 1, making it more likely to trigger the issue. This PR will fix this accuracy issue.
**Test Plan**
```
python test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_amx_Nc_larger_than_one
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156407
Approved by: https://github.com/CaoE
**Summary**
GEMM templates for INT4 weights are used for lowering `aten._weight_int4pack_mm_for_cpu` with Inductor when max-autotune is on. Currently, AMX-based microkernels are used only when M >= 16 if input tensor has shape [M, K]. However, we find that AMX kernel brings performance benefit when 4 < M < 16. For example, on a 6th gen of Intel(R) Xeon(R) platform, E2E latency can be improved by up to > 20% when running Llama-3.1-8B on 32 cores for M = 8. So, this PR changes the threshold so that AMX is used when M > 4.
**Test plan**
```
pytest test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155444
Approved by: https://github.com/sanchitintel, https://github.com/leslie-fang-intel
### Summary
Fixes#148494
Explicitly prefetch the cache lines of the next `B` block to accelerate int8 WoQ (BF16 activation, int8 statically quantized weights) GEMM for small `M` dimension.
Some of this code (outer loops of the GEMM) is being ported over from Intel Extension for PyTorch. The macro-kernel* and the micro-kernel* are essentially the same, but optionally prefetch a block of B. Templatization is being used to prevent branching causing a slowdown due to unnecessary prefetching.
\* - in [BLIS](https://dl.acm.org/doi/10.1145/2764454) parlance
### Performance data with BS 1
Machine: 32 cores of one socket of a Intel Xeon SP Gen 5 machine
| Model | input tokens | output tokens | next-token latency before this PR | Next-token latency after this change | Speedup |
|-----------|-------------|-----------------|--------------------------------------|------------------------------------------|-----------|
|GPT-J | 128 | 128 | 42 ms | 38 ms | 9.52 % |
| GPT-J | 1024 | 1024 | 48 ms | 45 ms | 6.25 % |
|LLaMA 3.1 8B Instruct | 128 | 128 | 52 ms | 47 ms| 9.61% |
|LLaMA 3.1 8B Instruct | 1024 | 1024 | 57 ms | 53 ms| 7.01% |
While the input shapes of GEMMs corresponding to linear for next-token computation remain the same in case of different number of input & output tokens, the difference in next-token latency is due to attention for those cases
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149373
Approved by: https://github.com/leslie-fang-intel, https://github.com/Xia-Weiwen
Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
Without MKL there is only 1 epilogue, not 2 because `addmm` is used instead of `packed_linear/_mkl_linear`.
This fails first at `TestSelectAlgorithmCPU.test_linear_with_in_out_buffer_batch_size_8_in_features_3_in_features2_192_image_size_224_out_features_64_bias_True_cpu_float32`
Instead of skipping the whole test just adjust the count for the single check.
Final numbers of `test/inductor/test_cpu_select_algorithm.py` without MKL:
```
Ran 1337 tests
OK (skipped=1211)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151548
Approved by: https://github.com/jansel
**Summary**
It's part of the task to enable max-autotune with GEMM template for WoQ INT4 GEMM on CPU.
This PR adds AMX-based GEMM templates for `torch.ops.aten_weight_int4pack_mm_for_cpu`. It brings performance benefits on platforms where AMX is available.
**Validation results**
We have run GPT-J-6B and Llama-3-8B-Instruct on a 6th gen Xeon with 96 cores. Results show that the AMX-based microkernel outperforms AVX512-based one by >5x for prefill stage with 1024 input length.
**Test plan**
```
python test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_amx
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150603
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
### Summary
When the block-size for `N` dimension is `48` for the AMX GEMM micro-kernel for int8 WoQ (BF16 activation, int8 statically quantized weights), the logic for handling the tail is incorrect - we can't always dequantize 32 elements of weights at a time because we may need to dequantize `32` followed by `16` when `block_n` is `48` (for each `K`).
This PR fixes that logic, which was initially exposed with `M=17, N=1024, K=1024`.
This PR also fixes the case of `block_n` being 16.
I had introduced [this bug ](ca9813ea14) after misreading GEMM blockings as `["block_m", "block_k", "block_n"]` instead of `["block_m", "block_n", "block_k"]` (so I had wrongly assumed that `block_n` was always 32).
### Future work
While this PR simply fixes a bug, it's possible to optimize the code pertaining to dequantizing & caching the B buffer - for `block_n` being `16` or `48`, `K` would always be a multiple of 2, so `K * block_n` will always be a multiple of 32. Since `dequantized_B_buf` stores rows contiguously, when `block_n` would be `16` or `48`, we could store 32 BF16 elements at a time instead of storing `16` at a time (when `block_n` is 16), or `32` followed by `16` at a time (when `block_n` is 48). Such an optimization would lower `register -> memory` data movements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149359
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
For int8 dynamically quantized activation & int8 quantized weights, add a workaround for some indexing issue that expected an empty index ( so, was expecting a 0D tensor) in epilogue creator when the activation scale was sized [1, 1] by converting it into a 0D tensor.
The issue was discovered while running LLaMA2 quantized with torchao's `int8_dynamic_activation_int8_weight` quantization on CPU with max-autotune enabled (although this error would've occurred regardless).
The final hidden states tensor that's activation to LM head is of shape `[batch_size, sequence_length, hidden_dim]` during decoding. For decoding one token at a time with batch size 1, sequence length is 1. The activation scale is shaped `[1, 1]` (reshaped from `[1, 1, 1]`). However, Inductor epilogue creator expects a 0D tensor in this case (my guess is that the corresponding logic in Inductor expects a 0D tensor if a tensor has only one element, even if it's 1D?).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147033
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel
**Summary**
It's part of the task to enable max-autotune with GEMM template for WoQ INT4 GEMM on CPU.
This PR adds GEMM templates for `torch.ops.aten_weight_int4pack_mm_for_cpu`. The micro kernel used for the templates is based on AVX512 and it's a copy of the ATen implementation of `torch.ops.aten_weight_int4pack_mm_for_cpu` with minor changes.
Due to better blocking and loop schedule, the GEMM template based implementation outperforms the ATen implementation in all cases we tested.
**Test plan**
```
python test/inductor/test_cpu_select_algorithm.py -k test_int4_woq_mm_avx512
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146756
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
#146843 broke int8 WoQ GEMM's (for BF16 activation) AMX ISA implementation in the main branch.
UT: `python test/inductor/test_cpu_select_algorithm.py -v -k woq`
The issue remained undetected because in case of templated kernel compilation failure, the auto-tuning infra marks its runtime as `inf`, and the op against which it was being benchmarked is used, so UTs didn't fail even on machines that support AMX ISA.
`test/inductor/test_cpu_select_algorithm.py` UTs checked the value of the `select_algorithm_autotune` counter, which only counts how many ops were selected for autotuning against their templated codegened counterparts.
@leslie-fang-intel advised using a new counter. I added `counters["inductor"]["cpp_templated_kernel_counter"]`, which is incremented after a codegened kernel's compilation, so it'd help catch breakage scenarios in which a templated kernel could not be codegened due to a compilation failure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147895
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
Currently, the bfloat16 microkernel that uses AMX vectorization requires that the weights are in an interleaved VNNI format. For GEMM code, this hasn't been an issue because GEMM currently only supports constant weights, so the VNNI weight packing is done during compile-time and saved as a constant tensor to the graph. But for BMM ops where weights are not required to be constant, current code does an expensive reshape/VNNI packing for all BMM weights.
This PR removes the need for the reshape/packing for non-constant inputs by moving VNNI packing inside the AMX microkernel. A new `K * block_n` buffer is used to store the temporary packed weights. Weight packing involves interleaving 2 rows of weights.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146843
Approved by: https://github.com/jgong5, https://github.com/sanchitintel, https://github.com/leslie-fang-intel, https://github.com/jansel
**Summary**
This is a regression issue caused by a change in the FX node name. In commit 71010bf0972834e35a155e6a187e5c6649a5a36b, both the node name and target for the `get_attr` node in `V.graph.graph.nodes` were `_frozen_param2`. However, in the latest main, the node name has changed to `_reorder_linear_weight`. This PR fixes the regression by using the node's target instead of its name.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_cpp_weight_prune
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147056
Approved by: https://github.com/jgong5
**Summary**
Issue found when fixing https://github.com/pytorch/ao/issues/1662. A FP32 GEMM with an epilogue node `to_fp16` resulted in [generated code](https://gist.github.com/leslie-fang-intel/464fb112abdb105818ae09b057350e84), which failed to compile. The root cause is that we used the slice of global buffer `Y` as the output of micro GEMM instead of a `local buffer`. However, due to the `to_fp16` epilogue node, the global buffer `Y` has a float16 data type, leading to the failure. This fix will ensure the use of a local buffer in such cases.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_linear_to_lowp_fp
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146958
Approved by: https://github.com/jgong5
Summary:
This commit fixes a crash in the gemm template lowering caused by hitting an [assert](fd515e4f59/torch/_inductor/codegen/common.py (L1181)) that a buffer was previously removed.
The assert triggers because in the first gemm lowering we use a local accumulation buffer, which causes the original buffer name to be added to the `removed_buffers` set. Then in the next gemm lowering we use the global buffer for accumulation, but that buffer name is already in the `removed_buffers` set.
The fix is to add a unique suffix to the buffer name to avoid triggering the assert from different gemm lowerings.
Differential Revision: D68814625
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146353
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
Summary:
The bmm template generates code like this
```
template<bool accum>
void cpp_fused_bmm_66_micro_gemm(...) {
...
}
void single_thread_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void threaded_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void cpp_fused_bmm_66(...)
{
...
single_thread_mm(...);
...
threaded_mm(...);
...
}
```
The generated `fused_bmm` and `fused_bmm_microgemm` functions both have unique identifiers added to their names, but the `single_threaded_mm` and `threaded_mm` do not.
This diff adds unique identifies to those generated functions as well. The identifier is based on the kernel name. So for the example above we would generate a bmm template name like `cpp_fused_bmm_66_single_thread_mm()`.
Differential Revision: D68364772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145303
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
## Summary
Templated `int8xint8->int32` GEMM that uses AMX ISA (present on Intel Xeon Gen 4 & above). Any epilogues such as weight scale, activation scale, and bias are applied per output block in a fused manner .
Performs well for large values of `M` dimension (assuming canonical dimensions [`M, K`] and [`K, N`] for the activation & weight matrices'/tensors' sizes) when the activation is quantized per-token.
Also supports SmoothQuant GEMM pattern when activation is quantized per-tensor (scalar scale) or per-token (vector scale is applied as an epilogue in this case).
Also increased coverage of GEMM template for uint8 activation, int8 weight GEMM UTs for when the activation zero point is a 1D tensor (the existing implementation only accepted 0D tensors). However, some of such UTs would have to be explicitly enabled with `max-autotune` Inductor config.
## Performance data
The templated codegened fused GEMM with M=32, K=4096, N=14336 used in LLaMA3 exhibits more than 2x perf-gain compared to oneDNN qlinear + mul (for activation's scale) with 48 cores of one socket of Xeon SP 4th gen Platinum 8468 when per-token quantization is used.
For M=1, K=4096, N=14336, regardless of whether per-tensor quantization was used for activation or per-token, the perf gain was more than 3x.
Intel OpenMP & libtcmalloc had been preloaded. All cores used by the workload corresponded to distinct physical cores.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143187
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5
Co-authored-by: Leslie Fang <leslie.fang@intel.com>
**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.
**Fusion**
- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.
**Code Gen**
We maintain a list of epilogues and codegen it one by one.
- If any of the GEMM has bias, we create a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.
**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012
- Support flexible number of GEMMs
- Share activation across GEMMs
- The Grouped GEMM Template supports independent activations
- However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
- Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```
**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.linear0(x), self.linear1(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).to(dtype=dtype).eval()
cm = torch.compile(m)
act_res = cm(input)
```
Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py
**Next Step**
- Support Epilogue fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
Fixes#143102
Addresses 2 problems relating to dynamic batch size in BMM autotuner:
1. With dynamic batch size, when the input is a sympy Mult expression, such as `s0*8` which occurs in many dynamo benchmark models. We address this by using `size_hints` to solve for any expressions. This is safe since this section of the code is only called to generate inputs for benchmarking.
2. Some epilogue nodes may use the dynamic batch size as part of the codegen, for example when an input to the epilogue node is transposed and has dynamic batch size in the stride of other dimensions. When these epilogue nodes exist, if the sizevar is not already present in the `kernel.args`, it will create a new sizevar with a name. It is possible that subsequent calls to `def_kernel` could overwrite this variable name, so to avoid this we pass all the sizevars as `extra_sizevars` to the calls to `def_kernel` for the GEMM functions, so no variable renaming happens later in the BMM definition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143141
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5
# Summary
The AMX ISA based GEMM micro-kernel template for int8 weight-only quantization (BF16 activation, int8 weights) should cache dequantized weights (int8 -> int32 -> fp32 -> bf16) so that they would not have to be dequantized again in subsequent calls to the _inner-kernel_ that uses the same weights.
This change leverages the fact that even for BF16 x BF16 GEMM template, cache-blocking ensures that `Nr * Kc` weight elements are cached in L1D cache (more info [here](https://static.sched.com/hosted_files/pytorch2024/59/TorchInductor%20CPU%20Backend%20Advancements%20-%20New%20Features%20and%20Performance%20Improvements_20240915.pdf)). Here, `Nr` is the register blocking size for `N` dimension (at the granularity of the GEMM micro-kernel, it's currently also the cache blocking size for `N` dimension, although that may change in the future), and `Kc` is the cache blocking size for `K` dimension.
The figure below is from the document linked above -
<img width="476" alt="image" src="https://github.com/user-attachments/assets/e23e5476-d910-46d1-a9b3-cbf77de76d94">
## Performance data
Collected on 48 physical cores of one socket of Intel Xeon Platinum 8468H (Xeon SP 4th gen). Intel OpenMP & tcmalloc were preloaded.
|M | N | K | Latency with ATen _weight_int8pack_mm | Latency with codegened templated GEMM (current main branch) | Latency with codegened templated GEMM (this PR) |
|-----|-----|-----|------|----------|----|
|4096|4096|4096| 45.844 ms | 9.322 ms| 5.2181 ms |
|4096|11008|4096| 127.618 ms |24.6258 ms | 13.6046 ms|
|4096|4096|11008| 121.953 ms | 25.4692 ms | 10.2669 ms |
|4096|32000|4096| 478.450 ms| 75.3942 ms | 48.21 ms |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136688
Approved by: https://github.com/jgong5
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/135106. For AOTI, there is the Inductor IR of weight
```
ReinterpretView(
StorageBox(
ConstantBuffer(name='L__self___mlp_0_weight', layout=FixedLayout('cpu', torch.float32, size=[64, 128], stride=[128, 1]))
),
FixedLayout('cpu', torch.float32, size=[128, 64], stride=[1, 128]),
origins=OrderedSet([addmm])
)
```
In the post-processing step of the GEMM template, the used weight was before permutation, leading to correctness issues. In this PR, we address this by reshaping the weight to the expected size and stride before the weight prepack.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_aot_inductor.py -k test_misc_1_max_autotune_True_non_abi_compatible_cpu
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_aoti_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_aoti_linear_multi_view_operations
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136421
Approved by: https://github.com/jgong5, https://github.com/desertfire
Fixes the max-autotune failure of `soft_actor_critic` of Torchbench in FP32 single thread dynamic shape case:
```log
File "/home/user/inductor/pytorch/torch/_inductor/codegen/cpp_micro_gemm.py", line 136, in codegen_call
C_ptr = f"&({kernel.index(C, [0, 0])})"
File "/home/user/inductor/pytorch/torch/_inductor/codegen/cpp_template_kernel.py", line 135, in index
else self.args.input(node.get_name())
File "/home/user/inductor/pytorch/torch/_inductor/codegen/common.py", line 1251, in input
assert name not in V.graph.removed_buffers, name
AssertionError: buf_GemmOut
```
The 1st and 2nd linear does not need to use local buffer while the 3rd linear needs to use local buffer.
The 3rd linear which uses local buffer will add its global buffer (named as `buf_GemmOut`) into `V.graph.removed_buffers`.
When scheduling the nodes, the 1st linear (won't use local buffer) will get its output buffer (also named as `buf_GemmOut`) from the input and found that it's in the `V.graph.removed_buffers` and raise AssertionError. The issue is that the output buffer of all these linears are all names with `buf_GemmOut`, which have a conflict.
Rename these buffers by adding the name of the `template_buffer` as the prefix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136419
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
ghstack dependencies: #136418, #136518
Fixes the compilation error of max-autotune for `maml_omniglot` (AMP and FP32) and `soft_actor_critic` (AMP) in Torchbench for single-thread dynamic shapes case:
```
/tmp/torchinductor_user/uv/cuvq6wenwp7us423onuvntkfx4cspmagha5beiknob7tiebzhupa.cpp: In function ‘void kernel(const bfloat16*, const bfloat16*, const bfloat16*, bfloat16*, int64_t)’:
/tmp/torchinductor_user/uv/cuvq6wenwp7us423onuvntkfx4cspmagha5beiknob7tiebzhupa.cpp:279:41: error: the value of ‘Mr_blocks’ is not usable in a constant expression
279 | constexpr int64_t m_block_end = Mr_blocks;
| ^~~~~~~~~
/tmp/torchinductor_user/uv/cuvq6wenwp7us423onuvntkfx4cspmagha5beiknob7tiebzhupa.cpp:237:19: note: ‘Mr_blocks’ was not initialized with a constant expression
237 | const int64_t Mr_blocks = (M + Mr - 1) / Mr;
| ^~~~~~~~~
```
The PR also updates the UT to add a test for `BS`=512 in single thread.
The previous case has `BS`=1024 equal to the `K` and `N` value. The generated code does not have symbolic shapes thus fails to capture the above issue.
By adding a case of `BS`=512, the generated code will have symbolic shape for the M dim and is able to reproduce the issue that this PR is addressing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136418
Approved by: https://github.com/jgong5