mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation We propose adding support for the Python with statement on `torch.accelerator.device_index` to enable device switching functionality. This enhancement would simplify writing device-agnostic code and provide benefits across all accelerators. Its device-specific counterparts include [`torch.cuda.device`](00199acdb8/torch/cuda/__init__.py (L482)
) and [`torch.cuda._DeviceGuard`](00199acdb8/torch/cuda/__init__.py (L469)
). **Design Philosophy** It accepts either an `Int` or `None` as input. When `None` is passed, no device switch is performed. Supporting `None` is important for compatibility, as it's possible to encounter `None` values from `torch.device.index`. Therefore, with this PR, we can do like this ```python src = 0 dst = 1 # Set src to current device torch.accelerator.set_device_index(src) with torch.accelerator.device_index(dst): # Inside with statement, we set dst to current device assert torch.accelerator.get_device_index() == dst # Here the current device should be src assert torch.accelerator.get_device_index() == src ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148864 Approved by: https://github.com/albanD