mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[EASY] Support SymInt tracing on broadcast_shapes (#113877)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113877 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8ee14292e
commit
4979f9c0d7
@ -106,7 +106,7 @@ def broadcast_shapes(*shapes):
|
||||
if not torch.jit.is_tracing():
|
||||
max_len = 0
|
||||
for shape in shapes:
|
||||
if isinstance(shape, int):
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
if max_len < 1:
|
||||
max_len = 1
|
||||
elif isinstance(shape, (tuple, list)):
|
||||
@ -115,7 +115,7 @@ def broadcast_shapes(*shapes):
|
||||
max_len = s
|
||||
result = [1] * max_len
|
||||
for shape in shapes:
|
||||
if isinstance(shape, int):
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, (tuple, list)):
|
||||
for i in range(-1, -1 - len(shape), -1):
|
||||
|
Reference in New Issue
Block a user