mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] support masked vectorization for the tail_loop for integer datatypes
ghstack-source-id: 32ec9331c2363c3528ee11be461ad20fa932894d Pull Request resolved: https://github.com/pytorch/pytorch/pull/163490
This commit is contained in:
@ -116,10 +116,10 @@ class Vectorized<int64_t> : public Vectorizedi {
|
||||
__at_align__ int64_t tmp_values[size()];
|
||||
// Ensure uninitialized memory does not change the output value See
|
||||
// https://github.com/pytorch/pytorch/issues/32502 for more details. We do
|
||||
// not initialize arrays to zero using "={0}" because gcc would compile it
|
||||
// not initialize arrays to one using "={1}" because gcc would compile it
|
||||
// to two instructions while a loop would be compiled to one instruction.
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp_values[i] = 0;
|
||||
tmp_values[i] = 1;
|
||||
}
|
||||
std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
|
||||
return loadu(tmp_values);
|
||||
@ -266,10 +266,10 @@ class Vectorized<int32_t> : public Vectorizedi {
|
||||
__at_align__ int32_t tmp_values[size()];
|
||||
// Ensure uninitialized memory does not change the output value See
|
||||
// https://github.com/pytorch/pytorch/issues/32502 for more details. We do
|
||||
// not initialize arrays to zero using "={0}" because gcc would compile it
|
||||
// not initialize arrays to one using "={1}" because gcc would compile it
|
||||
// to two instructions while a loop would be compiled to one instruction.
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp_values[i] = 0;
|
||||
tmp_values[i] = 1;
|
||||
}
|
||||
std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
|
||||
return loadu(tmp_values);
|
||||
@ -566,10 +566,10 @@ class Vectorized<int16_t> : public Vectorizedi {
|
||||
__at_align__ int16_t tmp_values[size()];
|
||||
// Ensure uninitialized memory does not change the output value See
|
||||
// https://github.com/pytorch/pytorch/issues/32502 for more details. We do
|
||||
// not initialize arrays to zero using "={0}" because gcc would compile it
|
||||
// not initialize arrays to one using "={1}" because gcc would compile it
|
||||
// to two instructions while a loop would be compiled to one instruction.
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp_values[i] = 0;
|
||||
tmp_values[i] = 1;
|
||||
}
|
||||
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
|
||||
return loadu(tmp_values);
|
||||
@ -914,10 +914,10 @@ class Vectorized8 : public Vectorizedi {
|
||||
__at_align__ T tmp_values[size()];
|
||||
// Ensure uninitialized memory does not change the output value See
|
||||
// https://github.com/pytorch/pytorch/issues/32502 for more details. We do
|
||||
// not initialize arrays to zero using "={0}" because gcc would compile it
|
||||
// not initialize arrays to one using "={1}" because gcc would compile it
|
||||
// to two instructions while a loop would be compiled to one instruction.
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp_values[i] = 0;
|
||||
tmp_values[i] = 1;
|
||||
}
|
||||
std::memcpy(tmp_values, ptr, count * sizeof(T));
|
||||
return loadu(tmp_values);
|
||||
|
||||
@ -130,7 +130,8 @@ class Vectorized<int64_t> : public Vectorizedi {
|
||||
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
|
||||
} else {
|
||||
__mmask8 mask = (1ULL << count) - 1;
|
||||
return _mm512_maskz_loadu_epi64(mask, ptr);
|
||||
auto ones = _mm512_set1_epi64(1);
|
||||
return _mm512_mask_loadu_epi64(ones, mask, ptr);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int count = size()) const {
|
||||
@ -332,7 +333,8 @@ class Vectorized<int32_t> : public Vectorizedi {
|
||||
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
|
||||
} else {
|
||||
__mmask16 mask = (1ULL << count) - 1;
|
||||
return _mm512_maskz_loadu_epi32(mask, ptr);
|
||||
auto ones = _mm512_set1_epi32(1);
|
||||
return _mm512_mask_loadu_epi32(ones, mask, ptr);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int count = size()) const {
|
||||
@ -660,7 +662,8 @@ class Vectorized<int16_t> : public Vectorizedi {
|
||||
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
|
||||
} else {
|
||||
__mmask32 mask = (1ULL << count) - 1;
|
||||
return _mm512_maskz_loadu_epi16(mask, ptr);
|
||||
auto ones = _mm512_set1_epi16(1);
|
||||
return _mm512_mask_loadu_epi16(ones, mask, ptr);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int count = size()) const {
|
||||
@ -1101,7 +1104,8 @@ class Vectorized8 : public Vectorizedi {
|
||||
return loadu_one_fourth(ptr);
|
||||
} else {
|
||||
__mmask64 mask = (1ULL << count) - 1;
|
||||
return _mm512_maskz_loadu_epi8(mask, ptr);
|
||||
auto ones = _mm512_set1_epi8(1);
|
||||
return _mm512_mask_loadu_epi8(ones, mask, ptr);
|
||||
}
|
||||
}
|
||||
void store(void* ptr, int count = size()) const {
|
||||
|
||||
@ -187,12 +187,13 @@ class VectorizedN {
|
||||
static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
|
||||
VectorizedN<T, N> result;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result.values[i] = Vectorized<T>::loadu(
|
||||
ptr, std::min(count, (int64_t)Vectorized<T>::size()));
|
||||
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
|
||||
count -= Vectorized<T>::size();
|
||||
if (count <= 0) {
|
||||
break;
|
||||
if (count > 0) {
|
||||
result.values[i] = Vectorized<T>::loadu(
|
||||
ptr, std::min(count, (int64_t)Vectorized<T>::size()));
|
||||
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
|
||||
count -= Vectorized<T>::size();
|
||||
} else {
|
||||
result.values[i] = Vectorized<T>((T)1);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
@ -4682,6 +4682,22 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randint(0, 100, (22, 22), dtype=torch.int32)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::Vectorized<int32_t>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_int32_reduction_vec(self):
|
||||
def fn(x):
|
||||
return x.sum(dim=1)
|
||||
@ -4691,6 +4707,22 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randint(0, 100, (22, 22), dtype=torch.int32)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::Vectorized<int32_t>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_uint32_pointwise_vec(self):
|
||||
def fn(x):
|
||||
return x * x
|
||||
@ -4720,6 +4752,22 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randint(0, 100, (22, 22), dtype=torch.int64)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<int64_t,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_int64_reduction_vec(self):
|
||||
def fn(x):
|
||||
return x.sum(dim=1)
|
||||
@ -4729,6 +4777,22 @@ class CPUReproTests(TestCase):
|
||||
self.common(fn, (x,))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
# Tail vectorization case
|
||||
x = torch.randint(0, 100, (22, 22), dtype=torch.int64)
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
with torch.no_grad():
|
||||
expected = fn(x)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 1 generated vec kernel
|
||||
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
|
||||
# Check that both main and tail loops are vectorized
|
||||
FileCheck().check_count(
|
||||
"at::vec::VectorizedN<int64_t,2>::loadu", 2, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_uint64_pointwise_vec(self):
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
@ -165,6 +165,8 @@ MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
@ -2977,7 +2979,10 @@ class CppVecKernel(CppKernel):
|
||||
cdtype = DTYPE_TO_CPP[dtype]
|
||||
index = ops.index_expr(index, torch.int64).value
|
||||
assert isinstance(index, CppCSEVariable) and index.is_vec
|
||||
line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});"
|
||||
if self.tail_size:
|
||||
line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value}, {cexpr_index(self.tail_size)});"
|
||||
else:
|
||||
line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});"
|
||||
self.stores.writeline(DeferredLine(name, line))
|
||||
else:
|
||||
raise NotImplementedError(f"store mode={mode}")
|
||||
|
||||
@ -859,14 +859,16 @@ template <typename T, int NI, int NV>
|
||||
void atomic_add_vec(
|
||||
T* addr,
|
||||
at::vec::VectorizedN<int64_t, NI> index,
|
||||
at::vec::VectorizedN<T, NV> offset) {
|
||||
at::vec::VectorizedN<T, NV> offset,
|
||||
std::optional<int64_t> tail_size = std::nullopt) {
|
||||
constexpr int len = at::vec::VectorizedN<int64_t, NI>::size();
|
||||
static_assert(len <= at::vec::VectorizedN<T, NV>::size());
|
||||
__at_align__ std::array<T, len> tmpbuf;
|
||||
__at_align__ std::array<int64_t, len> tmpidx;
|
||||
offset.store(tmpbuf.data(), len);
|
||||
index.store(tmpidx.data(), len);
|
||||
for (int i = 0; i < len; i++) {
|
||||
int size = tail_size.has_value() ? tail_size.value() : len;
|
||||
for (int i = 0; i < size; i++) {
|
||||
atomic_add(addr + tmpidx[i], tmpbuf[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user