mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix: nDims is mutated inside the loop in Shape.cu (#165446)
Summary: The `nDims` variable is mutated inside the loop but never restored to its original value. This affects subsequent iterations of the outer loop. Each batch iteration may get incorrect `nDims` after the first batch. Test Plan: CI Reviewed By: ngimel Differential Revision: D84612194 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
839f6facdb
commit
4f400ab520
@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
int32_t trailingSize;
|
int32_t trailingSize;
|
||||||
|
int nDimsLocal = nDims;
|
||||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||||
if (isInOutAligned) {
|
if (isInOutAligned) {
|
||||||
// in this case we can and should flatten the tensors after the cat dim
|
// in this case we can and should flatten the tensors after the cat dim
|
||||||
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||||
// for input, we will fix up the sizes and strides in the kernel directly
|
// for input, we will fix up the sizes and strides in the kernel directly
|
||||||
kernelOutputParam = outputParam;
|
kernelOutputParam = outputParam;
|
||||||
nDims = dimension + 1;
|
nDimsLocal = dimension + 1;
|
||||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||||
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||||||
case 0:
|
case 0:
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
cat_dim = nDims - cat_dim;
|
cat_dim = nDimsLocal - cat_dim;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
cat_dim--;
|
cat_dim--;
|
||||||
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||||
}\
|
}\
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
switch (nDims) {
|
switch (nDimsLocal) {
|
||||||
case 1:
|
case 1:
|
||||||
HANDLE_CASE(1);
|
HANDLE_CASE(1);
|
||||||
break;
|
break;
|
||||||
|
Reference in New Issue
Block a user