[Inductor] support masked vectorization for the tail_loop for bool datatype

ghstack-source-id: 9db08e405f6f4df80f831e3906dc7782a35214c5
Pull-Request: https://github.com/pytorch/pytorch/pull/164202
This commit is contained in:
Sun, Jiayi
2025-10-20 10:40:34 +00:00
parent eedf6e950b
commit 7bc36099ec
4 changed files with 70 additions and 30 deletions

View File

@ -165,6 +165,19 @@ class VecMask {
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
}
template <typename U>
static VecMask<T, N> from(U* b, int count) {
using int_t = int_same_size_t<T>;
__at_align__ T mask[size()];
#ifndef __msvc_cl__
#pragma unroll
#endif
for (int i = 0; i < count; i++) {
*(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
}
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask, count));
}
static VecMask<T, N> blendv(
const VecMask<T, N>& c,
const VecMask<T, N>& b,

View File

@ -4661,6 +4661,18 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.rand(37)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
compiled_fn = torch.compile(fn)
_, code = run_and_get_cpp_code(compiled_fn, x)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::VecMask<float,1>::from", 2, exactly=True
).run(code)
@torch._dynamo.config.patch(dynamic_shapes=True)
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_symbolic_shape_scalar_value_reduction(self):
@ -5399,7 +5411,7 @@ class CPUReproTests(TestCase):
_, code = run_and_get_cpp_code(opt_fn, x)
FileCheck().check_count(
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(),",
4,
8,
exactly=True,
).run(code)

View File

@ -158,19 +158,6 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.float8_e5m2,
]
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.float64,
torch.float,
torch.bfloat16,
torch.float16,
torch.uint8,
torch.int8,
torch.int32,
torch.int64,
torch.float8_e4m3fn,
torch.float8_e5m2,
]
def reduction_init(reduction_type, dtype):
if dtype in DTYPE_LOWP_FP:
@ -1873,8 +1860,7 @@ class CppVecOverrides(CppOverrides):
with code.indent():
code.writeline(f"tmpbuf_out[i] = {res};")
if output_mask:
assert not kernel.tail_size
load_args = "tmpbuf_out.data()"
load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from"
else:
load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
@ -2736,7 +2722,7 @@ class CppVecKernel(CppKernel):
loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
if dtype == torch.bool:
# TODO: should we consider load mask here?
line = f"{self._get_mask_type()}::from({loadbuf})"
line = f"{self._get_mask_type()}::from({loadbuf}, {cexpr_index(self.num_elems)})"
else:
line = (
f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
@ -3450,7 +3436,10 @@ class CppVecKernel(CppKernel):
if isinstance(next_value, CppCSEVariable):
assert next_value.dtype == torch.bool
(next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,))
return f"{var} | {next_value}"
if self.tail_size:
return f"any_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
else:
return f"{var} | {next_value}"
else:
raise NotImplementedError
@ -4357,13 +4346,6 @@ class CppKernelProxy(CppKernel):
fn_list, var_sizes_list
)
assert len(tiling_factors) == len(tiling_indices)
# <TODO> This should be removed after full support for vectorization is implemented.
could_masked_vec = True
all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list))
if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes):
# can be removed after masked vectorizable dtype are same with vectorizable dtype
could_masked_vec = False
_inner_loop_reduction_outer_not = False
_outer_loop = None
if tiling_indices:
@ -4390,7 +4372,7 @@ class CppKernelProxy(CppKernel):
)
tail_size = loop.size - loop.tiled_size
vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)}
if config.cpp.enable_loop_tail_vec and could_masked_vec:
if config.cpp.enable_loop_tail_vec:
tail_kernel = codegen_kernel(
self.vec_kernel_cls,
tiling_factors[0],
@ -4437,7 +4419,7 @@ class CppKernelProxy(CppKernel):
inner_loop.var: inner_ranges["main"],
}
tail_kernel = []
if config.cpp.enable_loop_tail_vec and could_masked_vec:
if config.cpp.enable_loop_tail_vec:
for outer_r, inner_r in (
("main", "tail"),
("tail", "main"),

View File

@ -296,23 +296,50 @@ inline T cascade_sum_combine(
}
template <typename T>
T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
inline T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = at::vec::maximum(a, b);
return T::set(a, out, tail_size);
}
template <>
inline at::vec::VecMask<float, 1> max_masked_reduce(
const at::vec::VecMask<float, 1>& a,
const at::vec::VecMask<float, 1>& b,
const int64_t tail_size) {
auto out = a | b;
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
}
template <typename T>
T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
inline T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = at::vec::minimum(a, b);
return T::set(a, out, tail_size);
}
template <>
inline at::vec::VecMask<float, 1> min_masked_reduce(
const at::vec::VecMask<float, 1>& a,
const at::vec::VecMask<float, 1>& b,
const int64_t tail_size) {
auto out = a & b;
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
}
template <typename T>
T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
inline T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a + b;
return T::set(a, out, tail_size);
}
template <>
inline at::vec::VecMask<float, 1> sum_masked_reduce(
const at::vec::VecMask<float, 1>& a,
const at::vec::VecMask<float, 1>& b,
const int64_t tail_size) {
auto out = a | b;
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
}
template <typename T>
T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a * b;
@ -324,6 +351,12 @@ T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a ^ b;
return T::set(a, out, tail_size);
}
template <typename T1, typename T2>
T1 any_masked_reduce(const T1& a, const T2& b, const int64_t tail_size) {
T1 out = a | b;
return T1::set(a, out, tail_size);
}
#endif
// Refer to