[inductor] Don't require_dense for grid_sampler_2d_backward (#163415)

Fixes #163372

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163415
Approved by: https://github.com/Skylion007
ghstack dependencies: #163386, #163398, #163387, #163414
This commit is contained in:
Jason Ansel
2025-09-22 08:12:16 -07:00
committed by PyTorch MergeBot
parent c8fd2b45e5
commit 4fc271e559
2 changed files with 51 additions and 1 deletions

View File

@ -7409,6 +7409,56 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
rtol=1.3e-06,
)
@requires_gpu()
def test_grid_sampler_expand_preserves_view(self):
if not self.device.startswith("cuda"):
self.skipTest("requires CUDA")
torch.manual_seed(0)
torch._dynamo.reset()
repeats = 9000
batch = 48
channels = 3
img = 224
grid_size = 13
device = self.device
class ExpandGridSampler(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.grid = torch.nn.Parameter(
torch.randn(repeats, grid_size, grid_size, 2, device=device)
)
self.fc = torch.nn.Linear(grid_size * grid_size * channels, 16).to(
device
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
per_channel = []
for i in range(channels):
channel = x[:, i, ...].expand(repeats, -1, -1, -1)
patch = torch.nn.functional.grid_sample(
channel,
self.grid,
mode="bilinear",
align_corners=False,
padding_mode="border",
)
patch = patch.transpose(0, 1).flatten(start_dim=2)
per_channel.append(patch)
x = torch.cat(per_channel, dim=2)
return self.fc(x)
model = ExpandGridSampler().to(device)
compiled = torch.compile(model, backend="inductor")
inp = torch.randn(batch, channels, img, img, device=device)
out = compiled(inp)
out.sum().backward()
self.assertIsNotNone(model.grid.grad)
def test_upsample_bicubic2d(self):
def fn(a):
return (

View File

@ -2795,7 +2795,7 @@ make_fallback(aten.replication_pad2d_backward)
make_fallback(aten.upsample_linear1d_backward)
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
make_fallback(aten.upsample_trilinear3d_backward)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
make_fallback(aten.grid_sampler_2d_backward)
make_fallback(aten._pdist_backward)