mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use torch.accelerator and `_get_device_module` instead of cuda to make DataParallel more device agnostic. Fixes #162152 recently, I've done some works to support my own privateuse1 backend in DataParallel module, but I found some cuda related APIs exist in parallel_apply.py file, that makes me have to monkey patch DataParallel module to support DP on my own backend. so I make some small changes to replace cuda.xxx to accelerator.xxx, and acquire device module by `_get_device_module`. this is my first time to contribute to pytorch, please let me know if there is any problem about the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162573 Approved by: https://github.com/ezyang, https://github.com/guangyey Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Edward Z. Yang <ezyang@mit.edu>