Add unique identifer to bmm thread_mm functions (#145303)

Summary:
The bmm template generates code like this

```
template<bool accum>
void cpp_fused_bmm_66_micro_gemm(...) {
    ...
}

void single_thread_mm() {
    ...
    cpp_fused_bmm_66_micro_gemm(...)
    ...
}

void threaded_mm() {
    ...
    cpp_fused_bmm_66_micro_gemm(...)
    ...
}

void cpp_fused_bmm_66(...)
{
    ...
    single_thread_mm(...);
    ...
    threaded_mm(...);
    ...
}
```

The generated  `fused_bmm` and `fused_bmm_microgemm` functions both have unique identifiers added to their names, but the `single_threaded_mm` and `threaded_mm` do not.

This diff adds unique identifies to those generated functions as well. The identifier is based on the kernel name. So for the example above we would generate a bmm template name like `cpp_fused_bmm_66_single_thread_mm()`.

Differential Revision: D68364772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145303
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
This commit is contained in:
David Peixotto
2025-01-24 17:35:50 +00:00
committed by PyTorch MergeBot
parent 547c18ee9f
commit 97c0b7cb0a
2 changed files with 40 additions and 4 deletions

View File

@ -2264,6 +2264,41 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@dtypes(torch.float)
def test_aoti_bmm_unique_identifiers(self, dtype):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w):
y = x @ w
return y @ w
counters.clear()
x = torch.randn(3, 64, 64).to(dtype=dtype)
w = torch.randn(3, 64, 64).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(x, w)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
"cpu",
mod,
(x, w),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):