[Inductor][CPP] fix store mode atomic add (#147961)

**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/147848 and https://github.com/pytorch/pytorch/issues/146390. While addressing these issues, 2 problems were encountered:

- In `CppVecKernel`, when the number of threads is 1 and the mode is `atomic_add`, `store` did not `load/add` before storing. This has been fixed in this PR.

- In `CppTile2DKernel`, `store` did not support `atomic_add` mode. Support for this has been added in this PR.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_nn_fold
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147961
Approved by: https://github.com/malfet
This commit is contained in:
leslie-fang-intel
2025-02-26 01:50:09 -08:00
committed by PyTorch MergeBot
parent f522d899fb
commit be830c8b1c
3 changed files with 84 additions and 5 deletions

View File

@ -202,6 +202,45 @@ class CPUReproTests(TestCase):
(v,),
)
def test_nn_fold(self):
# Fix https://github.com/pytorch/pytorch/issues/147848
class Model(torch.nn.Module):
def __init__(self, output_size, kernel_size, stride) -> None:
super().__init__()
self.fold = torch.nn.Fold(
output_size=output_size, kernel_size=kernel_size, stride=stride
)
def forward(self, x):
x = self.fold(x)
return x
output_sizes = [(64, 64), (64, 64)]
kernel_sizes = [(32, 32), (32, 32)]
strides = [(1, 1), (2, 2)]
input_sizes = [(1, 32 * 32, 1089), (1, 64 * 64, 289)]
for idx in range(len(output_sizes)):
output_size = output_sizes[idx]
kernel_size = kernel_sizes[idx]
stride = strides[idx]
input_size = input_sizes[idx]
for num_threads in [1, None]:
torch._dynamo.reset()
metrics.reset()
v = torch.randn(*input_size)
mod = Model(output_size, kernel_size, stride).eval()
with contextlib.nullcontext() if (
num_threads != 1
) else set_num_threads(1):
with torch.no_grad():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_packed(self):

View File

@ -2665,6 +2665,13 @@ class CppVecKernel(CppKernel):
stride = self._try_get_const_stride(index, tiling_var)
code = IndentedBuffer()
if stride == 1:
if accu_store:
load = (
f"{self._get_vec_type(dtype)}::loadu({var_expr})"
if dtype == torch.float and self.tail_size is None
else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})"
)
value = f"({value} + {load})"
if dtype == torch.float and self.tail_size is None:
code.writeline(f"{value}.store({var_expr});")
else:
@ -3256,7 +3263,9 @@ class CppTile2DKernel(CppVecKernel):
and not inner_stride.has(outer_var)
)
def gen_transposed_tile_load_store(self, name, var, index, is_store):
def gen_transposed_tile_load_store(
self, name, var, index, is_store, store_mode=None
):
# transposed tile load/store outside the kernel inner loop
dtype = V.graph.get_dtype(name)
factor = self.tiling_factor
@ -3276,16 +3285,17 @@ class CppTile2DKernel(CppVecKernel):
self.outer_num_elems,
self.inner_num_elems,
)
atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false"
if (isinstance(M, sympy.Expr) and not M.is_number) or (
isinstance(N, sympy.Expr) and not N.is_number
):
load_or_store = (
f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>"
f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>"
f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});"
)
else:
load_or_store = (
f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>"
f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>"
f"({src}, {ld_src}, {dst}, {ld_dst});"
)
if is_store:
@ -3346,10 +3356,9 @@ class CppTile2DKernel(CppVecKernel):
inner = self.inner_itervar()
index = self.rename_indexing(index)
assert mode is None
if self.need_vec_transpose(index):
tile_var = self.gen_transposed_tile_load_store(
name, var, index, is_store=True
name, var, index, is_store=True, store_mode=mode
)
# vector store inside the kernel inner loop
storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"

View File

@ -644,6 +644,37 @@ void atomic_add_vec(T *addr, at::vec::VectorizedN<int64_t, NI> index, at::vec::V
atomic_add(addr + tmpidx[i], tmpbuf[i]);
}
}
template <typename T, bool atomic_add>
struct transpose_mxn_helper;
template <typename T>
struct transpose_mxn_helper<T, true> {
static void call(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
atomic_add(&dst[j*ld_dst + i], src[i*ld_src + j]);
}
}
}
};
template <typename T>
struct transpose_mxn_helper<T, false> {
static void call(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
at::vec::transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
}
};
template <typename T, bool atomic_add>
inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
transpose_mxn_helper<T, atomic_add>::call(src, ld_src, dst, ld_dst, M, N);
}
template <typename T, int M, int N, bool atomic_add>
inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
transpose_mxn<T, atomic_add>(src, ld_src, dst, ld_dst, M, N);
}
#endif
inline std::tuple<std::shared_ptr<int64_t[]>, int> _get_factors(int64_t number) {