Fix mtia_extension.cpp setDevice() to correctly set current_device (#149398)

We referred to this code and found that there was a minor bug. Fix for future reference for others.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149398
Approved by: https://github.com/janeyx99
This commit is contained in:
Youseok Yang
2025-03-31 06:07:22 +00:00
committed by PyTorch MergeBot
parent 4f14224dc8
commit b99e0c5412

View File

@ -38,9 +38,8 @@ struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
void setDevice(c10::Device d) const override {
c10::Device current_device = getDevice();
if (current_device.index() != d.index()) {
current_device = d;
if (getDevice().index() != d.index()) {
current_device = d.index();
}
}
void uncheckedSetDevice(c10::Device d) const noexcept override {