mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use sym_eq and sym_and on symbolic shapes in common_meta_baddbmm_bmm (#164781)
Differential Revision: D84005053 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164781 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ecd092bd9
commit
2855a045b3
@ -4351,6 +4351,8 @@ def meta_index_put_(self, indices, values, accumulate=False):
|
||||
|
||||
|
||||
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
|
||||
from torch.fx.experimental.symbolic_shapes import sym_and, sym_eq
|
||||
|
||||
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
||||
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
||||
|
||||
@ -4364,7 +4366,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype
|
||||
output_size = (bs, res_rows, res_cols)
|
||||
|
||||
torch._check(
|
||||
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
|
||||
sym_and(sym_eq(batch2_sizes[0], bs), sym_eq(batch2_sizes[1], contraction_size)),
|
||||
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
|
||||
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
|
||||
)
|
||||
@ -4384,7 +4386,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype
|
||||
if not is_bmm and self_baddbmm is not None:
|
||||
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
|
||||
torch._check(
|
||||
self_baddbmm.size() == output_size,
|
||||
sym_eq(self_baddbmm.size(), output_size),
|
||||
lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user