mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
improve shape checks for grouped_mm (#159666)
Check that contraction dimension matches between tensors if it's known, and do device-side checks for correct offsets Pull Request resolved: https://github.com/pytorch/pytorch/pull/159666 Approved by: https://github.com/danielvegamyhre, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
465fe4d9f7
commit
a81ffbc5f5
@ -7369,6 +7369,12 @@ def _meta_grouped_mm_common(
|
||||
mat_a_is_2d = mat_a.dim() == 2
|
||||
mat_b_is_2d = mat_b.dim() == 2
|
||||
|
||||
if not mat_a_is_2d or not mat_b_is_2d:
|
||||
torch._check(
|
||||
mat_a.size(-1) == mat_b.size(-2),
|
||||
"contraction dimension of mat_a and mat_b must match",
|
||||
)
|
||||
|
||||
if scaled:
|
||||
|
||||
def is_row_major(mat):
|
||||
|
Reference in New Issue
Block a user