mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BUG] Fix nonzero_static crash on CUDA when the input is a empty tensor (#162578)
Fixes #162473 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162578 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
755cf90672
commit
b3ad8f4a9c
@ -317,6 +317,17 @@ void nonzero_static_cuda_out_impl(
|
||||
out_temp =
|
||||
Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())).t();
|
||||
}
|
||||
// If input has zero elements, avoid kernel grid calculations (which can
|
||||
// produce zero divisors) and just fill the output with fill_value.
|
||||
if (self.numel() == 0) {
|
||||
if (need_to_copy) {
|
||||
out_temp.fill_(fill_value);
|
||||
out.copy_(out_temp);
|
||||
} else {
|
||||
out.fill_(fill_value);
|
||||
}
|
||||
return;
|
||||
}
|
||||
int64_t* out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr<int64_t>()
|
||||
: out.mutable_data_ptr<int64_t>();
|
||||
|
||||
|
@ -1654,6 +1654,15 @@ class TestUnaryUfuncs(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
# empty input
|
||||
# https://github.com/pytorch/pytorch/issues/162473
|
||||
input_tensor = torch.tensor([], device=device)
|
||||
static_size = 1
|
||||
self.assertEqual(
|
||||
torch.nonzero_static(input_tensor, size=static_size),
|
||||
torch.tensor([[-1]], device=device),
|
||||
)
|
||||
|
||||
# 1D input
|
||||
input_tensor = torch.tensor([0, 8], device=device)
|
||||
static_size = 1
|
||||
|
Reference in New Issue
Block a user