Fix: nDims is mutated inside the loop in Shape.cu (#165446)

Summary:
The `nDims` variable is mutated inside the loop but never restored to its original value.
This affects subsequent iterations of the outer loop.
Each batch iteration may get incorrect `nDims` after the first batch.

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D84612194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446
Approved by: https://github.com/ngimel
This commit is contained in:
Alex Sibiryakov
2025-10-15 02:32:12 +00:00
committed by PyTorch MergeBot
parent 839f6facdb
commit 4f400ab520

View File

@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
} }
#endif #endif
int32_t trailingSize; int32_t trailingSize;
int nDimsLocal = nDims;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam; TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
if (isInOutAligned) { if (isInOutAligned) {
// in this case we can and should flatten the tensors after the cat dim // in this case we can and should flatten the tensors after the cat dim
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
// and divide all strides except last by elems_per_vec (last stride is 1 always) // and divide all strides except last by elems_per_vec (last stride is 1 always)
// for input, we will fix up the sizes and strides in the kernel directly // for input, we will fix up the sizes and strides in the kernel directly
kernelOutputParam = outputParam; kernelOutputParam = outputParam;
nDims = dimension + 1; nDimsLocal = dimension + 1;
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1]; auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec; kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
case 0: case 0:
break; break;
case 1: case 1:
cat_dim = nDims - cat_dim; cat_dim = nDimsLocal - cat_dim;
break; break;
default: default:
cat_dim--; cat_dim--;
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\ data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\ }\
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) { switch (nDimsLocal) {
case 1: case 1:
HANDLE_CASE(1); HANDLE_CASE(1);
break; break;