Add getCurrentDeviceIndex to torch::stable::accelerator (#160453)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160453
Approved by: https://github.com/janeyx99
ghstack dependencies: #159679
This commit is contained in:
Mikayla Gawarecki
2025-08-13 11:38:24 -07:00
committed by PyTorch MergeBot
parent e4e4dbd2f8
commit 50a8c11875
6 changed files with 54 additions and 0 deletions

View File

@ -465,15 +465,29 @@ void boxed_test_stream(
stack[0] = from(res);
}
int64_t test_get_current_device_index() {
return torch::stable::accelerator::getCurrentDeviceIndex();
}
void boxed_test_get_current_device_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_device_guard(int device_index) -> int");
m.def("test_device_guard_set_index() -> int");
m.def("test_stream(int device_index) -> int");
m.def("test_get_current_device_index() -> int");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_guard", &boxed_test_device_guard);
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
}
#endif // LAE_USE_CUDA

View File

@ -237,3 +237,12 @@ def test_stream(device_index) -> int:
Returns: Stream ID as an integer
"""
return torch.ops.libtorch_agnostic.test_stream.default(device_index)
def test_get_current_device_index() -> int:
"""
Tests the getCurrentDeviceIndex functionality by getting the current device index.
Returns: Current device index as an integer
"""
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()

View File

@ -285,6 +285,22 @@ if not IS_WINDOWS:
self.assertEqual(stream_id, expected_stream_id)
@onlyCUDA
@deviceCountAtLeast(2)
def test_get_current_device_index(self, device):
import libtorch_agnostic
prev_device = torch.cuda.current_device()
try:
expected_device = 1
torch.cuda.set_device(expected_device)
current_device = libtorch_agnostic.ops.test_get_current_device_index()
self.assertEqual(current_device, expected_device)
finally:
torch.cuda.set_device(prev_device)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -526,6 +526,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream(
StreamHandle* ret_stream // returns new reference
);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_device_index(int32_t* ret_device_index);
#ifdef USE_CUDA
struct CUDAGuardOpaque;

View File

@ -1676,3 +1676,8 @@ AOTITorchError aoti_torch_get_current_stream(
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
});
}
AOTITorchError aoti_torch_get_current_device_index(int32_t* ret_device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_device_index = at::accelerator::getDeviceIndex(); });
}

View File

@ -68,4 +68,11 @@ inline Stream getCurrentStream(DeviceIndex device_index) {
return Stream(stream);
}
// Get the current device index
inline DeviceIndex getCurrentDeviceIndex() {
DeviceIndex device_index;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_device_index(&device_index));
return device_index;
}
} // namespace torch::stable::accelerator