Use sym_eq and sym_and on symbolic shapes in common_meta_baddbmm_bmm (#164781)

Differential Revision: D84005053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164781
Approved by: https://github.com/Skylion007
This commit is contained in:
Colin Peppler
2025-10-07 18:25:00 +00:00
committed by PyTorch MergeBot
parent 9ecd092bd9
commit 2855a045b3

View File

@ -4351,6 +4351,8 @@ def meta_index_put_(self, indices, values, accumulate=False):
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
from torch.fx.experimental.symbolic_shapes import sym_and, sym_eq
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
@ -4364,7 +4366,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype
output_size = (bs, res_rows, res_cols)
torch._check(
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
sym_and(sym_eq(batch2_sizes[0], bs), sym_eq(batch2_sizes[1], contraction_size)),
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
)
@ -4384,7 +4386,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype
if not is_bmm and self_baddbmm is not None:
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
torch._check(
self_baddbmm.size() == output_size,
sym_eq(self_baddbmm.size(), output_size),
lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
)