Isuru Fernando
2025-08-14 11:44:56 +00:00
committed by PyTorch MergeBot
parent e4de93f6a3
commit 781e9a7724
2 changed files with 22 additions and 0 deletions

View File

@ -7457,6 +7457,14 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
) )
def test_constant_pad_2d_strides_nonpositive(self):
def fn(a):
return torch.constant_pad_nd(a, [0, 0, 0, -2, 0, 0])
self.common(
fn, (torch.empty_strided((2, 4, 5), (20, 1, 4), dtype=torch.float32),)
)
@skip_if_gpu_halide # misaligned address @skip_if_gpu_halide # misaligned address
def test_constant_pad_3d(self): def test_constant_pad_3d(self):
def fn(a): def fn(a):

View File

@ -7638,6 +7638,20 @@ def _constant_pad_nd_meta(input, pad, value=0):
f"{l_inp} dimensions.", f"{l_inp} dimensions.",
) )
if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad):
c_input = input
for i in range(l_diff, l_inp):
pad_idx = 2 * (l_inp - i - 1)
if pad[pad_idx] < 0:
c_input = c_input.narrow(
i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]
)
if pad[pad_idx + 1] < 0:
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
return c_input.clone()
new_shape = list(input_sizes[:l_diff]) new_shape = list(input_sizes[:l_diff])
for i in range(l_pad): for i in range(l_pad):
pad_idx = len(pad) - ((i + 1) * 2) pad_idx = len(pad) - ((i + 1) * 2)