Revert "[Inductor][CPP] Support vectorization of remainder (#129849)"

This reverts commit 8624a571b4eecd11547867591d70992843265e97.

Reverted https://github.com/pytorch/pytorch/pull/129849 on behalf of https://github.com/izaitsevfb due to ptedge_executorch_benchmark build failed again with LLVM crash ([comment](https://github.com/pytorch/pytorch/pull/129849#issuecomment-2294408526))
This commit is contained in:
PyTorch MergeBot
2024-08-16 22:41:05 +00:00
parent 98d6a6eb7d
commit 19ff9059eb
4 changed files with 40 additions and 118 deletions

View File

@ -8,43 +8,6 @@
namespace at::vec {
template <typename scalar_t, int N>
inline VectorizedN<scalar_t, N> div_floor_floating_vec(
const VectorizedN<scalar_t, N>& a,
const VectorizedN<scalar_t, N>& b) {
VectorizedN<scalar_t, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = div_floor_floating_vec(a[i], b[i]);
}
return result;
}
template <typename scalar_t>
inline Vectorized<scalar_t> div_floor_floating_vec(
const Vectorized<scalar_t>& a,
const Vectorized<scalar_t>& b) {
using vec_t = Vectorized<scalar_t>;
const auto basic_div = a / b;
vec_t inf(std::numeric_limits<scalar_t>::infinity());
auto mod = a.fmod(b);
// Fixup for a case that isn't properly handled by Sleef_fmod
auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
auto div = floor / b;
const auto zero = vec_t(0);
auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
const auto one = vec_t(1);
div = vec_t::blendv(div, div - one, mask);
auto floordiv = div.floor();
mask = (div - floordiv) > vec_t(0.5);
floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
return floordiv;
};
// slow path
template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(

View File

@ -243,6 +243,29 @@ void div_trunc_kernel(TensorIteratorBase& iter) {
}
}
template <typename scalar_t>
inline Vectorized<scalar_t> div_floor_floating_vec(
const Vectorized<scalar_t>& a,
const Vectorized<scalar_t>& b) {
using vec_t = Vectorized<scalar_t>;
const auto basic_div = a / b;
vec_t inf(std::numeric_limits<scalar_t>::infinity());
auto mod = a.fmod(b);
// Fixup for a case that isn't properly handled by Sleef_fmod
auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
auto div = floor / b;
const auto zero = vec_t(0);
auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
const auto one = vec_t(1);
div = vec_t::blendv(div, div - one, mask);
auto floordiv = div.floor();
mask = (div - floordiv) > vec_t(0.5);
floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
return floordiv;
};
void div_floor_kernel(TensorIteratorBase& iter) {
const auto dtype = iter.common_dtype();
if (dtype == kByte) {
@ -275,7 +298,7 @@ void div_floor_kernel(TensorIteratorBase& iter) {
[=](Vectorized<scalar_t> a) {
return binary_op_scalar(
a, b, [](const vec_t& x, const vec_t& y) {
return at::vec::div_floor_floating_vec(x, y);
return div_floor_floating_vec(x, y);
});
});
});
@ -289,7 +312,7 @@ void div_floor_kernel(TensorIteratorBase& iter) {
return c10::div_floor_floating(a, b);
},
[](vec_t a, vec_t b) -> vec_t {
return at::vec::div_floor_floating_vec(a, b);
return div_floor_floating_vec(a, b);
});
});
}
@ -710,7 +733,7 @@ void fmin_kernel(TensorIteratorBase& iter) {
void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
if (iter.dtype() == kBFloat16) {
const float beta_val(static_cast<float>(beta));
const float beta_val(beta);
const Vectorized<float> beta_val_vec(beta_val);
const Vectorized<float> point_five_vec(static_cast<float>(0.5));
cpu_kernel_vec(

View File

@ -34,7 +34,6 @@ from torch._inductor.graph import GraphLowering
from torch._inductor.ir import InterpreterShim
from torch._inductor.utils import timed
from torch._inductor.virtualized import V
from torch._prims_common import is_float_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
@ -2114,37 +2113,6 @@ class CPUReproTests(TestCase):
self.common(bitwise_fn, _args)
check_metrics_vec_kernel_count(1)
@requires_vectorization
def test_vec_remainder(self):
for dtype in [
torch.int8,
torch.uint8,
torch.int32,
torch.int64,
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
]:
if is_float_dtype(dtype):
x = torch.randn(64, dtype=dtype)
y = torch.randn(64, dtype=dtype)
else:
lower = 1 if dtype == torch.uint8 else -100
x = torch.randint(lower, 100, (64,), dtype=dtype)
y = torch.randint(lower, 100, (64,), dtype=dtype)
y = torch.where(
y == torch.zeros_like(y),
torch.ones_like(y),
y,
)
torch._dynamo.reset()
metrics.reset()
_args = (x, y)
self.common(torch.remainder, _args)
check_metrics_vec_kernel_count(1)
@requires_vectorization
@patch("torch.cuda.is_available", lambda: False)
def test_vec_compare_op_cpu_only(self):
@ -3583,19 +3551,13 @@ class CPUReproTests(TestCase):
def fn(x, y, mode):
return torch.div(x, y, rounding_mode=mode)
for dtype in [
torch.int8,
torch.uint8,
torch.int32,
torch.int64,
]:
x = torch.randint(1, 100, (32, 32), dtype=dtype)
y = torch.randint(1, 100, (32, 32), dtype=dtype)
for mode in [None, "trunc", "floor"]:
with torch.no_grad():
metrics.reset()
self.common(fn, (x, y, mode))
check_metrics_vec_kernel_count(1)
x = torch.randint(1, 100, (32, 32))
y = torch.randint(1, 100, (32, 32))
for mode in [None, "trunc", "floor"]:
with torch.no_grad():
metrics.reset()
self.common(fn, (x, y, mode))
check_metrics_vec_kernel_count(1)
def test_uint8_add(self):
# https://github.com/pytorch/pytorch/issues/113016

View File

@ -16,7 +16,7 @@ import sympy
import torch
import torch.fx
from torch._inductor import dependencies
from torch._prims_common import is_float_dtype, is_integer_dtype
from torch._prims_common import is_float_dtype
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
@ -1181,13 +1181,6 @@ class CppVecOverrides(CppOverrides):
def bitwise_right_shift(a, b):
return f"{a} >> {b}"
@staticmethod
def remainder(a, b):
assert (
a.dtype == b.dtype
), "remainder vec implementation expect the same inputs' dtype."
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
@staticmethod
def tan(a):
return f"{a}.tan()"
@ -1291,30 +1284,16 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def floordiv(a, b):
if is_float_dtype(a.dtype):
assert (
a.dtype == b.dtype
), "div_floor_floating_vec implementation expect the same inputs' dtype."
return f"at::vec::div_floor_floating_vec({a}, {b})"
else:
assert all(is_integer_dtype(item.dtype) for item in [a, b])
# a and b are integer type
_t = f"decltype({a})"
if V.kernel._get_raw_num_vectors(b.dtype) < 1:
# Doing blend to set the remaining bits of b to non-zero
b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
quot = f"{a} / {b}"
has_rem = f"({a} % {b} != {_t}(0))"
is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
# a and b are integer type
_t = f"decltype({a})"
quot = f"{a} / {b}"
has_rem = f"({a} % {b} != {_t}(0))"
is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
@staticmethod
def truncdiv(a, b):
# a and b are integer type
if V.kernel._get_raw_num_vectors(b.dtype) < 1:
# Doing blend to set the remaining bits of b to non-zero
_t = f"decltype({b})"
b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
return f"{a} / {b}"
@staticmethod
@ -2031,11 +2010,6 @@ class CppVecKernel(CppKernel):
assert num_vectors >= 1
return num_vectors
def _get_raw_num_vectors(self, dtype: torch.dtype) -> float:
# This utility function is used to check if the vector lanes has been
# fully utilized. For example, uint8 will only use 1/4 of the vector lanes.
return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
def _get_vec_type(self, dtype: torch.dtype) -> str:
num_vectors = self._get_num_vectors(dtype)
if num_vectors == 1: