[inductor][triton] Fix average pool nd for int64 dtype (#146061)

The eager mode implementation of average pool nd returns an integer tensor if the input is also an integer tensor. This should also be preserved in inductor.

Fixes pytest -k test_comprehensive_nn_functional_avg_pool2d_cpu_int64 error: Triton compilation failed: triton_poi_fused_avg_pool2d_0

See WIP https://github.com/pytorch/pytorch/pull/145865#issuecomment-26200289890 to potentially enable such tests as they aren't enabled yet.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146061
Approved by: https://github.com/eellison
This commit is contained in:
Mwiza Kunda
2025-03-04 13:53:47 +00:00
committed by PyTorch MergeBot
parent fdee60769a
commit f339e41a38

View File

@ -5256,7 +5256,8 @@ def _avg_poolnd(
else:
def fn(idx):
return ops.truediv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
else:
@ -5273,7 +5274,10 @@ def _avg_poolnd(
factor = ops.index_expr(hend - hstart, torch.int32)
divide_factors.append(factor)
divide_factor = functools.reduce(ops.mul, divide_factors)
return ops.truediv(fn_sum(idx, x_loader), divide_factor)
if dtype.is_floating_point:
return ops.truediv(fn_sum(idx, x_loader), divide_factor)
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
return ops.truncdiv(fn_sum(idx, x_loader), divide_factor)
rv = Pointwise.create(
device=x.get_device(),