[CUDA] Check size calculation in ilpReduce for softmax (#144009)

For #143644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144009
Approved by: https://github.com/Skylion007
This commit is contained in:
eqy
2025-01-04 02:31:13 +00:00
committed by PyTorch MergeBot
parent dbdda654af
commit 7e3cd0e488
2 changed files with 9 additions and 2 deletions

View File

@ -468,7 +468,7 @@ ilpReduce(index_t shift,
if (offset >= shift && offset < size) {
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
size -= blockDim.x > size ? size : blockDim.x;
data += blockDim.x;
}
index_t last = size % (ILP * blockDim.x);
@ -518,7 +518,7 @@ WriteFpropResultsVectorized(
if (offset >= shift && offset < size) {
output[offset] = epilogue(input[offset]);
}
size -= blockDim.x;
size -= blockDim.x > size ? size : blockDim.x;
input += blockDim.x;
output += blockDim.x;
}

View File

@ -10331,6 +10331,13 @@ class TestNNDeviceType(NNTestCase):
run_test(1100000000, 2) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
run_test(2200000000, 1) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
@onlyCUDA
@dtypes(torch.double)
def test_softmax_double(self, device, dtype):
logits = torch.randn(5, 513, dtype=dtype, device=device)
expected_ones = F.log_softmax(logits, dim=1).exp().sum(dim=1)
self.assertEqual(expected_ones, torch.ones_like(expected_ones))
@onlyCUDA
@dtypes(torch.half)
@largeTensorTest("20GB")