[cuda rng] Making offset calculation independent of device properties (#98988)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98988
Approved by: https://github.com/ngimel
This commit is contained in:
Animesh Jain
2023-04-18 19:07:31 +00:00
committed by PyTorch MergeBot
parent bb017d7671
commit 26f318574f
7 changed files with 46 additions and 20 deletions

View File

@ -51,10 +51,31 @@ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
const uint32_t unroll = curand4_engine_calls;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
grid.x = std::min(
static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
grid.x);
// We changed the offset calculations to be independent of CUDA devices
// properties. Earlier the impl was
//
// max_threads_per_sm = static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor);
// number_of_sm = static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
//
// However, having the offset dependent on cuda device properties makes it
// harder/hacky to support functionalization of RNG ops. So we have chosen
// lower bounds for the two device properties by looking at the recent GPUs.
// Lower bounds ensure that we move the offsets more aggressively, ensuring
// that random numbers are not repeated. Note that philox sequence length is
// 2**64, so we can advance the offsets liberally.
// Number of SMs have been around 1536 and 2048. Choosing 1536 for safety.
uint32_t max_threads_per_sm = 1536;
// T4 has 40, V100 has 80 and A100 has 108 SMs. To get the lower bound, we
// assume MIG on a A100 GPU, i.e. 108/8, and round it a nicer even number 12.
uint32_t number_of_sm = 12;
uint32_t blocks_per_sm = max_threads_per_sm / block_size;
grid.x = std::min(number_of_sm * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
* curand4_engine_calls;

View File

@ -7,6 +7,7 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/native/cuda/DistributionTemplates.h>
#include <c10/macros/Macros.h>
#include <curand_kernel.h>
@ -342,13 +343,11 @@ dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){
if (nelem==0) return std::tuple<Tensor,Tensor>(self.clone(), mask);
Tensor ret = at::empty_like(self);
const int64_t block_size = 256;
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
dim3 dim_block(block_size);
dim3 grid((nelem + block_size -1)/block_size);
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
auto execution_policy = at::native::calc_execution_policy(nelem);
auto counter_offset = std::get<0>(execution_policy);
auto grid = std::get<1>(execution_policy);
auto dim_block = std::get<2>(execution_policy);
PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]

View File

@ -118,10 +118,12 @@ TEST(DistributionsTest, TestPhiloxIncrementBigUniformTensor) {
// calculate maximum number of threads that can be launched
// and set the numel to be 8 times that
uint32_t max_threads_per_sm = 1536;
uint32_t number_of_sm = 12;
const int block_size = 256;
dim3 dim_block(block_size);
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 grid(static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm);
uint32_t blocks_per_sm = max_threads_per_sm / block_size;
dim3 grid(number_of_sm * blocks_per_sm);
auto numel = block_size * grid.x * 8;
// get numel randoms from uniform_(), philox offset is now incremented to 8 by this call

View File

@ -354,6 +354,9 @@ class TestFFT(TestCase):
# nd-fft tests
@onlyNativeDeviceTypes
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@toleranceOverride({
torch.cfloat : tol(2e-4, 1.3e-6),
})
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
allowed_dtypes=(torch.cfloat, torch.cdouble))
def test_reference_nd(self, device, dtype, op):

View File

@ -1213,7 +1213,7 @@ class TestSDPA(NNTestCase):
math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()
self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
self.assertEqual(math_ref_test, math_ref_lp_test, atol=9.5e-3, rtol=7e-3)
self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@ -1666,7 +1666,7 @@ class TestSDPA(NNTestCase):
# TODO: Investigate why grad_k needs larger tolerances
grad_k_deviation = key_ref.grad - key_ref_lp.grad
grad_k_ref_atol = max(7 * torch.abs(grad_k_deviation).max().item(), 7 * default_atol[out.dtype])
grad_k_ref_atol = max(9.5 * torch.abs(grad_k_deviation).max().item(), 9.5 * default_atol[out.dtype])
grad_k_ref_rtol = max(7 * get_rtol(key_ref.grad, key_ref_lp.grad), 7 * default_rtol[out.dtype])
grad_v_deviation = value_ref.grad - value_ref_lp.grad

View File

@ -22,7 +22,7 @@ struct ValidationConstants {
// Tolerances generated from randn + add + sum fusion
// compared against double precision
std::array<std::array<double, 2>, 20> sum_tolerances_float = {
{{4, 1.68222e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
{{4, 1.99955e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
{32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06},
{256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05},
{2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207},

View File

@ -27,7 +27,7 @@ def throw_on_non_cuda(device):
def rand_offset_calculator(shape):
# For impl, look at the function calc_execution_policy in the file
# aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
# commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
# commit hash ccc5d1daec46da82ce17fcb8e9dcc871e9fef9a2
numel = 1
for dim_size in shape:
numel *= dim_size
@ -35,10 +35,11 @@ def rand_offset_calculator(shape):
block_size = 256
unroll = 4
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
max_threads_per_sm = 1536
number_of_sm = 12
blocks_per_sm = max_threads_per_sm // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
grid_size = min(grid_size, number_of_sm * blocks_per_sm)
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls