mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: As a first step for this plan: https://github.com/pytorch/pytorch/issues/19508#issuecomment-485178192, this PR moves `THCTensor_(uniform)` to ATen. Major changes are: - `uniform_` cuda kernel now utilizes a philox generator. - the kernel also utilizes TensorIterator - the kernel uses a grid-stride loop to achieve peak effective bandwidth - Since the engine has changed from `curandStateMTGP32` to `curandStatePhilox4_32_10`, the randoms generated now will be different. - Here is the diff showing codegen changes: https://gist.github.com/syed-ahmed/4af9ae0d42b6c7dbaa13b9dd0d1dd1e8 (BC breaking change if any) - Philox4_32_10 is known to pass the standard TestU01 Big Crush test (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) and hence the quality of random numbers generated isn't an issue when compared to the previously used `curandStateMTGP32`. - I have added a test case in `aten/src/ATen/test/cuda_distributions_test.cu` which verifies that philox offset is incremented properly The benchmark was done on a DGX station with 4 V100s. I modified the script from jcjohnson 's [multinomial benchmark](https://github.com/jcjohnson/pytorch-multinomial-benchmark) to produce this notebook which shows that there is a general speedup with this PR and a regression hasn't been introduced: https://gist.github.com/syed-ahmed/9d26d4e96308aed274d0f2c7be5218ef To reproduce the notebook: - Run https://gist.github.com/syed-ahmed/4208c22c541f1d30ad6a9b1efc1d728f in a container with the current pytorch top of tree with the command: `python uniform_benchmark.py --stats_json before.json` - Apply this diff to the current pytorch top of tree and run the same script in a container with the command: `python uniform_benchmark.py --stats_json after.json` - Run the notebook attached above with the `after.json` and `before.json` in the same directory The effected bandwidth was calculated using the script (thanks to ngimel ): https://gist.github.com/syed-ahmed/f8b7384d642f4bce484228b508b4bc68 Following are the numbers before and after. ``` uniform, size, elements 65536 forward 5.168914794921875e-06 bandwidth (GB/s) 50.71548098597786 uniform, size, elements 131072 forward 5.056858062744141e-06 bandwidth (GB/s) 103.67860705101367 uniform, size, elements 262144 forward 7.164478302001953e-06 bandwidth (GB/s) 146.357621001797 uniform, size, elements 524288 forward 1.1217594146728515e-05 bandwidth (GB/s) 186.9520302275877 uniform, size, elements 1048576 forward 1.923084259033203e-05 bandwidth (GB/s) 218.10297600317384 uniform, size, elements 2097152 forward 3.640890121459961e-05 bandwidth (GB/s) 230.39992200138826 uniform, size, elements 4194304 forward 6.778717041015625e-05 bandwidth (GB/s) 247.49839679819922 uniform, size, elements 8388608 forward 0.00012810707092285157 bandwidth (GB/s) 261.92490202361347 uniform, size, elements 16777216 forward 0.00025241613388061524 bandwidth (GB/s) 265.86598474620627 uniform, size, elements 33554432 forward 0.000497891902923584 bandwidth (GB/s) 269.5720239913193 ``` ``` uniform, size, elements 65536 forward 5.550384521484375e-06 bandwidth (GB/s) 47.22988091821306 uniform, size, elements 131072 forward 5.581378936767578e-06 bandwidth (GB/s) 93.93520954942333 uniform, size, elements 262144 forward 6.165504455566406e-06 bandwidth (GB/s) 170.071404141686 uniform, size, elements 524288 forward 6.3276290893554685e-06 bandwidth (GB/s) 331.4277702414469 uniform, size, elements 1048576 forward 8.509159088134765e-06 bandwidth (GB/s) 492.91639239047356 uniform, size, elements 2097152 forward 1.2989044189453124e-05 bandwidth (GB/s) 645.8218077979443 uniform, size, elements 4194304 forward 2.347707748413086e-05 bandwidth (GB/s) 714.6211452997259 uniform, size, elements 8388608 forward 4.4286251068115234e-05 bandwidth (GB/s) 757.6715389250498 uniform, size, elements 16777216 forward 8.672237396240235e-05 bandwidth (GB/s) 773.8356427961071 uniform, size, elements 33554432 forward 0.00016920566558837892 bandwidth (GB/s) 793.2224227438523 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20292 Differential Revision: D15277761 Pulled By: ezyang fbshipit-source-id: 8bfe31a01eeed77f0ed6e7ec4d2dda4c6472ecaa