mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add getCurrentDeviceIndex to torch::stable::accelerator (#160453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160453 Approved by: https://github.com/janeyx99 ghstack dependencies: #159679
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4e4dbd2f8
commit
50a8c11875
@ -465,15 +465,29 @@ void boxed_test_stream(
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
return torch::stable::accelerator::getCurrentDeviceIndex();
|
||||
}
|
||||
|
||||
void boxed_test_get_current_device_index(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_device_guard(int device_index) -> int");
|
||||
m.def("test_device_guard_set_index() -> int");
|
||||
m.def("test_stream(int device_index) -> int");
|
||||
m.def("test_get_current_device_index() -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_device_guard", &boxed_test_device_guard);
|
||||
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
}
|
||||
#endif // LAE_USE_CUDA
|
||||
|
@ -237,3 +237,12 @@ def test_stream(device_index) -> int:
|
||||
Returns: Stream ID as an integer
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_stream.default(device_index)
|
||||
|
||||
|
||||
def test_get_current_device_index() -> int:
|
||||
"""
|
||||
Tests the getCurrentDeviceIndex functionality by getting the current device index.
|
||||
|
||||
Returns: Current device index as an integer
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()
|
||||
|
@ -285,6 +285,22 @@ if not IS_WINDOWS:
|
||||
|
||||
self.assertEqual(stream_id, expected_stream_id)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
def test_get_current_device_index(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
prev_device = torch.cuda.current_device()
|
||||
|
||||
try:
|
||||
expected_device = 1
|
||||
torch.cuda.set_device(expected_device)
|
||||
|
||||
current_device = libtorch_agnostic.ops.test_get_current_device_index()
|
||||
self.assertEqual(current_device, expected_device)
|
||||
finally:
|
||||
torch.cuda.set_device(prev_device)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -526,6 +526,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream(
|
||||
StreamHandle* ret_stream // returns new reference
|
||||
);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_device_index(int32_t* ret_device_index);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
struct CUDAGuardOpaque;
|
||||
|
@ -1676,3 +1676,8 @@ AOTITorchError aoti_torch_get_current_stream(
|
||||
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_get_current_device_index(int32_t* ret_device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ *ret_device_index = at::accelerator::getDeviceIndex(); });
|
||||
}
|
||||
|
@ -68,4 +68,11 @@ inline Stream getCurrentStream(DeviceIndex device_index) {
|
||||
return Stream(stream);
|
||||
}
|
||||
|
||||
// Get the current device index
|
||||
inline DeviceIndex getCurrentDeviceIndex() {
|
||||
DeviceIndex device_index;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_device_index(&device_index));
|
||||
return device_index;
|
||||
}
|
||||
|
||||
} // namespace torch::stable::accelerator
|
||||
|
Reference in New Issue
Block a user