improve shape checks for grouped_mm (#159666)

Check that contraction dimension matches between tensors if it's known, and do device-side checks for correct offsets
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159666
Approved by: https://github.com/danielvegamyhre, https://github.com/eqy
This commit is contained in:
Natalia Gimelshein
2025-08-02 00:12:21 +00:00
committed by PyTorch MergeBot
parent 465fe4d9f7
commit a81ffbc5f5
5 changed files with 30 additions and 4 deletions

View File

@ -7369,6 +7369,12 @@ def _meta_grouped_mm_common(
mat_a_is_2d = mat_a.dim() == 2
mat_b_is_2d = mat_b.dim() == 2
if not mat_a_is_2d or not mat_b_is_2d:
torch._check(
mat_a.size(-1) == mat_b.size(-2),
"contraction dimension of mat_a and mat_b must match",
)
if scaled:
def is_row_major(mat):