mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Don't require_dense for grid_sampler_2d_backward (#163415)
Fixes #163372 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163415 Approved by: https://github.com/Skylion007 ghstack dependencies: #163386, #163398, #163387, #163414
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8fd2b45e5
commit
4fc271e559
@ -7409,6 +7409,56 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
rtol=1.3e-06,
|
||||
)
|
||||
|
||||
@requires_gpu()
|
||||
def test_grid_sampler_expand_preserves_view(self):
|
||||
if not self.device.startswith("cuda"):
|
||||
self.skipTest("requires CUDA")
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch._dynamo.reset()
|
||||
|
||||
repeats = 9000
|
||||
batch = 48
|
||||
channels = 3
|
||||
img = 224
|
||||
grid_size = 13
|
||||
device = self.device
|
||||
|
||||
class ExpandGridSampler(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.grid = torch.nn.Parameter(
|
||||
torch.randn(repeats, grid_size, grid_size, 2, device=device)
|
||||
)
|
||||
self.fc = torch.nn.Linear(grid_size * grid_size * channels, 16).to(
|
||||
device
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
per_channel = []
|
||||
for i in range(channels):
|
||||
channel = x[:, i, ...].expand(repeats, -1, -1, -1)
|
||||
patch = torch.nn.functional.grid_sample(
|
||||
channel,
|
||||
self.grid,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
padding_mode="border",
|
||||
)
|
||||
patch = patch.transpose(0, 1).flatten(start_dim=2)
|
||||
per_channel.append(patch)
|
||||
x = torch.cat(per_channel, dim=2)
|
||||
return self.fc(x)
|
||||
|
||||
model = ExpandGridSampler().to(device)
|
||||
compiled = torch.compile(model, backend="inductor")
|
||||
inp = torch.randn(batch, channels, img, img, device=device)
|
||||
|
||||
out = compiled(inp)
|
||||
out.sum().backward()
|
||||
|
||||
self.assertIsNotNone(model.grid.grad)
|
||||
|
||||
def test_upsample_bicubic2d(self):
|
||||
def fn(a):
|
||||
return (
|
||||
|
@ -2795,7 +2795,7 @@ make_fallback(aten.replication_pad2d_backward)
|
||||
make_fallback(aten.upsample_linear1d_backward)
|
||||
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
|
||||
make_fallback(aten.upsample_trilinear3d_backward)
|
||||
make_fallback(aten.grid_sampler_2d_backward, require_dense)
|
||||
make_fallback(aten.grid_sampler_2d_backward)
|
||||
make_fallback(aten._pdist_backward)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user