mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix: nDims is mutated inside the loop in Shape.cu (#165446)
Summary: The `nDims` variable is mutated inside the loop but never restored to its original value. This affects subsequent iterations of the outer loop. Each batch iteration may get incorrect `nDims` after the first batch. Test Plan: CI Reviewed By: ngimel Differential Revision: D84612194 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
839f6facdb
commit
4f400ab520
@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
#endif
|
||||
int32_t trailingSize;
|
||||
int nDimsLocal = nDims;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||
if (isInOutAligned) {
|
||||
// in this case we can and should flatten the tensors after the cat dim
|
||||
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||
// for input, we will fix up the sizes and strides in the kernel directly
|
||||
kernelOutputParam = outputParam;
|
||||
nDims = dimension + 1;
|
||||
nDimsLocal = dimension + 1;
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
cat_dim = nDims - cat_dim;
|
||||
cat_dim = nDimsLocal - cat_dim;
|
||||
break;
|
||||
default:
|
||||
cat_dim--;
|
||||
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDims) {
|
||||
switch (nDimsLocal) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
|
Reference in New Issue
Block a user