[EASY] Support SymInt tracing on broadcast_shapes (#113877)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113877
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-11-16 17:06:15 -08:00
committed by PyTorch MergeBot
parent e8ee14292e
commit 4979f9c0d7
2 changed files with 12 additions and 2 deletions

View File

@ -106,7 +106,7 @@ def broadcast_shapes(*shapes):
if not torch.jit.is_tracing():
max_len = 0
for shape in shapes:
if isinstance(shape, int):
if isinstance(shape, (int, torch.SymInt)):
if max_len < 1:
max_len = 1
elif isinstance(shape, (tuple, list)):
@ -115,7 +115,7 @@ def broadcast_shapes(*shapes):
max_len = s
result = [1] * max_len
for shape in shapes:
if isinstance(shape, int):
if isinstance(shape, (int, torch.SymInt)):
shape = (shape,)
if isinstance(shape, (tuple, list)):
for i in range(-1, -1 - len(shape), -1):