mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA devices so that it can be reset before execution of the op for each such that ops with randomness produces the same result for all ranks (note that we are planning a separate change to add support of per rank rng state). Previously we relied on op input arguments to deduce which devices to get rng state from. Which doesn't work for factory functions such torch.randn. Hence this changes switches to uncondionally collecting rng state from all devices. - Fixing per rank specific computations in _MaskedPartial and Shard placements discovered during test enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716 Approved by: https://github.com/ezyang