Compare commits

...

6 Commits

Author SHA1 Message Date
fb311d2516 simplify register_meta 2025-11-14 20:29:29 +00:00
fc2425329d use wrapper for out variant and simplify comments 2025-11-14 18:14:57 +00:00
1130aa2001 Remove max_pool2d_with_indices_backward dynamic shapes from expected failures
The decomposition now handles large kernels and dilation != 1 correctly,
so these tests pass with dynamic shapes.
2025-11-13 15:37:56 +00:00
e424feb451 improve from review comments 2025-11-13 15:13:51 +00:00
8b217b4462 update test expectations (decomp, kernel number) and meta for grad_input 2025-11-10 17:52:45 +00:00
5df3142e85 [decomp] Add max_pool2d_with_indices_backward decomposition using gather+reduction
Algorithm:
    For each input position, gather all gradient contributions from output
    positions that selected it as the maximum, then sum using scatter_add.

  Technical details:
  - Uses functional scatter_add for AOT Autograd compatibility
  - Uses reshape (not view) to handle non-contiguous tensors
  - Handles all configurations without fallback (dilation, large kernels, ceil_mode)
2025-11-10 15:30:13 +00:00
6 changed files with 69 additions and 13 deletions

View File

@ -348,6 +348,8 @@ aten::lt.Tensor
aten::lt.Tensor_out
aten::lt_.Scalar
aten::lt_.Tensor
aten::max_pool2d_with_indices_backward
aten::max_pool2d_with_indices_backward.grad_input
aten::maximum
aten::maximum.out
aten::mean

View File

@ -946,8 +946,6 @@ aten::max_pool2d_backward
aten::max_pool2d_backward.out
aten::max_pool2d_with_indices
aten::max_pool2d_with_indices.out
aten::max_pool2d_with_indices_backward
aten::max_pool2d_with_indices_backward.grad_input
aten::max_pool3d_with_indices
aten::max_pool3d_with_indices.out
aten::max_pool3d_with_indices_backward

View File

@ -9729,11 +9729,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
indices,
],
)
# Decomposition compiles via scatter_add (fuses into 1 kernel)
assertGeneratedKernelCountEqual(self, 1)
@expectedFailureXPU
def test_max_pool2d_with_indices_backward5(self):
# Window size is too big. Should fallback
# Large window size - decomposition handles via scatter_add
def fn(a, b, c):
return aten.max_pool2d_with_indices_backward(
a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c
@ -9757,11 +9758,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
indices,
],
)
assertGeneratedKernelCountEqual(self, 0)
# Decomposition compiles via scatter_add (fuses into 1 kernel)
assertGeneratedKernelCountEqual(self, 1)
# From https://github.com/pytorch/pytorch/issues/93384
def test_max_pool2d_with_indices_backward6(self):
# dilation is not 1. Should fallback
# dilation != 1 - decomposition handles all dilation cases
def fn(a, b, c):
return aten.max_pool2d_with_indices_backward(
a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c
@ -9785,7 +9787,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
indices,
],
)
assertGeneratedKernelCountEqual(self, 0)
# Decomposition compiles via scatter_add (fuses into 1 kernel)
assertGeneratedKernelCountEqual(self, 1)
def test_issue102546(self):
def fn(x):

View File

@ -203,12 +203,6 @@ test_failures = {
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)),
"test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)),
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
("cpu", "cuda")
),
"test_max_pool2d_with_indices_backward6_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu")
),
"test_misaligned_address_issue1_dynamic_shapes": TestFailure(("cpu",)),
"test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

View File

@ -5347,6 +5347,64 @@ def resize_as(self, other, memory_format=None):
return aten.resize(self, other.shape, memory_format=memory_format)
@register_decomposition(aten.max_pool2d_with_indices_backward)
def max_pool2d_with_indices_backward(
grad_output: Tensor,
self: Tensor,
kernel_size,
stride,
padding,
dilation,
ceil_mode: bool,
indices: Tensor,
):
# Use native kernel in deterministic mode
if torch.are_deterministic_algorithms_enabled():
return NotImplemented
# Get spatial dimensions
in_height = self.size(-2)
in_width = self.size(-1)
out_height = grad_output.size(-2)
out_width = grad_output.size(-1)
# Handle both 3D (C, H, W) and 4D (B, C, H, W) cases by treating 3D as 4D
is_batched = self.dim() == 4
if not is_batched:
self = self.unsqueeze(0)
grad_output = grad_output.unsqueeze(0)
indices = indices.unsqueeze(0)
batch_size = self.size(0)
channels = self.size(1)
# Create grad_input in the flattened shape for efficient scatter_add
grad_input_flat = torch.zeros(
batch_size * channels,
in_height * in_width,
dtype=grad_output.dtype,
device=grad_output.device,
)
# Reshape grad_output and indices to (B*C, H_out*W_out)
grad_output_flat = grad_output.reshape(
batch_size * channels, out_height * out_width
)
indices_flat = indices.reshape(batch_size * channels, out_height * out_width)
# Use scatter_add to accumulate gradients
grad_input_flat = grad_input_flat.scatter_add(1, indices_flat, grad_output_flat)
# Reshape back to original input shape
grad_input = grad_input_flat.reshape(batch_size, channels, in_height, in_width)
# Remove batch dimension for 3D case
if not is_batched:
grad_input = grad_input.squeeze(0)
return grad_input
register_inplace(aten.addbmm_, aten.addbmm)
register_inplace(aten.addmm_, aten.addmm)
register_inplace(aten.addmv_, aten.addmv)

View File

@ -4810,7 +4810,8 @@ def max_pool2d_checks_and_compute_shape(
return nInputPlane, outputHeight, outputWidth
@register_meta(aten.max_pool2d_with_indices_backward.default)
@register_meta(aten.max_pool2d_with_indices_backward)
@out_wrapper("grad_input")
def meta_max_pool2d_with_indices_backward(
grad_output,
self,