mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Make torch._chunk_cat support non-contiguous inputs (#151263)
Currently, `torch._chunk_cat` only supports contiguous inputs (due to `.view()` usage in `_pad_chunk()` supporting only contiguous tensor). This doesn't work for internal models where there can be non-contiguous input tensors: - size=[8192, 16416], stride=[16448, 1] # stride[0] is larger than size[1] - size=[1152, 384], stride=[1, 1152] # column-major tensor In this PR, we relax the assumption on contiguous input tensor, by switching from `.view()` to `.reshape()`. Note that since `.reshape()` will try to use `.view()` under the hood whenever possible, this should not cause regression to existing use cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151263 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
committed by
PyTorch MergeBot
parent
30101aa450
commit
82200e33b5
@ -3366,7 +3366,7 @@ static std::vector<Tensor> _pad_chunk(
|
||||
std::vector<int64_t> view_sizes(
|
||||
tensor_size.begin(), tensor_size.begin() + dim);
|
||||
view_sizes.insert(view_sizes.end(), {num_chunks, -1});
|
||||
padded_tensors.push_back(padded_tensor.view(view_sizes));
|
||||
padded_tensors.push_back(padded_tensor.reshape(view_sizes));
|
||||
}
|
||||
return padded_tensors;
|
||||
}
|
||||
|
@ -422,11 +422,12 @@ static __global__ void chunk_cat_cuda_kernel(
|
||||
}
|
||||
|
||||
bool all_contiguous(TensorList tensors) {
|
||||
bool contiguous = true;
|
||||
for (const auto& t : tensors) {
|
||||
contiguous &= t.is_non_overlapping_and_dense();
|
||||
if (!t.is_contiguous()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return contiguous;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get leading dimensions before `dim`-th dimension.
|
||||
|
@ -1282,7 +1282,7 @@ def _pad_chunk(
|
||||
]
|
||||
tensor = aten.constant_pad_nd(tensor, pad, 0)
|
||||
view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
|
||||
padded_tensors.append(tensor.view(view_size))
|
||||
padded_tensors.append(tensor.reshape(view_size))
|
||||
return padded_tensors
|
||||
|
||||
|
||||
|
@ -2302,6 +2302,7 @@ def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
|
||||
# No requirements for (wrapped_dim, ...)-th dimension.
|
||||
# 3. Expect positive num_chunks
|
||||
# 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
|
||||
# 5. Non-contiguous input tensors are allowed.
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
same_ndim_cases = (
|
||||
(
|
||||
@ -2348,6 +2349,14 @@ def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
|
||||
tensors.append(make_arg(size))
|
||||
yield SampleInput(tensors, args=(dim, num_chunks))
|
||||
|
||||
# non-contiguous
|
||||
for dim in range(max_dim):
|
||||
tensors = []
|
||||
for size in different_ndim_case:
|
||||
# make the last 2 dims column-major (i.e. non-contiguous)
|
||||
t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
tensors.append(t)
|
||||
yield SampleInput(tensors, args=(dim, num_chunks))
|
||||
|
||||
def error_inputs_chunk_cat(op_info, device, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
|
||||
|
Reference in New Issue
Block a user