Make RNGStateTracker support cuda-like device (#106771)

replace  `CudaRNGStateTracker` with `RNGStateTracker` by rewriting some Cuda-binding code with `device_handle`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106771
Approved by: https://github.com/wanchaol
This commit is contained in:
alanhe151220037
2023-08-10 19:14:30 +00:00
committed by PyTorch MergeBot
parent bb6b157458
commit 1afbc985fe
3 changed files with 31 additions and 25 deletions

View File

@ -429,7 +429,7 @@ def distribute_tensor(
# TODO: the value assignment to global variable is not the ideal solution
# we can replace it in future.
if is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
random._rng_tracker = OffsetBasedRNGTracker()
random._rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
if not tensor.is_leaf:
raise RuntimeError(