mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][triton] Fix average pool nd for int64 dtype (#146061)
The eager mode implementation of average pool nd returns an integer tensor if the input is also an integer tensor. This should also be preserved in inductor. Fixes pytest -k test_comprehensive_nn_functional_avg_pool2d_cpu_int64 error: Triton compilation failed: triton_poi_fused_avg_pool2d_0 See WIP https://github.com/pytorch/pytorch/pull/145865#issuecomment-26200289890 to potentially enable such tests as they aren't enabled yet. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146061 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdee60769a
commit
f339e41a38
@ -5256,7 +5256,8 @@ def _avg_poolnd(
|
||||
else:
|
||||
|
||||
def fn(idx):
|
||||
return ops.truediv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
|
||||
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
|
||||
return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
|
||||
|
||||
else:
|
||||
|
||||
@ -5273,7 +5274,10 @@ def _avg_poolnd(
|
||||
factor = ops.index_expr(hend - hstart, torch.int32)
|
||||
divide_factors.append(factor)
|
||||
divide_factor = functools.reduce(ops.mul, divide_factors)
|
||||
return ops.truediv(fn_sum(idx, x_loader), divide_factor)
|
||||
if dtype.is_floating_point:
|
||||
return ops.truediv(fn_sum(idx, x_loader), divide_factor)
|
||||
# C style integer division as done in native/cpu/AvgPoolKernel.cpp
|
||||
return ops.truncdiv(fn_sum(idx, x_loader), divide_factor)
|
||||
|
||||
rv = Pointwise.create(
|
||||
device=x.get_device(),
|
||||
|
Reference in New Issue
Block a user