Compare commits

...

2 Commits

Author SHA1 Message Date
d2597181fe Update
[ghstack-poisoned]
2025-11-06 16:54:53 +00:00
2aea412900 Update (base update)
[ghstack-poisoned]
2025-11-06 16:54:53 +00:00
3 changed files with 57 additions and 70 deletions

View File

@ -5398,16 +5398,14 @@ class CommonTemplate:
)
def test_avg_pool2d7(self):
# Large kernel size, use fallback
# Large kernel size
def fn(x):
return aten.avg_pool2d(x, [13, 13], [1, 1], [0, 0])
torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
(-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),),
)
assertGeneratedKernelCountEqual(self, 0)
def test_avg_pool2d8(self):
# https://github.com/pytorch/pytorch/issues/100987

View File

@ -161,7 +161,6 @@ test_failures = {
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

View File

@ -284,7 +284,11 @@ def is_boolean_type(x):
return isinstance(x, bool)
def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
def get_promoted_dtype(
*args,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
return_compute_dtype: bool = False,
):
def construct_input(inp):
if isinstance(inp, (Number, sympy.Basic)):
return inp
@ -294,8 +298,10 @@ def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KI
return torch.zeros([1] * dim, dtype=inp.get_dtype())
inps = [construct_input(arg) for arg in args]
_, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
return dtype
compute_dtype, result_dtype = elementwise_dtypes(
*inps, type_promotion_kind=type_promotion_kind
)
return compute_dtype if return_compute_dtype else result_dtype
def get_overloads(aten_fn):
@ -5510,14 +5516,6 @@ def upsample_nearest2d_backward(
return rv
fallback_avg_pool2d = fallback_handler(
aten.avg_pool2d.default, add_to_fallback_set=False
)
fallback_avg_pool3d = fallback_handler(
aten.avg_pool3d.default, add_to_fallback_set=False
)
@register_lowering(aten.avg_pool2d, type_promotion_kind=None)
def avg_pool2d(
x,
@ -5606,57 +5604,52 @@ def _avg_poolnd(
new_size = list(batch) + list(h_out)
dtype = x.get_dtype()
# compute in higher-precision until scaling
output_dtype = get_promoted_dtype(
x,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
return_compute_dtype=True,
)
def fn_inner(idx, reduction_idx):
prefix = idx[:-dim]
bh = idx[-dim:]
ih = reduction_idx
ih = [bh[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
return x_loader([*prefix, *ih])
window_size = functools.reduce(operator.mul, kernel_size)
if window_size > 25:
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
if dim == 2:
fallback = fallback_avg_pool2d
elif dim == 3:
fallback = fallback_avg_pool3d
else:
raise ValueError(f"Unknown dim: {dim}")
return fallback(
x,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
# TODO: remove this when #100331 is merged. We only do this
# for window_size <=25 to avoid performance regressions compared
# to the previous algorithm which unrolled manually for <=25
context = (
config.patch(unroll_reductions_threshold=25)
if window_size <= 25
else contextlib.nullcontext()
)
with context:
rv = Reduction.create(
reduction_type="sum",
input_node=x,
device=x.get_device(),
dst_dtype=output_dtype,
src_dtype=dtype,
inner_fn=fn_inner,
ranges=new_size,
reduction_ranges=kernel_size,
)
def fn_sum(idx, loader):
prefix = idx[:-dim]
b = idx[-dim:]
total = None
for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]):
inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
val = loader([*prefix, *inp])
if total is None:
total = val
else:
total = ops.add(val, total)
return total
if isinstance(rv.data.data, Reduction):
# Only realize if reduction isn't unrolled
rv.realize()
if not had_padding or divisor_override:
divisor = divisor_override if divisor_override else window_size
if dtype.is_floating_point:
scale = 1 / divisor
def fn(idx):
return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
else:
def fn(idx):
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
result = div_prim(rv, divisor)
else:
def fn(idx):
def fn_count(idx):
bh = idx[-dim:]
divide_factors = []
@ -5668,20 +5661,17 @@ def _avg_poolnd(
hend = sympy.Min(hend, h[i])
factor = ops.index_expr(hend - hstart, torch.int32)
divide_factors.append(factor)
divide_factor = functools.reduce(ops.mul, divide_factors)
if dtype.is_floating_point:
return ops.truediv(fn_sum(idx, x_loader), divide_factor)
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
return ops.truncdiv(fn_sum(idx, x_loader), divide_factor)
return functools.reduce(ops.mul, divide_factors)
rv = Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=fn,
ranges=new_size,
)
# TODO(jansel): should we force these to be realized?
return rv
divide_factor = Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=fn_count,
ranges=new_size,
)
result = div_prim(rv, divide_factor)
return to_dtype(result, dtype)
fallback_avg_pool2d_backward = fallback_handler(