Make torch._chunk_cat support non-contiguous inputs (#151263)

Currently, `torch._chunk_cat` only supports contiguous inputs (due to `.view()` usage in `_pad_chunk()` supporting only contiguous tensor). This doesn't work for internal models where there can be non-contiguous input tensors:

- size=[8192, 16416], stride=[16448, 1]  # stride[0] is larger than size[1]
- size=[1152, 384], stride=[1, 1152]  # column-major tensor

In this PR, we relax the assumption on contiguous input tensor, by switching from `.view()` to `.reshape()`. Note that since `.reshape()` will try to use `.view()` under the hood whenever possible, this should not cause regression to existing use cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151263
Approved by: https://github.com/BoyuanFeng
This commit is contained in:
Will Feng
2025-04-15 13:29:26 -07:00
committed by PyTorch MergeBot
parent 30101aa450
commit 82200e33b5
4 changed files with 15 additions and 5 deletions

View File

@ -3366,7 +3366,7 @@ static std::vector<Tensor> _pad_chunk(
std::vector<int64_t> view_sizes(
tensor_size.begin(), tensor_size.begin() + dim);
view_sizes.insert(view_sizes.end(), {num_chunks, -1});
padded_tensors.push_back(padded_tensor.view(view_sizes));
padded_tensors.push_back(padded_tensor.reshape(view_sizes));
}
return padded_tensors;
}

View File

@ -422,11 +422,12 @@ static __global__ void chunk_cat_cuda_kernel(
}
bool all_contiguous(TensorList tensors) {
bool contiguous = true;
for (const auto& t : tensors) {
contiguous &= t.is_non_overlapping_and_dense();
if (!t.is_contiguous()) {
return false;
}
}
return contiguous;
return true;
}
// Get leading dimensions before `dim`-th dimension.

View File

@ -1282,7 +1282,7 @@ def _pad_chunk(
]
tensor = aten.constant_pad_nd(tensor, pad, 0)
view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
padded_tensors.append(tensor.view(view_size))
padded_tensors.append(tensor.reshape(view_size))
return padded_tensors

View File

@ -2302,6 +2302,7 @@ def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
# No requirements for (wrapped_dim, ...)-th dimension.
# 3. Expect positive num_chunks
# 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
# 5. Non-contiguous input tensors are allowed.
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
same_ndim_cases = (
(
@ -2348,6 +2349,14 @@ def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
tensors.append(make_arg(size))
yield SampleInput(tensors, args=(dim, num_chunks))
# non-contiguous
for dim in range(max_dim):
tensors = []
for size in different_ndim_case:
# make the last 2 dims column-major (i.e. non-contiguous)
t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1)
tensors.append(t)
yield SampleInput(tensors, args=(dim, num_chunks))
def error_inputs_chunk_cat(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)