mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
Add unique identifer to bmm thread_mm functions (#145303)
Summary:
The bmm template generates code like this
```
template<bool accum>
void cpp_fused_bmm_66_micro_gemm(...) {
...
}
void single_thread_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void threaded_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void cpp_fused_bmm_66(...)
{
...
single_thread_mm(...);
...
threaded_mm(...);
...
}
```
The generated `fused_bmm` and `fused_bmm_microgemm` functions both have unique identifiers added to their names, but the `single_threaded_mm` and `threaded_mm` do not.
This diff adds unique identifies to those generated functions as well. The identifier is based on the kernel name. So for the example above we would generate a bmm template name like `cpp_fused_bmm_66_single_thread_mm()`.
Differential Revision: D68364772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145303
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
This commit is contained in:
committed by
PyTorch MergeBot
parent
547c18ee9f
commit
97c0b7cb0a
@ -2264,6 +2264,41 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@dtypes(torch.float)
|
||||
def test_aoti_bmm_unique_identifiers(self, dtype):
|
||||
try:
|
||||
try:
|
||||
from . import test_aot_inductor_utils
|
||||
except ImportError:
|
||||
import test_aot_inductor_utils
|
||||
except Exception:
|
||||
# skip this UT if import failed
|
||||
return
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, w):
|
||||
y = x @ w
|
||||
return y @ w
|
||||
|
||||
counters.clear()
|
||||
x = torch.randn(3, 64, 64).to(dtype=dtype)
|
||||
w = torch.randn(3, 64, 64).to(dtype=dtype)
|
||||
mod = M().to(dtype=dtype).eval()
|
||||
with verify(dtype) as (atol, rtol), torch.no_grad():
|
||||
expected = mod(x, w)
|
||||
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
|
||||
"cpu",
|
||||
mod,
|
||||
(x, w),
|
||||
)
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
|
||||
|
||||
Reference in New Issue
Block a user