mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] support masked vectorization for the tail_loop for bool datatype
ghstack-source-id: 9db08e405f6f4df80f831e3906dc7782a35214c5 Pull-Request: https://github.com/pytorch/pytorch/pull/164202
This commit is contained in:
@ -165,6 +165,19 @@ class VecMask {
|
||||
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
static VecMask<T, N> from(U* b, int count) {
|
||||
using int_t = int_same_size_t<T>;
|
||||
__at_align__ T mask[size()];
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (int i = 0; i < count; i++) {
|
||||
*(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
|
||||
}
|
||||
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask, count));
|
||||
}
|
||||
|
||||
static VecMask<T, N> blendv(
|
||||
const VecMask<T, N>& c,
|
||||
const VecMask<T, N>& b,
|
||||
|
@ -4661,6 +4661,18 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.rand(37)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
compiled_fn = torch.compile(fn)
|
||||
_, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VecMask<float,1>::from", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
@torch._dynamo.config.patch(dynamic_shapes=True)
|
||||
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||
def test_symbolic_shape_scalar_value_reduction(self):
|
||||
@ -5399,7 +5411,7 @@ class CPUReproTests(TestCase):
|
||||
_, code = run_and_get_cpp_code(opt_fn, x)
|
||||
FileCheck().check_count(
|
||||
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(),",
|
||||
4,
|
||||
8,
|
||||
exactly=True,
|
||||
).run(code)
|
||||
|
||||
|
@ -158,19 +158,6 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float64,
|
||||
torch.float,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
|
||||
def reduction_init(reduction_type, dtype):
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
@ -1873,8 +1860,7 @@ class CppVecOverrides(CppOverrides):
|
||||
with code.indent():
|
||||
code.writeline(f"tmpbuf_out[i] = {res};")
|
||||
if output_mask:
|
||||
assert not kernel.tail_size
|
||||
load_args = "tmpbuf_out.data()"
|
||||
load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
|
||||
load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from"
|
||||
else:
|
||||
load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
|
||||
@ -2736,7 +2722,7 @@ class CppVecKernel(CppKernel):
|
||||
loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
|
||||
if dtype == torch.bool:
|
||||
# TODO: should we consider load mask here?
|
||||
line = f"{self._get_mask_type()}::from({loadbuf})"
|
||||
line = f"{self._get_mask_type()}::from({loadbuf}, {cexpr_index(self.num_elems)})"
|
||||
else:
|
||||
line = (
|
||||
f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
|
||||
@ -3450,7 +3436,10 @@ class CppVecKernel(CppKernel):
|
||||
if isinstance(next_value, CppCSEVariable):
|
||||
assert next_value.dtype == torch.bool
|
||||
(next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,))
|
||||
return f"{var} | {next_value}"
|
||||
if self.tail_size:
|
||||
return f"any_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
|
||||
else:
|
||||
return f"{var} | {next_value}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -4357,13 +4346,6 @@ class CppKernelProxy(CppKernel):
|
||||
fn_list, var_sizes_list
|
||||
)
|
||||
assert len(tiling_factors) == len(tiling_indices)
|
||||
# <TODO> This should be removed after full support for vectorization is implemented.
|
||||
could_masked_vec = True
|
||||
all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list))
|
||||
if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes):
|
||||
# can be removed after masked vectorizable dtype are same with vectorizable dtype
|
||||
could_masked_vec = False
|
||||
|
||||
_inner_loop_reduction_outer_not = False
|
||||
_outer_loop = None
|
||||
if tiling_indices:
|
||||
@ -4390,7 +4372,7 @@ class CppKernelProxy(CppKernel):
|
||||
)
|
||||
tail_size = loop.size - loop.tiled_size
|
||||
vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)}
|
||||
if config.cpp.enable_loop_tail_vec and could_masked_vec:
|
||||
if config.cpp.enable_loop_tail_vec:
|
||||
tail_kernel = codegen_kernel(
|
||||
self.vec_kernel_cls,
|
||||
tiling_factors[0],
|
||||
@ -4437,7 +4419,7 @@ class CppKernelProxy(CppKernel):
|
||||
inner_loop.var: inner_ranges["main"],
|
||||
}
|
||||
tail_kernel = []
|
||||
if config.cpp.enable_loop_tail_vec and could_masked_vec:
|
||||
if config.cpp.enable_loop_tail_vec:
|
||||
for outer_r, inner_r in (
|
||||
("main", "tail"),
|
||||
("tail", "main"),
|
||||
|
@ -296,23 +296,50 @@ inline T cascade_sum_combine(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
inline T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
auto out = at::vec::maximum(a, b);
|
||||
return T::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline at::vec::VecMask<float, 1> max_masked_reduce(
|
||||
const at::vec::VecMask<float, 1>& a,
|
||||
const at::vec::VecMask<float, 1>& b,
|
||||
const int64_t tail_size) {
|
||||
auto out = a | b;
|
||||
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
inline T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
auto out = at::vec::minimum(a, b);
|
||||
return T::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline at::vec::VecMask<float, 1> min_masked_reduce(
|
||||
const at::vec::VecMask<float, 1>& a,
|
||||
const at::vec::VecMask<float, 1>& b,
|
||||
const int64_t tail_size) {
|
||||
auto out = a & b;
|
||||
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
inline T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
auto out = a + b;
|
||||
return T::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline at::vec::VecMask<float, 1> sum_masked_reduce(
|
||||
const at::vec::VecMask<float, 1>& a,
|
||||
const at::vec::VecMask<float, 1>& b,
|
||||
const int64_t tail_size) {
|
||||
auto out = a | b;
|
||||
return at::vec::VecMask<float, 1>::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
auto out = a * b;
|
||||
@ -324,6 +351,12 @@ T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
|
||||
auto out = a ^ b;
|
||||
return T::set(a, out, tail_size);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
T1 any_masked_reduce(const T1& a, const T2& b, const int64_t tail_size) {
|
||||
T1 out = a | b;
|
||||
return T1::set(a, out, tail_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Refer to
|
||||
|
Reference in New Issue
Block a user