mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPP] fix store mode atomic add (#147961)
**Summary** Fix issue: https://github.com/pytorch/pytorch/issues/147848 and https://github.com/pytorch/pytorch/issues/146390. While addressing these issues, 2 problems were encountered: - In `CppVecKernel`, when the number of threads is 1 and the mode is `atomic_add`, `store` did not `load/add` before storing. This has been fixed in this PR. - In `CppTile2DKernel`, `store` did not support `atomic_add` mode. Support for this has been added in this PR. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_nn_fold ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147961 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f522d899fb
commit
be830c8b1c
@ -202,6 +202,45 @@ class CPUReproTests(TestCase):
|
||||
(v,),
|
||||
)
|
||||
|
||||
def test_nn_fold(self):
|
||||
# Fix https://github.com/pytorch/pytorch/issues/147848
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, output_size, kernel_size, stride) -> None:
|
||||
super().__init__()
|
||||
self.fold = torch.nn.Fold(
|
||||
output_size=output_size, kernel_size=kernel_size, stride=stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fold(x)
|
||||
return x
|
||||
|
||||
output_sizes = [(64, 64), (64, 64)]
|
||||
kernel_sizes = [(32, 32), (32, 32)]
|
||||
strides = [(1, 1), (2, 2)]
|
||||
input_sizes = [(1, 32 * 32, 1089), (1, 64 * 64, 289)]
|
||||
|
||||
for idx in range(len(output_sizes)):
|
||||
output_size = output_sizes[idx]
|
||||
kernel_size = kernel_sizes[idx]
|
||||
stride = strides[idx]
|
||||
input_size = input_sizes[idx]
|
||||
|
||||
for num_threads in [1, None]:
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
v = torch.randn(*input_size)
|
||||
mod = Model(output_size, kernel_size, stride).eval()
|
||||
with contextlib.nullcontext() if (
|
||||
num_threads != 1
|
||||
) else set_num_threads(1):
|
||||
with torch.no_grad():
|
||||
self.common(
|
||||
mod,
|
||||
(v,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_conv2d_packed(self):
|
||||
|
@ -2665,6 +2665,13 @@ class CppVecKernel(CppKernel):
|
||||
stride = self._try_get_const_stride(index, tiling_var)
|
||||
code = IndentedBuffer()
|
||||
if stride == 1:
|
||||
if accu_store:
|
||||
load = (
|
||||
f"{self._get_vec_type(dtype)}::loadu({var_expr})"
|
||||
if dtype == torch.float and self.tail_size is None
|
||||
else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})"
|
||||
)
|
||||
value = f"({value} + {load})"
|
||||
if dtype == torch.float and self.tail_size is None:
|
||||
code.writeline(f"{value}.store({var_expr});")
|
||||
else:
|
||||
@ -3256,7 +3263,9 @@ class CppTile2DKernel(CppVecKernel):
|
||||
and not inner_stride.has(outer_var)
|
||||
)
|
||||
|
||||
def gen_transposed_tile_load_store(self, name, var, index, is_store):
|
||||
def gen_transposed_tile_load_store(
|
||||
self, name, var, index, is_store, store_mode=None
|
||||
):
|
||||
# transposed tile load/store outside the kernel inner loop
|
||||
dtype = V.graph.get_dtype(name)
|
||||
factor = self.tiling_factor
|
||||
@ -3276,16 +3285,17 @@ class CppTile2DKernel(CppVecKernel):
|
||||
self.outer_num_elems,
|
||||
self.inner_num_elems,
|
||||
)
|
||||
atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false"
|
||||
if (isinstance(M, sympy.Expr) and not M.is_number) or (
|
||||
isinstance(N, sympy.Expr) and not N.is_number
|
||||
):
|
||||
load_or_store = (
|
||||
f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>"
|
||||
f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>"
|
||||
f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});"
|
||||
)
|
||||
else:
|
||||
load_or_store = (
|
||||
f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>"
|
||||
f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>"
|
||||
f"({src}, {ld_src}, {dst}, {ld_dst});"
|
||||
)
|
||||
if is_store:
|
||||
@ -3346,10 +3356,9 @@ class CppTile2DKernel(CppVecKernel):
|
||||
|
||||
inner = self.inner_itervar()
|
||||
index = self.rename_indexing(index)
|
||||
assert mode is None
|
||||
if self.need_vec_transpose(index):
|
||||
tile_var = self.gen_transposed_tile_load_store(
|
||||
name, var, index, is_store=True
|
||||
name, var, index, is_store=True, store_mode=mode
|
||||
)
|
||||
# vector store inside the kernel inner loop
|
||||
storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"
|
||||
|
@ -644,6 +644,37 @@ void atomic_add_vec(T *addr, at::vec::VectorizedN<int64_t, NI> index, at::vec::V
|
||||
atomic_add(addr + tmpidx[i], tmpbuf[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool atomic_add>
|
||||
struct transpose_mxn_helper;
|
||||
|
||||
template <typename T>
|
||||
struct transpose_mxn_helper<T, true> {
|
||||
static void call(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
atomic_add(&dst[j*ld_dst + i], src[i*ld_src + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct transpose_mxn_helper<T, false> {
|
||||
static void call(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
|
||||
at::vec::transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, bool atomic_add>
|
||||
inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
|
||||
transpose_mxn_helper<T, atomic_add>::call(src, ld_src, dst, ld_dst, M, N);
|
||||
}
|
||||
|
||||
template <typename T, int M, int N, bool atomic_add>
|
||||
inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
|
||||
transpose_mxn<T, atomic_add>(src, ld_src, dst, ld_dst, M, N);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline std::tuple<std::shared_ptr<int64_t[]>, int> _get_factors(int64_t number) {
|
||||
|
Reference in New Issue
Block a user