mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 00:24:53 +08:00
Make RNGStateTracker support cuda-like device (#106771)
replace `CudaRNGStateTracker` with `RNGStateTracker` by rewriting some Cuda-binding code with `device_handle` Pull Request resolved: https://github.com/pytorch/pytorch/pull/106771 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb6b157458
commit
1afbc985fe
@ -429,7 +429,7 @@ def distribute_tensor(
|
||||
# TODO: the value assignment to global variable is not the ideal solution
|
||||
# we can replace it in future.
|
||||
if is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
|
||||
random._rng_tracker = OffsetBasedRNGTracker()
|
||||
random._rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
|
||||
|
||||
if not tensor.is_leaf:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user