mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix meta for constant_pad_nd (#159878)
Fixes https://github.com/pytorch/pytorch/issues/144187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159878 Approved by: https://github.com/Skylion007, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4de93f6a3
commit
781e9a7724
@ -7457,6 +7457,14 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||||||
fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
|
fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_constant_pad_2d_strides_nonpositive(self):
|
||||||
|
def fn(a):
|
||||||
|
return torch.constant_pad_nd(a, [0, 0, 0, -2, 0, 0])
|
||||||
|
|
||||||
|
self.common(
|
||||||
|
fn, (torch.empty_strided((2, 4, 5), (20, 1, 4), dtype=torch.float32),)
|
||||||
|
)
|
||||||
|
|
||||||
@skip_if_gpu_halide # misaligned address
|
@skip_if_gpu_halide # misaligned address
|
||||||
def test_constant_pad_3d(self):
|
def test_constant_pad_3d(self):
|
||||||
def fn(a):
|
def fn(a):
|
||||||
|
@ -7638,6 +7638,20 @@ def _constant_pad_nd_meta(input, pad, value=0):
|
|||||||
f"{l_inp} dimensions.",
|
f"{l_inp} dimensions.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad):
|
||||||
|
c_input = input
|
||||||
|
for i in range(l_diff, l_inp):
|
||||||
|
pad_idx = 2 * (l_inp - i - 1)
|
||||||
|
if pad[pad_idx] < 0:
|
||||||
|
c_input = c_input.narrow(
|
||||||
|
i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
if pad[pad_idx + 1] < 0:
|
||||||
|
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
|
||||||
|
|
||||||
|
return c_input.clone()
|
||||||
|
|
||||||
new_shape = list(input_sizes[:l_diff])
|
new_shape = list(input_sizes[:l_diff])
|
||||||
for i in range(l_pad):
|
for i in range(l_pad):
|
||||||
pad_idx = len(pad) - ((i + 1) * 2)
|
pad_idx = len(pad) - ((i + 1) * 2)
|
||||||
|
Reference in New Issue
Block a user