mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix meta for constant_pad_nd (#159878)
Fixes https://github.com/pytorch/pytorch/issues/144187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159878 Approved by: https://github.com/Skylion007, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4de93f6a3
commit
781e9a7724
@ -7638,6 +7638,20 @@ def _constant_pad_nd_meta(input, pad, value=0):
|
||||
f"{l_inp} dimensions.",
|
||||
)
|
||||
|
||||
if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad):
|
||||
c_input = input
|
||||
for i in range(l_diff, l_inp):
|
||||
pad_idx = 2 * (l_inp - i - 1)
|
||||
if pad[pad_idx] < 0:
|
||||
c_input = c_input.narrow(
|
||||
i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]
|
||||
)
|
||||
|
||||
if pad[pad_idx + 1] < 0:
|
||||
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
|
||||
|
||||
return c_input.clone()
|
||||
|
||||
new_shape = list(input_sizes[:l_diff])
|
||||
for i in range(l_pad):
|
||||
pad_idx = len(pad) - ((i + 1) * 2)
|
||||
|
Reference in New Issue
Block a user